aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lib/mako/ast.py77
-rw-r--r--test/ast.py37
2 files changed, 112 insertions, 2 deletions
diff --git a/lib/mako/ast.py b/lib/mako/ast.py
index 6fcd9e7..19a3208 100644
--- a/lib/mako/ast.py
+++ b/lib/mako/ast.py
@@ -2,6 +2,7 @@
from compiler import ast, parse, visitor
from mako import util
+from StringIO import StringIO
class PythonCode(object):
"""represents information about a string containing Python code"""
@@ -12,12 +13,84 @@ class PythonCode(object):
expr = parse(code, "exec")
class FindIdentifiers(object):
- def visitAssName(s, node, *args, **kwargs):
+ def visitAssName(s, node, *args):
if node.name not in self.undeclared_identifiers:
self.declared_identifiers.add(node.name)
- def visitName(s, node, *args, **kwargs):
+ def visitName(s, node, *args):
if node.name not in __builtins__ and node.name not in self.declared_identifiers:
self.undeclared_identifiers.add(node.name)
f = FindIdentifiers()
visitor.walk(expr, f)
+class walker(visitor.ASTVisitor):
+ def dispatch(self, node, *args):
+ print "Node:", str(node)
+ #print "dir:", dir(node)
+ return visitor.ASTVisitor.dispatch(self, node, *args)
+
+class FunctionDecl(object):
+ """function declaration"""
+ def __init__(self, code):
+ self.code = code
+
+ expr = parse(code, "exec")
+ class ParseFunc(object):
+ def visitFunction(s, node, *args, **kwargs):
+ self.funcname = node.name
+ self.argnames = node.argnames
+ self.defaults = node.defaults
+ f = ParseFunc()
+ visitor.walk(expr, f)
+
+class ExpressionGenerator(object):
+ def __init__(self, astnode):
+ self.buf = StringIO()
+ visitor.walk(astnode, self) #, walker=walker())
+ def value(self):
+ return self.buf.getvalue()
+ def operator(self, op, node, *args):
+ self.buf.write("(")
+ self.visit(node.left, *args)
+ self.buf.write(" %s " % op)
+ self.visit(node.right, *args)
+ self.buf.write(")")
+ def visitConst(self, node, *args):
+ self.buf.write(repr(node.value))
+ def visitName(self, node, *args):
+ self.buf.write(node.name)
+ def visitMul(self, node, *args):
+ self.operator("*", node, *args)
+ def visitAdd(self, node, *args):
+ self.operator("+", node, *args)
+ def visitGetattr(self, node, *args):
+ self.visit(node.expr, *args)
+ self.buf.write(".%s" % node.attrname)
+ def visitSub(self, node, *args):
+ self.operator("-", node, *args)
+ def visitDiv(self, node, *args):
+ self.operator("/", node, *args)
+ def visitSubscript(self, node, *args):
+ self.visit(node.expr)
+ self.buf.write("[")
+ [self.visit(x) for x in node.subs]
+ self.buf.write("]")
+ def visitSlice(self, node, *args):
+ print node, dir(node)
+ self.visit(node.expr)
+ self.buf.write("[")
+ if node.lower is not None:
+ self.visit(node.lower)
+ self.buf.write(":")
+ if node.upper is not None:
+ self.visit(node.upper)
+ self.buf.write("]")
+ def visitCallFunc(self, node, *args):
+ self.visit(node.node)
+ self.buf.write("(")
+ self.visit(node.args[0])
+ for a in node.args[1:]:
+ self.buf.write(", ")
+ self.visit(a)
+ self.buf.write(")")
+
+ \ No newline at end of file
diff --git a/test/ast.py b/test/ast.py
index d144629..501fc4a 100644
--- a/test/ast.py
+++ b/test/ast.py
@@ -1,6 +1,7 @@
import unittest
from mako import ast, util
+from compiler import parse
class AstParseTest(unittest.TestCase):
def setUp(self):
@@ -29,7 +30,43 @@ print "Another expr", c
parsed = ast.PythonCode("x + 5 * (y-z)")
assert parsed.undeclared_identifiers == util.Set(['x', 'y', 'z'])
assert parsed.declared_identifiers == util.Set()
+
+ def test_function_decl(self):
+ """test getting the arguments from a function"""
+ code = "def foo(a, b, c=None, d='hi', e=x, f=y+7):pass"
+ parsed = ast.FunctionDecl(code)
+ assert parsed.funcname=='foo'
+ assert parsed.argnames==['a', 'b', 'c', 'd', 'e', 'f']
+ def test_expr_generate(self):
+ """test the round trip of expressions to AST back to python source"""
+ x = 1
+ y = 2
+ class F(object):
+ def bar(self, a,b):
+ return a + b
+ def lala(arg):
+ return "blah" + arg
+ local_dict = dict(x=x, y=y, foo=F(), lala=lala)
+
+ code = "str((x+7*y) / foo.bar(5,6)) + lala('ho')"
+ astnode = parse(code)
+ newcode = ast.ExpressionGenerator(astnode).value()
+ #print "newcode:" + newcode
+ #print "result:" + eval(code, local_dict)
+ assert (eval(code, local_dict) == eval(newcode, local_dict))
+
+ a = ["one", "two", "three"]
+ hoho = {'somevalue':"asdf"}
+ g = [1,2,3,4,5]
+ local_dict = dict(a=a,hoho=hoho,g=g)
+ code = "a[2] + hoho['somevalue'] + repr(g[3:5]) + repr(g[3:]) + repr(g[:5])"
+ astnode = parse(code)
+ newcode = ast.ExpressionGenerator(astnode).value()
+ print newcode
+ print "result:", eval(code, local_dict)
+ assert(eval(code, local_dict) == eval(newcode, local_dict))
+
if __name__ == '__main__':
unittest.main()