diff options
-rw-r--r-- | lib/mako/ast.py | 77 | ||||
-rw-r--r-- | test/ast.py | 37 |
2 files changed, 112 insertions, 2 deletions
diff --git a/lib/mako/ast.py b/lib/mako/ast.py index 6fcd9e7..19a3208 100644 --- a/lib/mako/ast.py +++ b/lib/mako/ast.py @@ -2,6 +2,7 @@ from compiler import ast, parse, visitor from mako import util +from StringIO import StringIO class PythonCode(object): """represents information about a string containing Python code""" @@ -12,12 +13,84 @@ class PythonCode(object): expr = parse(code, "exec") class FindIdentifiers(object): - def visitAssName(s, node, *args, **kwargs): + def visitAssName(s, node, *args): if node.name not in self.undeclared_identifiers: self.declared_identifiers.add(node.name) - def visitName(s, node, *args, **kwargs): + def visitName(s, node, *args): if node.name not in __builtins__ and node.name not in self.declared_identifiers: self.undeclared_identifiers.add(node.name) f = FindIdentifiers() visitor.walk(expr, f) +class walker(visitor.ASTVisitor): + def dispatch(self, node, *args): + print "Node:", str(node) + #print "dir:", dir(node) + return visitor.ASTVisitor.dispatch(self, node, *args) + +class FunctionDecl(object): + """function declaration""" + def __init__(self, code): + self.code = code + + expr = parse(code, "exec") + class ParseFunc(object): + def visitFunction(s, node, *args, **kwargs): + self.funcname = node.name + self.argnames = node.argnames + self.defaults = node.defaults + f = ParseFunc() + visitor.walk(expr, f) + +class ExpressionGenerator(object): + def __init__(self, astnode): + self.buf = StringIO() + visitor.walk(astnode, self) #, walker=walker()) + def value(self): + return self.buf.getvalue() + def operator(self, op, node, *args): + self.buf.write("(") + self.visit(node.left, *args) + self.buf.write(" %s " % op) + self.visit(node.right, *args) + self.buf.write(")") + def visitConst(self, node, *args): + self.buf.write(repr(node.value)) + def visitName(self, node, *args): + self.buf.write(node.name) + def visitMul(self, node, *args): + self.operator("*", node, *args) + def visitAdd(self, node, *args): + self.operator("+", node, *args) + def visitGetattr(self, node, *args): + self.visit(node.expr, *args) + self.buf.write(".%s" % node.attrname) + def visitSub(self, node, *args): + self.operator("-", node, *args) + def visitDiv(self, node, *args): + self.operator("/", node, *args) + def visitSubscript(self, node, *args): + self.visit(node.expr) + self.buf.write("[") + [self.visit(x) for x in node.subs] + self.buf.write("]") + def visitSlice(self, node, *args): + print node, dir(node) + self.visit(node.expr) + self.buf.write("[") + if node.lower is not None: + self.visit(node.lower) + self.buf.write(":") + if node.upper is not None: + self.visit(node.upper) + self.buf.write("]") + def visitCallFunc(self, node, *args): + self.visit(node.node) + self.buf.write("(") + self.visit(node.args[0]) + for a in node.args[1:]: + self.buf.write(", ") + self.visit(a) + self.buf.write(")") + +
\ No newline at end of file diff --git a/test/ast.py b/test/ast.py index d144629..501fc4a 100644 --- a/test/ast.py +++ b/test/ast.py @@ -1,6 +1,7 @@ import unittest from mako import ast, util +from compiler import parse class AstParseTest(unittest.TestCase): def setUp(self): @@ -29,7 +30,43 @@ print "Another expr", c parsed = ast.PythonCode("x + 5 * (y-z)") assert parsed.undeclared_identifiers == util.Set(['x', 'y', 'z']) assert parsed.declared_identifiers == util.Set() + + def test_function_decl(self): + """test getting the arguments from a function""" + code = "def foo(a, b, c=None, d='hi', e=x, f=y+7):pass" + parsed = ast.FunctionDecl(code) + assert parsed.funcname=='foo' + assert parsed.argnames==['a', 'b', 'c', 'd', 'e', 'f'] + def test_expr_generate(self): + """test the round trip of expressions to AST back to python source""" + x = 1 + y = 2 + class F(object): + def bar(self, a,b): + return a + b + def lala(arg): + return "blah" + arg + local_dict = dict(x=x, y=y, foo=F(), lala=lala) + + code = "str((x+7*y) / foo.bar(5,6)) + lala('ho')" + astnode = parse(code) + newcode = ast.ExpressionGenerator(astnode).value() + #print "newcode:" + newcode + #print "result:" + eval(code, local_dict) + assert (eval(code, local_dict) == eval(newcode, local_dict)) + + a = ["one", "two", "three"] + hoho = {'somevalue':"asdf"} + g = [1,2,3,4,5] + local_dict = dict(a=a,hoho=hoho,g=g) + code = "a[2] + hoho['somevalue'] + repr(g[3:5]) + repr(g[3:]) + repr(g[:5])" + astnode = parse(code) + newcode = ast.ExpressionGenerator(astnode).value() + print newcode + print "result:", eval(code, local_dict) + assert(eval(code, local_dict) == eval(newcode, local_dict)) + if __name__ == '__main__': unittest.main() |