aboutsummaryrefslogtreecommitdiffstats
path: root/test/ast.py
blob: 3b1ec8b9a97ae90bd1213f88f2deb283a3244689 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import unittest

from mako import ast, util, exceptions
from compiler import parse

class AstParseTest(unittest.TestCase):
    def setUp(self):
        pass
    def tearDown(self):
        pass
    def test_locate_identifiers(self):
        """test the location of identifiers in a python code string"""
        code = """
a = 10
b = 5
c = x * 5 + a + b + q
(g,h,i) = (1,2,3)
[u,k,j] = [4,5,6]
foo.hoho.lala.bar = 7 + gah.blah + u + blah
for lar in (1,2,3):
    gh = 5
    x = 12
print "hello world, ", a, b
print "Another expr", c
"""
        parsed = ast.PythonCode(code, 0, 0)
        assert parsed.declared_identifiers == util.Set(['a','b','c', 'g', 'h', 'i', 'u', 'k', 'j', 'gh', 'lar'])
        assert parsed.undeclared_identifiers == util.Set(['x', 'q', 'foo', 'gah', 'blah'])
    
        parsed = ast.PythonCode("x + 5 * (y-z)", 0, 0)
        assert parsed.undeclared_identifiers == util.Set(['x', 'y', 'z'])
        assert parsed.declared_identifiers == util.Set()

    def test_locate_identifiers_2(self):
        code = """
import foobar
from lala import hoho, yaya
import bleep as foo
result = []
data = get_data()
for x in data:
    result.append(x+7)
"""
        parsed = ast.PythonCode(code, 0, 0)
        print parsed.declared_identifiers
        assert parsed.undeclared_identifiers == util.Set(['get_data'])
        assert parsed.declared_identifiers == util.Set(['result', 'data', 'x', 'hoho', 'foobar', 'foo', 'yaya'])

    def test_no_global_imports(self):
        code = """
from foo import *
import x as bar
"""
        try:
            parsed = ast.PythonCode(code, 0, 0)
            assert False
        except exceptions.CompileException, e:
            assert str(e).startswith("'import *' is not supported")
            
    def test_python_fragment(self):
        parsed = ast.PythonFragment("for x in foo:", 0, 0)
        assert parsed.declared_identifiers == util.Set(['x'])
        assert parsed.undeclared_identifiers == util.Set(['foo'])
        
        parsed = ast.PythonFragment("try:", 0, 0)

        parsed = ast.PythonFragment("except (MyException, e):", 0, 0)
        assert parsed.declared_identifiers == util.Set(['e'])
        assert parsed.undeclared_identifiers == util.Set()
        
    def test_function_decl(self):
        """test getting the arguments from a function"""
        code = "def foo(a, b, c=None, d='hi', e=x, f=y+7):pass"
        parsed = ast.FunctionDecl(code, 0, 0)
        assert parsed.funcname=='foo'
        assert parsed.argnames==['a', 'b', 'c', 'd', 'e', 'f']

    def test_function_decl_2(self):
        """test getting the arguments from a function"""
        code = "def foo(a, b, c=None, *args, **kwargs):pass"
        parsed = ast.FunctionDecl(code, 0, 0)
        assert parsed.funcname=='foo'
        assert parsed.argnames==['a', 'b', 'c', 'args', 'kwargs']
    
    def test_expr_generate(self):
        """test the round trip of expressions to AST back to python source"""
        x = 1
        y = 2
        class F(object):
            def bar(self, a,b):
                return a + b
        def lala(arg):
            return "blah" + arg
        local_dict = dict(x=x, y=y, foo=F(), lala=lala)
        
        code = "str((x+7*y) / foo.bar(5,6)) + lala('ho')"
        astnode = parse(code)
        newcode = ast.ExpressionGenerator(astnode).value()
        #print "newcode:" + newcode
        #print "result:" + eval(code, local_dict)
        assert (eval(code, local_dict) == eval(newcode, local_dict))
        
        a = ["one", "two", "three"]
        hoho = {'somevalue':"asdf"}
        g = [1,2,3,4,5]
        local_dict = dict(a=a,hoho=hoho,g=g)
        code = "a[2] + hoho['somevalue'] + repr(g[3:5]) + repr(g[3:]) + repr(g[:5])"
        astnode = parse(code)
        newcode = ast.ExpressionGenerator(astnode).value()
        #print newcode
        #print "result:", eval(code, local_dict)
        assert(eval(code, local_dict) == eval(newcode, local_dict))
        
if __name__ == '__main__':
    unittest.main()