Source code for pyadi.astvisitor

from astunparse import Unparser
import sys, os, inspect, json
from io import StringIO
import time
import random

import astunparse
from astunparse import loadastpy, unparse, unparse2j
from astunparse.astnode import fields
from .nodes import *

tpref_ = 't_'

[docs] def setprefix(t): global tpref_ tpref_ = t
tmpseen = {}
[docs] class TmpVar(Name):
[docs] def __init__(self, kind='t'): super().__init__(mkTmpName(kind))
[docs] def mkTmpName(kind='t'): for i in range(3): id = random.random() if id not in tmpseen: break short = f'{tpref_}{kind}{len(tmpseen):d}' tmpseen[id] = short return short
[docs] def mkTmp(kind='t'): return Name(mkTmpName(kind))
[docs] class NotFound(BaseException): pass
[docs] class NoSource(BaseException): pass
[docs] def fname(func): fname = getattr(func, '__qualname__', getattr(func, '__name__', None)) return fname
[docs] def fqname(func): mod, _ = getmodule(func) fname = getattr(func, '__qualname__', getattr(func, '__name__', None)) return f'{mod}.{fname}'
[docs] def fquname(func): mod, _ = getmodule(func) fname = getattr(func, '__qualname__', getattr(func, '__name__', None)) fid = id(func) return f'{mod}.{fname}_at_0x{fid:x}'
[docs] def fdname(func): return '.'.join([ s for s in fqname(func).split('.') if s != '<locals>' ])
[docs] def fddname(func): fields = [ s for s in fqname(func).split('.') if s != '<locals>' ] fields[-1] = 'd_' + fields[-1] return '.'.join(fields)
[docs] def rid(func): fname = fqname(func) #print(f'rid(func) = {func} {fname} {dir(func)}') fid = fname.replace('.', '_') return fid
[docs] def getmodule(func): #print('getmodule', func, type(func)) mod = getattr(func, '__module__', None) if mod is None: mod = func.__class__.__module__ modfile = getattr(sys.modules[mod], '__file__', None) #print('getmodule', func, mod, modfile) return mod, modfile
[docs] def isbuiltin(func): mod, modfile = getmodule(func) res = modfile is None return res
modastcache = {}
[docs] def getmoddict(mod, **opts): """Return dictionary with ASTs of module classes and functions. Resolve ``from`` imports with :py:func:`.resolveImports`. """ modfile = getattr(sys.modules.get(mod, {}), '__file__', None) if modfile is None or modfile.endswith('.so'): if opts.get('verbose', 0): print(f'No source for module {mod}') raise NoSource(f'No source for module {mod}') if mod in modastcache: centry = modastcache[mod] tree, imports, modules = centry["data"] moddict = centry["dict"] else: t0 = time.time() with open(modfile) as f: csrc = f.read() tree = loadastpy(csrc) imports, modules = ASTVisitorImports()(tree) moddict = ASTVisitorDict()(tree) modastcache[mod] = {'name': mod, "file": modfile, "data": (tree, imports, modules), "dict": moddict} resolveImports(mod, modfile, moddict, imports, modules, **opts) if opts.get('verbose', 0) > 1: print(f'Load and parse module {mod} source from {modfile}') t1 = time.time() if opts.get('verbose', 0): print(f'Load and parse module {mod} source from {modfile}: {1e3*(t1-t0):.1f} ms') return moddict, imports, modules
[docs] def resolveImports(mod, modfile, moddict, imports, modules, **opts): pkgs = mod.split('.') moduleImports = {} if opts.get('verbose', 0) > 1: print(f'Resolve imports for {mod}, {modfile}') # print(f'Resolve imports for {mod}, {modfile}:, imports={imports}, modules={modules}') # print(f'moddict={moddict.keys()}') for name in imports: impentry = imports[name] if isinstance(impentry, dict): assert len(impentry.keys()) == 1 imod = list(impentry.keys())[0] if imod not in moduleImports: moduleImports[imod] = {} moduleImports[imod].update({ name: impentry[imod] }) if opts.get('verbose', 0) > 2: print(f'moduleImports={moduleImports}') for imod in moduleImports: modname, level = imod if modname is None: # these are all modules continue if level > 0: ilevel = level if modfile.endswith('__init__.py'): ilevel -= 1 prepkgs = pkgs if ilevel <= 0 else pkgs[0:-ilevel] prepkgs += [modname] modname_ = '.'.join(prepkgs) if opts.get('verbose', 0) > 1: print(f'Get local import {modname}, {level}, {ilevel}, {pkgs} => {modname_}') modname = modname_ try: impd, _, _ = getmoddict(modname) #print(f'Got moddict for {modname}: {impd.keys()}') except NoSource: continue imodimps = moduleImports[imod] for name, impname in imodimps.items(): if name == "*": if opts.get('verbose', 0) > 1: print(f'Import {impname} from import module {modname} as {name} into {mod}') moddict.update(impd) else: if impname in impd: if opts.get('verbose', 0) > 1: print(f'Import {impname} from module {modname} into {mod} as {name}') moddict[name] = impd[impname] else: if opts.get('verbose', 0) > 1: print(f'Import {impname} from module {modname} into {mod} as {name} is likely a module.') return moddict
[docs] def getast(func, **kw): # ta0 = time.time() mod, modfile = getmodule(func) # print(f'Get SRC and AST: {func.__qualname__} in {mod} file {modfile}') if modfile is None: print(f'No source for {mod}.{fqname(func)}') raise(NoSource(f'No source for {mod}.{fqname(func)}')) moddict, imports, modules = getmoddict(mod, **kw) try: tree = moddict[fname(func)] except KeyError: print(f'No source for {mod}.{fname(func)}') raise(NoSource(f'No source for {mod}.{fname(func)} => {rid(func)}')) # ta1 = time.time() # print(f'Got AST of {mod}.{func.__name__}: {1e3*(ta1-ta0):.1f} ms') return tree, imports, modules
[docs] def py(func, info=False): tree, imports, modules = getast(func) src = unparse(tree).strip() if info: return src, imports, modules else: return src
[docs] class ASTVisitor:
[docs] def __init__(self): # print('ASTVisitor()') pass
[docs] def __call__(self, tree): # print('ASTVisitor.call()') self.result = self.dispatch(tree) return self.result
[docs] def dispatch(self, tree): # print('ASTVisitor.dispatch()') if isinstance(tree, list): for t in tree: self.dispatch(t) return cname = tree._class meth = getattr(self, "_"+cname) meth(tree)
[docs] class ASTVisitorID(ASTVisitor):
[docs] def dispatch(self, tree): if isinstance(tree, list): return [self.dispatch(t) for t in tree] if not isinstance(tree, ASTNode): return tree cname = tree._class meth = getattr(self, "_"+cname, None) # print(cname, vars(tree)) if meth: return meth(tree) else: for name in vars(tree): setattr(tree, name, self.dispatch(getattr(tree, name))) return tree
[docs] def isop(cn): return cn._class in ['BinOp', 'UnaryOp', 'BoolOp', 'AugAssign']
[docs] def iscall(cn): return cn._class in ['Call']
[docs] def iscanon(cn): return isop(cn) or iscall(cn)
[docs] class ASTCanonicalizer:
[docs] def __init__(self): pass
[docs] def __call__(self, tree): self._list = [] result = self.dispatch(tree) return result
[docs] def edispatch(self, tree, val=None): # print('edisp', tree) if type(tree) == type([]): raise(BaseException('error')) res = list(map(self.edispatch, tree)) elif isinstance(tree, ASTNode): tmpv = mkTmp('c') tmpas = Assign(tmpv, self.dispatch(tree.clone()) if val is None else val) self._list.append(tmpas) # print('new tmp', repr(tmpas)) res = tree else: res = tree return (res, tmpv)
[docs] def processStmts(self, stmts): nbody = [] assert isinstance(stmts, list) for stmt in stmts: if stmt._class == "For": self._list = [] stmt.iter = self.dispatch(stmt.iter) nbody += self._list self._list = [] self._list = [] pstmt = self.dispatch(stmt) nbody += self._list self._list = [] nbody += [pstmt] return nbody
[docs] def dispatch(self, tree): if type(tree) == type([]): res = list(map(self.dispatch, tree)) elif isinstance(tree, ASTNode): # print('visit', vars(tree)) if getattr(tree, 'body', None) is not None and tree._class != "Module" and tree._class != "IfExp" and tree._class != "Lambda": tree.body = self.processStmts(tree.body) if tree._class == "If": tree.orelse = self.processStmts(tree.orelse) return tree elif tree._class == "DictComp" or tree._class == "ListComp": tree.generators = [self.dispatch(tree.generators)] return tree elif tree._class == "Subscript": if iscanon(tree.value): (tl, tmpvar) = self.edispatch(tree.value) tree.value = tmpvar return tree for k in fields(tree): setattr(tree, k, self.dispatch(getattr(tree, k))) if tree._class == "AugAssign" or tree._class == "Attribute": # or tree._class == "Subscript" (handled above) if iscanon(tree.value): (tl, tmpvar) = self.edispatch(tree.value) tree.value = tmpvar elif tree._class == "UnaryOp": if iscanon(tree.operand): (tl, tmpvar) = self.edispatch(tree.operand) tree.operand = tmpvar elif tree._class == "BinOp" and tree.op == "**": (tl, tmpvar) = self.edispatch(tree.clone(), val=tree) tree = tmpvar elif tree._class == "BinOp": if iscanon(tree.left): (tl, tmpvar) = self.edispatch(tree.left) tree.left = tmpvar if iscanon(tree.right): (tr, tmpvar) = self.edispatch(tree.right) tree.right = tmpvar elif tree._class == "keyword" and False: if iscanon(tree.value): (tl, tmpvar) = self.edispatch(tree.value) tree.value = tmpvar elif tree._class == "List": tree.elts = [ self.edispatch(e)[1] if isop(e) else self.dispatch(e) for e in tree.elts ] res = tree else: res = tree return res
[docs] def canonicalize(tree, **kw): an = ASTCanonicalizer() return an(tree)
[docs] class ASTLocalAction:
[docs] def __init__(self): pass
[docs] def Before(self, t): pass
[docs] def After(self, t): pass
[docs] def Begin(self, t): pass
[docs] def End(self, t): pass
[docs] def __call__(self, tree): self.Begin(tree) result = self.dispatch(tree) er = self.End(tree) if er is not None: result = er return result
[docs] def dispatch(self, tree): if type(tree) == type([]): res = list(map(self.dispatch, tree)) elif isinstance(tree, ASTNode): res = tree br = self.Before(tree) if br is not None: b_res, b_ret = br res = b_res if b_ret: return b_res for k in vars(res).keys(): setattr(res, k, self.dispatch(getattr(res, k))) ar = self.After(res) if ar is not None: res = ar else: res = tree return res
[docs] class ASTReolvetmpvars(ASTLocalAction):
[docs] def Begin(self, tree): self.seen = {}
[docs] def Before(self, tree): if tree._class == "TmpVar": if tree.id in self.seen: short = self.seen[tree.id] else: short = f'{tree.kind}{len(self.seen):d}' self.seen[tree.id] = short tree = Name(f'{short}') return (tree, True)
[docs] def resolvetmpvars(tree, **kw): an = ASTReolvetmpvars() return an(tree)
[docs] class ASTPatchSuper(ASTLocalAction):
[docs] def Begin(self, tree): self.func = None
[docs] def Before(self, tree): if tree._class == "FunctionDef": self.func = tree elif tree._class == "Call" and getattr(tree.func, 'id', '') == "super": selfname = self.func.args.args[0].arg if len(tree.args) == 0: tree.args = [Attribute(Name(selfname), '__class__'), Name(selfname)] return (tree, True)
[docs] def mkOprevName(dict): rdict = {v: k.lower() for k, v in dict.items()} def inner(op): return rdict[op] return inner
[docs] def mkUnOprevName(dict): rdict = {k.lower(): v for k, v in dict.items()} def inner(op): return rdict[op] return inner
op_revname = mkOprevName(Unparser.binop | Unparser.boolops | Unparser.cmpops) op_revname_unary = mkOprevName(Unparser.unop) op_unrevname = mkUnOprevName(Unparser.unop | Unparser.binop | Unparser.boolops | Unparser.cmpops)
[docs] class ASTReplaceOps(ASTLocalAction):
[docs] def __init__(self, **kw): self.replace = kw.get('replace', ['binop', 'unaryop', 'augassign'])
#self.replace += ['attribute']
[docs] def Begin(self, tree): self.seen = {}
[docs] def Before(self, tree): pass
[docs] def After(self, tree): transformed = True if tree._class == "BinOp" and 'binop' in self.replace: if tree.op == "*" and tree.left._class == "Constant": res = Call(f'binop_c_{op_revname(tree.op)}', [tree.left, tree.right]) elif tree.op == "*" and tree.right._class == "Constant": res = Call(f'binop_d_{op_revname(tree.op)}', [tree.left, tree.right]) else: res = Call(f'binop_{op_revname(tree.op)}', [tree.left, tree.right]) elif tree._class == "UnaryOp" and 'unaryop' in self.replace: res = Call(f'unaryop_{op_revname_unary(tree.op)}', [tree.operand]) elif tree._class == "CmpOp" and 'cmpop' in self.replace: res = Call(f'cmpop_{op_revname(tree.op)}', [tree.value]) elif tree._class == "BoolOp" and 'boolop' in self.replace: res = Call(f'boolop_{op_revname(tree.op)}', tree.values) elif tree._class == "AugAssign" and 'augassign' in self.replace: res = Assign(tree.target, Call(f'binop_{op_revname(tree.op)}', [tree.target, tree.value])) elif tree._class == "Attribute" and 'attribute' in self.replace: res = Call(f'getattr', [tree.value, Constant(tree.attr)]) else: transformed = False res = tree if transformed: res.transformed = True return res
[docs] class ASTReplaceOpsInvert(ASTLocalAction):
[docs] def __init__(self, **kw): self.replace = kw.get('replace', ['binop', 'unaryop', 'augassign'])
#self.replace += ['attribute']
[docs] def Begin(self, tree): self.seen = {}
[docs] def Before(self, tree): pass
[docs] def After(self, tree): if getattr(tree, 'transformed', False): transformed = True try: if 'binop' in self.replace and tree._class == "Call" and getattr(tree.func, 'id', '').startswith('binop_'): res = BinOp(op_unrevname(tree.func.id[6:]), tree.args[0], tree.args[1]) elif 'unaryop' in self.replace and tree._class == "Call" and getattr(tree.func, 'id', '').startswith('unaryop_'): res = UnaryOp(op_unrevname(tree.func.id[8:]), tree.args[0]) elif 'cmpop' in self.replace and tree._class == "Call" and getattr(tree.func, 'id', '').startswith('cmpop_'): res = CmpOp(op_unrevname(tree.func.id[6:]), tree.args[0], tree.args[1]) elif 'boolop' in self.replace and tree._class == "Call" and getattr(tree.func, 'id', '').startswith('boolop_'): res = BoolOp(op_unrevname(tree.func.id[7:]), tree.args[0], tree.args[1]) elif 'augassign' in self.replace and tree._class == "Assign" \ and len(tree.targets) == 1 and tree.value._class == "Call" \ and getattr(tree.value.func, 'id', '').startswith('binop_'): res = AugAssign(op_revname(tree.value.func.id), tree.targets[0], tree.value.args[1]) elif 'attribute' in self.replace and tree._class == "Call" and tree.func.id == 'op_getattr': res = Attribute(tree.args[0], tree.args[1].value) else: transformed = False res = tree except KeyError: transformed = False res = tree if transformed: res.untransformed = True else: res = tree return res
[docs] def normalize(tree, **kw): tree = ASTPatchSuper()(tree) if kw.get('replaceops', False): tree = ASTReplaceOps()(tree) #tree = resolvetmpvars(tree) #tree = astunparse.normalize(tree) return tree
[docs] def unnormalize(tree, **kw): if kw.get('replaceops', False): tree = ASTReplaceOpsInvert()(tree) return tree
[docs] class ASTVisitorLastFunction(ASTLocalAction):
[docs] def Begin(self, tree): self.seen = []
[docs] def Before(self, tree): if tree._class == "FunctionDef": self.seen.append((tree.name, tree)) return (tree, True)
[docs] def End(self, tree): lname, lfunc = self.seen[-1] return Module([lfunc]), lname
[docs] def filterLastFunction(intree): trans = ASTVisitorLastFunction() return trans(intree)
[docs] class ASTVisitorFilterFunctions(ASTLocalAction):
[docs] def __init__(self, names): if isinstance(names, str): names = names.split('.') self.names = names
[docs] def Begin(self, tree): self.seen = [] self.index = 0 self.pos = 0
[docs] def Before(self, tree): # found, or to deep in tree, prune if self.pos > self.index or self.index >= len(self.names): return (tree, True) # print(f'XSearch {tree._class} for {self.names[self.index]}') if tree._class == "FunctionDef" or tree._class == "ClassDef": self.pos += 1 if tree.name == self.names[self.index]: # print(f'XFound {tree._class[:-3]} with name {tree.name} at level {self.index}') if self.index == len(self.names) -1: self.seen.append(tree.clone()) self.index += 1 return (tree, True) else: self.index += 1
[docs] def After(self, tree): if tree._class == "FunctionDef" or tree._class == "ClassDef": self.pos -= 1
[docs] def End(self, tree): if len(self.seen) == 0: raise(NotFound('Class or function not found: ' + '.'.join(self.names))) return Module(self.seen)
[docs] def filterFunctions(intree, names): trans = ASTVisitorFilterFunctions(names) return trans(intree)
[docs] class ASTVisitorLastFunctionSig(ASTLocalAction):
[docs] def Begin(self, tree): self.name = '' self.sig = [] self.seen = []
[docs] def Before(self, tree): if tree._class == "FunctionDef": self.sig = [t.arg for t in tree.args.args] self.name = tree.name return [tree, True]
[docs] def End(self, tree): return (self.name, self.sig)
[docs] def infoSignature(intree): trans = ASTVisitorLastFunctionSig() return trans(intree)
[docs] class ASTVisitorImports(ASTLocalAction):
[docs] def Begin(self, tree): self.imports = {} self.modules = []
[docs] def Before(self, tree): if tree._class == "ImportFrom": mname = f'{tree.module}' if tree.module is not None else None for f in tree.names: if f.asname: self.imports[f.asname] = {(mname, tree.level): f'{f.name}'} else: self.imports[f.name] = {(mname, tree.level): f'{f.name}'} elif tree._class == "Import": for f in tree.names: if f.asname: self.imports[f.asname] = f.name self.modules.append(f.asname) else: self.imports[f.name] = f.name self.modules.append(f.name)
[docs] def End(self, tree): return self.imports, self.modules
[docs] class ASTVisitorDict(ASTLocalAction):
[docs] def Begin(self, tree): self.dict = {} self.path = [] self.infunc = 0 self.infuncs = []
[docs] def Before(self, tree): if tree._class == "FunctionDef": self.infunc += 1 if self.infunc > 1: self.path += ['<locals>.' + tree.name] else: self.path += [tree.name] self.dict['.'.join(self.path)] = tree elif tree._class == "ClassDef": self.infuncs += [self.infunc] self.infunc = 0 self.path += [tree.name] self.dict['.'.join(self.path)] = tree
[docs] def After(self, tree): if tree._class == "FunctionDef": self.infunc -= 1 self.path = self.path[0:-1] elif tree._class == "ClassDef": self.path = self.path[0:-1] self.infunc = self.infuncs[-1] self.infuncs = self.infuncs[0:-1]
[docs] def End(self, tree): return self.dict
[docs] class ASTVisitorLocals(ASTLocalAction):
[docs] @classmethod def getRoot(self, t): if t._class == "Attribute" or t._class == "Subscript": return self.getRoot(t.value) return t
[docs] @classmethod def getVars(self, t): res = [] if isinstance(t, list) or isinstance(t, tuple): for e in t: res += ASTVisitorLocals.getVars(e) elif t._class == "Tuple": for e in t.elts: res += ASTVisitorLocals.getVars(e) elif t._class == "Name": res = [t.id] elif t._class == "arg": res = [t.arg] elif t._class == "Attribute" or t._class == "Subscript": res = [self.getRoot(t).id] return res
[docs] def Begin(self, tree): self.locals = [] self.localfuncs = []
[docs] def Before(self, tree): if tree._class == "FunctionDef": self.localfuncs += [ tree.name ] self.locals += [ tree.name ] self.locals += [ n.arg for n in tree.args.args ] for deco in tree.decorator_list: self.locals += self.getVars(deco) if tree.args.kwarg: self.locals += [ tree.args.kwarg.arg ] if tree.args.vararg: self.locals += [ tree.args.vararg.arg ] elif tree._class == "Assign": for n in tree.targets: if n._class == "Name": self.locals += [ n.id ] elif n._class == "Tuple": self.locals += [ m.id for m in n.elts if m._class == "Name" ] elif tree._class == "For" or tree._class == "comprehension": self.locals += self.getVars(tree.target) elif tree._class == "With": self.locals += [ self.getRoot(s.optional_vars).id for s in tree.items if s.optional_vars is not None ]
[docs] def End(self, tree): return self.locals, self.localfuncs
[docs] def py2pys_check(jdict, visitor): if type(jdict) == type(''): jdict = json.loads(jdict) intree = JStructBuilder(jdict).result print(json.dumps(jdict, indent=1), file=open('in.json', 'w')) print(unparse2Jt(intree), file=open('out.json', 'w')) assert(json.loads(unparse2Jt(intree)) == jdict) outtree = visitor(intree) assert(json.loads(unparse2Jt(outtree)) == jdict) jbuf = StringIO() JStructUnparser(outtree, jbuf) return jbuf.getvalue()
[docs] def py2pys(jdict, visitor): if type(jdict) == type(''): jdict = json.loads(jdict) intree = JStructBuilder(jdict).result outtree = visitor(intree) jbuf = StringIO() JStructUnparser(outtree, jbuf) return jbuf.getvalue()
[docs] def py2py(fname): with open(fname, "r") as pyfile: source = pyfile.read() return py2pys(source, fname)
[docs] def roundtrip2JIDs(source, fname): return py2pys_check(py2jsons(source), ASTVisitorID())
[docs] def roundtrip2JID(fname): with open(fname, "r") as pyfile: source = pyfile.read() return roundtrip2JIDs(source, fname)
[docs] def run(): fname = 'pyphy/parser.py' with open(fname) as f: code = f.read() output=sys.stdout tree = compile(code, fname, "exec", ast.PyCF_ONLY_AST, dont_inherit=True) c1 = ASTFragment(tree) Unparser(tree, output) src = inspect.getsource(Test.energy).strip() print(f'prop: "\n{src}"') roundtrip2Js(src, 'test_py')
# tree = compile(src, 'test_py', "exec", ast.PyCF_ONLY_AST, dont_inherit=True) # c2 = ASTFragment(tree) # UnparserJ(tree, output)
[docs] def testdir(): base = 'examples' for name in sorted(os.listdir(base)): fname = os.path.join(base, name) if os.path.isfile(fname) and fname.endswith('.py'): with open(fname) as f: source = f.read() res = roundtrip2JIDs(source, fname) print(f"""File {fname} Source: {source} Result: {res}""") assert(py2jsons(source, 'a.py') == py2jsons(res, 'b.py'))
if __name__ == "__main__": testdir() #run() # (c) 2023 AI & IT UG # Author: Johannes Willkomm jwillkomm@ai-and-it.de