Source code for pyadi.pyadi

import os
import inspect
import importlib
import warnings
import numpy as np

from itertools import chain

from astunparse import loadast, unparse2j, unparse2x, unparse
from astunparse.astnode import ASTNode, BinOp, Constant, Name, isgeneric, fields

from .astvisitor import canonicalize, resolvetmpvars, normalize, unnormalize, filterLastFunction
from .astvisitor import infoSignature, filterFunctions, py, getmodule, getast, fqname, fquname, fdname, fddname
from .astvisitor import ASTVisitorID, ASTVisitorImports, ASTVisitorLocals, mkTmp, isbuiltin
from .nodes import *
from .runtime import dzeros, unzd, joind, unjnd, DWith, lzip

from .runtime import binop_add, binop_sub, binop_mult, binop_c_mult, binop_d_mult, binop_matmult, binop_div, binop_floordiv, binop_mod, binop_pow
from .runtime import unaryop_uadd, unaryop_usub
from .runtime import augassign_add, augassign_sub, augassign_mult, augassign_div, augassign_truediv, augassign_mod
from .dtargets import mkActArgFunction, mkKwFunction

from .timer import Timer

from . import astvisitor


Debug = False

dpref_ = 'd_'

[docs] def setprefix(diff, tmp, common=''): global dpref_ dpref_ = common + diff astvisitor.setprefix(common + tmp)
dumpDir = '.'
[docs] def dumpFile(fname): return os.path.join(dumpDir, fname)
[docs] def nodiff(tree): return tree._class == "Constant"
[docs] def isdiff(tree): return not nodiff(tree)
[docs] class ASTVisitorFMAD(ASTVisitorID): active_fields = [] active_methods = [] localvars = [] verbose = 0
[docs] def __call__(self, tree): """Process the tree. Calls dispatch, which will catch only FunctionDefs and enter ddispatch traversal when the function is designated as active. Calls methods self._XYZ for individual node XYZ handling, there is only _FunctionDef.""" self.localvars, self.localfuncs = ASTVisitorLocals()(tree) self.active_methods += self.localfuncs if self.verbose > 2: print(f'Locals of {tree.name}', self.localvars) self.result = self.dispatch(tree) return self.result
[docs] def dadispatch(self, tree): """traversal only for the LHS of assignments""" if isinstance(tree, list): return [self.dadispatch(t) for t in tree] elif isgeneric(tree): return tree cname = tree._class meth = getattr(self, "_Da_"+cname, None) # all nodes must be handled by a method assert meth, f'self.{"_Da_"+cname} not found' if meth: return meth(tree) else: res = ASTNode() for name in vars(tree).keys(): delem = self.dadispatch(getattr(tree, name)) setattr(res, name, delem) return res
def _Da_Name(self, t): if not self.isLocal(t): return Name(dpref_ + '_') return self.ddispatch(t) def _Da_Attribute(self, t): if not self.isLocal(t): return Name(dpref_ + '_') return self.ddispatch(t) def _Da_Subscript(self, t): if not self.isLocal(t): return Name(dpref_ + '_') return self.ddispatch(t) def _Da_Tuple(self, t): t.elts = self.dadispatch(t.elts) return t
[docs] def ddispatch(self, tree): """The main workhorse, the differentiation traversal Calls methods self._DXYZ for individual node XYZ handling """ if isinstance(tree, list): return [self.ddispatch(t) for t in tree] elif isgeneric(tree): return tree cname = tree._class meth = getattr(self, "_D"+cname, None) # print('ddispatch?', cname, vars(tree)) if meth: # print('Found method', cname) return meth(tree) else: # print('start dispatch', vars(tree).keys()) # print('start dispatch', dir(tree)) res = ASTNode() for name in vars(tree).keys(): delem = self.ddispatch(getattr(tree, name)) # print(f'DDispatch {name} => {repr(delem)}') setattr(res, name, delem) return res
nodiffFunctions = [] nodiffExpr = ["Raise", "Assert"]
[docs] def isnodiffExpr(self, item): res = False if item._class == "Expr": v = item.value if v._class == "Call": if getattr(v.func, 'id', False): res = v.func.id in self.nodiffFunctions elif item._class in self.nodiffExpr: res = True return res
[docs] def diffStmtList(self, body): nbody = [] self.tmpval = None for item in body: if item._class == "Assign": atargets = [ s for s in item.targets if self.isLocal(s) ] if len(atargets) < len(item.targets): natargets = [ s for s in item.targets if not self.isLocal(s) ] warnings.warn(f'Assignment to non-local locations {[str(v).strip() for v in natargets]} cannot be handled, the derivative may be wrong') if len(atargets) > 0: if item.value._class == "BinOp" and item.value.op == "**" and isdiff(item.value.right): self.tmpval = mkTmp('s') nbody += [Assign(self.tmpval, self.mkOpPartialC("**", None, None, item.value.left, item.value.right))] nbody += [self.ddispatch(item.clone())] if item.value._class == "BinOp" and item.value.op == "**" and isdiff(item.value.right): nbody += [Assign(item.targets, self.tmpval)] self.tmpval = None continue if item.value._class not in self.tupleDiff: nbody += [self.dispatch(item)] else: nbody += [self.dispatch(item)] elif item._class == "FunctionDef": nbody += [self._DFunctionDef(item.clone())] nbody += [item] elif item._class == "AugAssign": if item.op == "**" and isdiff(item.value): self.tmpval = mkTmp('s') nbody += [Assign(self.tmpval, self.mkOpPartialC("**", None, None, item.target, item.value))] if item.op in ['+', '-']: if isdiff(item.value): nbody += [self.ddispatch(item.clone())] elif item.op == '|': nbody += [self.ddispatch(item.clone())] elif item.op == '//': nbody += [Assign(self.ddispatch(item.target.clone()), Constant(0))] else: if item.op == '*' or item.op == '@' or item.op == '/': nbody += [AugAssign(item.op, self.ddispatch(item.target.clone()), item.value)] elif item.op == '**': nbody += [AugAssign('*', self.ddispatch(item.target.clone()), self.mkOpPartialL1(item.op, None, item.target, item.value))] if isdiff(item.value): op = '+' rhspartial = self.mkOpPartialR(item.op, None, self.ddispatch(item.value.clone()), item.target, item.value) if rhspartial._class == "UnaryOp": op = rhspartial.op rhspartial = rhspartial.operand nbody += [AugAssign(op, self.ddispatch(item.target.clone()), rhspartial)] if item.op == "**" and self.tmpval is not None: nbody += [Assign(item.target, self.tmpval)] else: nbody += [item] elif self.isnodiffExpr(item): nbody += [self.dispatch(item)] else: nbody += [self.ddispatch(item)] return nbody
def _FunctionDef(self, t): return self.ddispatch(t) def _DFunctionDef(self, t): if t.name in self.active_methods: # print(f'Catch Active FunctionDef {t.name} {vars(t)}') t.args = self.ddispatch(t.args) t.body = self.diffStmtList(t.body) prestmts = [] if t.args.kwarg: prestmts += [Assign(Tuple([Name('d_' + t.args.kwarg.arg), Name(t.args.kwarg.arg)]), Call('unjnd', Name(t.args.kwarg.arg)))] if t.args.vararg: prestmts += [Assign(Tuple([Name('d_' + t.args.vararg.arg), Name(t.args.vararg.arg)]), Tuple([Subscript(Name(t.args.vararg.arg), Slice(0, s=2)), Subscript(Name(t.args.vararg.arg), Slice(1, s=2))]))] t.body = prestmts + t.body decos = [] for d in t.decorator_list: if d._class == "Name": decos += [ Lambda('f', Subscript(Call(Call('D', d), Tuple([Name('f'), Name(t.name)])), 0)) ] else: decos += [ Lambda('f', Subscript(Call(Subscript(self.diffUnlessIsTupleDiff(d), 0), [Name('f'), Name(t.name)]), 0)) ] t.decorator_list = decos t.name = dpref_ + t.name else: t.args = self.dispatch(t.args) t.body = self.dispatch(t.body) return t def _DJoinedStr(self, t): t.values = [self.ddispatch(s) if s._class != "Constant" else s for s in t.values] return t def _DFormattedValue(self, t): t.value = self.diffUnlessIsTupleDiff(t.value) return t def _DIfExp(self, node): node.body = self.diffUnlessIsTupleDiff(node.body) node.orelse = self.diffUnlessIsTupleDiff(node.orelse) return node def _DGeneratorExp(self, t): t.elt = self.diffUnlessIsTupleDiff(t.elt) t.generators = self.ddispatch(t.generators) return t def _DLambda(self, node): bck = self.localvars self.localvars += ASTVisitorLocals.getVars(node.args.args) node.args = self.ddispatch(node.args) self.localvars = bck node.body = self.diffUnlessIsTupleDiff(node.body) return node tupleDiff = ["Call", "List", "ListComp", "Dict", "DictComp", "DictComp", "IfExp", "GeneratorExp", 'Starred']
[docs] def diffUnlessIsTupleDiff(self, t, src=None): if t._class in self.tupleDiff: res = self.ddispatch(t.clone()) if src and src._class == 'Call': if t._class == "List": res = res.elts[0].value return res elif t._class == "Tuple": dargs = [self.diffUnlessIsTupleDiff(t) for t in t.elts] return Tuple([Starred(Call('zip', dargs))]) else: return Tuple([self.ddispatch(t.clone()),t])
def _DList(self, node): if len(node.elts): dargs = [self.diffUnlessIsTupleDiff(t) for t in node.elts] return List([Starred(Call('lzip', dargs))]) return Tuple([List([]), List([])]) def __DTuple(self, node): if len(node.elts): dargs = [self.diffUnlessIsTupleDiff(t) for t in node.elts] return Tuple([Starred(Call('zip', dargs))]) return Tuple([List([]), List([])]) def _DDict(self, node): node.values = [self.diffUnlessIsTupleDiff(t) for t in node.values] return Tuple([Starred(Call('unzd', node))]) def _DDictComp(self, node): node.value = self.diffUnlessIsTupleDiff(node.value) node.generators = self.ddispatch(node.generators) return Call('unzd', [node]) def _DListComp(self, node): node.elt = self.diffUnlessIsTupleDiff(node.elt) node.generators = self.ddispatch(node.generators) return Call('lzip', [Starred(node)]) def _Dcomprehension(self, node): self._DForCommon(node) return node def _DIf(self, node): node.body = self.diffStmtList(node.body) node.orelse = self.diffStmtList(node.orelse) return node def _DWhile(self, node): node.body = self.diffStmtList(node.body) return node def _DForCommon(self, node): tnode = Tuple([self.ddispatch(node.target.clone()), node.target]) if node.iter._class == "Call": itnode = Call('zip', [Starred(self.ddispatch(node.iter))]) else: itnode = Call('zip', [self.ddispatch(node.iter.clone()), node.iter]) node.target = tnode node.iter = itnode def _DFor(self, node): node.body = self.diffStmtList(node.body) self._DForCommon(node) return node def _Darguments(self, node): assert isinstance(node.args, list) dargs = [] curargs = node.args for t in curargs: if t.arg in self.localvars: tr1 = self.ddispatch(t.clone()) dargs += [tr1] node.args = list(chain(*zip(dargs, curargs))) ddefs = self.ddispatch(node.defaults) node.defaults = list(chain(*zip(ddefs, node.defaults))) # node.args = dargs + curargs return node nonder_builtins = ['next'] nondercall_builtins = ['next'] def _DCall(self, t): #print(f'Diff Call {t.func} {vars(t)}') dcallName = 'D' curargs = t.args if t.func._class == "Call" or self.isLocal(t.func): dcall = Call(Name('Dc')) dcall.args = [ self.diffUnlessIsTupleDiff(t.func) ] else: dcall = Call(Name('D')) dcall.args = [t.func] fcallname = getattr(t.func, 'id', '') if fcallname in self.nondercall_builtins: res = Call(fcallname) else: res = Call(dcall) dargs = [self.diffUnlessIsTupleDiff(a, t) for a in curargs] res.args = dargs res.keywords = self.diffKeywords(t.keywords) return res
[docs] def diffKeywords(self, keywords): res = [] if len(keywords) == 0: return res for key in keywords: if key.arg is None: res += [ self.diffUnlessIsTupleDiff(key.value) ] else: res += [ Call('unzd', Call('dict', keyword(key.arg, self.diffUnlessIsTupleDiff(key.value)))) ] # print('keywords::', res) zcall = Call('joind', Starred(Call('zip', res))) return [keyword(None, zcall)]
def _Darg(self, t): if t.arg in self.localvars: t.arg = dpref_ + t.arg #print(' * active arg', t.arg) return t def _Dkeyword(self, t): t.arg = dpref_ + t.arg t.value = self.ddispatch(t.value) return t def _DAssign(self, t): atargets = [ s for s in t.targets if self.isLocal(s) ] if t.value._class in self.tupleDiff: t.targets = [Tuple(self.dadispatch([t.clone() for t in atargets]) + self.dispatch(t.targets))] else: #t.targets = [self.ddispatch(s.clone()) if self.isLocal(s) else Name('_') for s in t.targets ] t.targets = [self.dadispatch(s.clone()) for s in atargets ] isList = t.value._class == "List" or t.value._class == "Dict" t.value = self.ddispatch(t.value) if isList: if hasattr(t.value, 'elts') and t.value.elts[0]._class == "Starred": t.value = t.value.elts[0].value return t def _DName(self, t): #print(f'Diff Name {t.id}') if t.id in self.localvars: t.id = dpref_ + t.id return t return Call('dzeros', t)
[docs] def getRoot(self, t): if t._class == "Attribute" or t._class == "Subscript": return self.getRoot(t.value) return t
[docs] def isLocal(self, t): if t._class == "Tuple": return any([self.isLocal(s) for s in t.elts]) return getattr(self.getRoot(t), 'id', '') in self.localvars
def _DStarred(self, node): node.value = Call('zip', [Starred(self.diffUnlessIsTupleDiff(node.value))]) return node def _DSubscript(self, node): if not self.isLocal(node.value): return Call('dzeros', node) node.value = self.ddispatch(node.value) return node def _DAttribute(self, t): #print(f'Diff Attribute {t.attr} of {vars(t.value)} {self.imports}') if not t.value._class == "Call" and not self.isLocal(t.value): return Call('dzeros', t) t.value = self.ddispatch(t.value) return t def _DConstant(self, t): if isinstance(t.value, float) or isinstance(t.value, int): t = t.clone() t.value = 0 elif isinstance(t.value, complex): t = t.clone() t.value = 0j return t
[docs] def mkOpPartialC(self, op, r, dx, x, y): if op == '**': if r is None: t = BinOp(op, x, y) else: t = r return t
[docs] def mkOpPartialL1(self, op, r, x, y): if op == '**': if y._class == "Constant": if y.value == 2: t = BinOp('*', y, x) else: t = BinOp('*', y, BinOp('**', x, Constant(y.value -1))) else: t = BinOp('*', y, BinOp('**', x, BinOp('-', y, Constant(1)))) else: raise ValueError() return t
[docs] def mkOpPartialL(self, op, r, dx, x, y): if op == '/': t = BinOp('/', dx, y) elif op == '%': t = dx elif op == '**': p1 = self.mkOpPartialL1(op, r, x, y) t = BinOp('*', p1, dx) return t
[docs] def mkOpPartialR(self, op, r, dy, x, y): if op == '*' or op == '@': t = BinOp(op, x, dy) elif op == '/': sq = BinOp('**', y, Constant(2)) right_ = BinOp('*', x, dy) t = UnaryOp('-', BinOp('/', right_, sq)) elif op == '%': quot = BinOp('/', x, y) t = UnaryOp('-', BinOp('*', Call('math.floor', [quot]), dy)) elif op == '**': t = BinOp('*', Call('log', [x]), dy) t = BinOp('*', self.tmpval if self.tmpval is not None else BinOp('**', x, y), t) return t
def _DBinOp(self, t): #print(f'Diff BinOp {t} left {vars(t.left)}') if nodiff(t.left) and nodiff(t.right): return Constant(0) if nodiff(t.left): left = t.left else: left = self.ddispatch(t.left.clone()) if nodiff(t.right): right = t.right else: right = self.ddispatch(t.right.clone()) if t.op == '*' or t.op == '@': if isdiff(t.left) and isdiff(t.right): left_ = BinOp(t.op, left, t.right) right_ = BinOp(t.op, t.left, right) t = BinOp('+', left_, right_) else: t.left = left t.right = right elif t.op == '/': if isdiff(t.right): right_ = self.mkOpPartialR('/', None, right, t.left, t.right) if isdiff(t.left): left_ = self.mkOpPartialL('/', None, left, t.left, t.right) t = BinOp('-') t.left = left_ t.right = right_.operand else: t = right_ elif isdiff(t.left): t.left = left elif t.op == '//': t = Constant(0) elif t.op == '%': if t.right._class == "Tuple" or (t.left._class == "Constant" and isinstance(t.left.value, str)): return t if isdiff(t.left): lfact_ = self.mkOpPartialL('%', None, left, t.left, t.right) if isdiff(t.right): rfact_ = self.mkOpPartialR('%', None, right, t.left, t.right) if isdiff(t.left): t = BinOp('+') t.left = lfact_ t.right = rfact_ else: t = rfact_ elif isdiff(t.left): t = lfact_ elif t.op == '**': if isdiff(t.left): lder = self.mkOpPartialL('**', None, left, t.left, t.right) if isdiff(t.right): rder = self.mkOpPartialR('**', None, right, t.left, t.right) if isdiff(t.left): term = BinOp('+', lder, rder) else: term = rder elif isdiff(t.left): term = lder t = term elif t.op == '+' or t.op == '-': if nodiff(t.left): left = self.ddispatch(t.left.clone()) if nodiff(t.right): right = self.ddispatch(t.right.clone()) t.left = left t.right = right else: t.left = self.dispatch(t.left) t.right = self.dispatch(t.right) return t def _DReturn(self, t): t.value = self.diffUnlessIsTupleDiff(t.value) return t def _DYield(self, t): t.value = self.diffUnlessIsTupleDiff(t.value) return t def _DTry(self, t): t.body = self.diffStmtList(t.body) t.handlers = self.ddispatch(t.handlers) return t def _DExceptHandler(self, t): t.body = self.diffStmtList(t.body) return t def _DWith(self, t): t.items = self.ddispatch(t.items) t.body = self.diffStmtList(t.body) return t def _Dwithitem(self, t): t.context_expr = Call('DWith', self.diffUnlessIsTupleDiff(t.context_expr)) if t.optional_vars: t.optional_vars = self.diffUnlessIsTupleDiff(t.optional_vars) return t def _DDelete(self, t): dtargets = [ self.ddispatch(s.clone()) for s in t.targets if self.isLocal(s) ] t.targets = dtargets + t.targets return t def _DNonlocal(self, t): dtargets = [ 'd_' + s for s in t.names ] t.names = dtargets + t.names return t
[docs] def diff2pys(intree, visitor, **kw): # print('intree', unparse2j(intree, indent=1), file=open('intree.json', 'w')) if kw.get('verbose', 0) > 2: print(f'Input code for {getattr(intree, "name", "")}:', unparse(intree)) intree = normalize(intree.clone(), **kw) intree = canonicalize(intree) if kw.get('verbose', 0) > 1: print(f'Preprocessed code for {getattr(intree, "name", "")}:', unparse(intree)) outtree = visitor(intree) outtree = unnormalize(outtree.clone(), **kw) return outtree
[docs] def differentiate(intree, activef=None, active=None, modules=None, filter=False, prefix=None, **kw): fmadtrans = ASTVisitorFMAD() fmadtrans.verbose = kw.get('verbose', 0) if prefix: while len(prefix) < 3: prefix.append('') setprefix(*prefix) if modules is None: _, modules = ASTVisitorImports()(intree) fmadtrans.imports = modules # print('imports', fmadtrans.imports) # print('source', unparse(intree)) if activef is None: intree, fname = filterLastFunction(intree) fmadtrans.active_methods = [fname] else: fmadtrans.active_methods = varspec(activef) if filter: intree = filterFunctions(intree, activef) if active is None or len(active) == 0: fname, sig = infoSignature(intree) # fmadtrans.active_fields = [sig[0]] fmadtrans.active_fields = sig else: fmadtrans.active_fields = varspec(active) dtree = diff2pys(intree, fmadtrans, **kw) return dtree
[docs] def diff2py(fname): with open(fname, "r") as pyfile: source = pyfile.read() return unparse(diff2pys(source, fname))
[docs] def diff2pys2s(source, fname): return unparse(differentiate(loadast(source)))
[docs] def execompile(source, fglobals={}, flocals={}, imports=['math', 'sys', 'os', {'pyadi': 'D'}], vars=['x'], fname='', **kw): # importstr = '\n'.join([f'import {name}' if isinstance(name, str) # else ('\n'.join([f'from {k} import {v}' for k, v in name.items()])) for name in imports]) collectstr = '\n'.join([f'_pyadi_data["{name}"] = {name}' for name in vars]) dsrc = f"{source}\n{collectstr}" try: res = compile(dsrc, fname, "exec") except SyntaxError as ex: # print(f'Compilation error in diff source:\n{ex}') raise ex gvars = globals() | fglobals | {'_pyadi_data': {}} #print(f'exec compiled diff function code in file {sfname} with globals={(gvars).keys()} locals={flocals}') exec(res, gvars, flocals) result = {name: gvars["_pyadi_data"][name] for name in vars} gvars |= result return result
[docs] def Dpy(func, active=[], **kw): csrc, imports, modules = getast(func, **kw) dtree = differentiate(csrc, activef=func.__name__, active=active, modules=modules, **kw) return dtree
[docs] def mkClosDict(function): """Return dictionary of closure variables of ``function``.""" clos = getattr(function, '__closure__', None) cl_data = {} if clos is not None: code = function.__code__ cl_data = { code.co_freevars[i]: clos[i].cell_contents for i in range(len(clos)) } return cl_data
[docs] def difffunction(func, active=[], **kw): dtree = Dpy(func, active, **kw) try: dsrc = unparse(dtree) except BaseException as ex: print(unparse2j(dtree, indent=1), file=open('d_failed.json', 'w')) print(unparse2x(dtree, indent=1), file=open('d_failed.xml', 'w')) print(f"""Failed to unparse diff code, exception: {ex} Source: {py(func)} """) raise ex if kw.get('verbose', 0) > 1: print(f'Diff code {fdname(func)}:{dsrc}') sfname = '' if kw.get('dump', 0) > 0: sfname = dumpFile(fddname(func) + '.py') print(dsrc, file=open(sfname, 'w')) cl_data = mkClosDict(func) if cl_data and kw.get('verbose', 0) > 1: print(f'difffunction: Function {func} has a closure: {cl_data.keys()}') fkey = dpref_ + func.__name__ # globals = func.__globals__ if not isinstance(func, type) else func.__init__.__globals__ gvars = func.__globals__ | kw.get('globals', {}) | cl_data try: dfunc = execompile(dsrc, vars=[fkey], fglobals=gvars, fname=sfname, **kw) except BaseException as ex: print(unparse2j(dtree, indent=1), file=open('d_failed.json', 'w')) print(unparse2x(dtree, indent=1), file=open('d_failed.xml', 'w')) print(dsrc, file=open('d_failed.py', 'w')) print(f"""Failed to compile diff code, exception: {ex} Diff code: {dsrc} Source: {py(func)} """) raise ex dfunc = dfunc[fkey] global adglobalsc adglobalsc[fqname(func)] = dfunc.__globals__ return (dfunc, active)
adglobalsc = {}
[docs] def fid(func, active): mod, modfile = getmodule(func) if modfile is None: modfile = mod fid = f'{func.__qualname__}:{modfile}:{repr(active)}' # print('FID', func, fid) return fid
[docs] def getsig(f): x = inspect.signature(f) x = [f for f in x.parameters] return x
[docs] def varspec(x): if isinstance(x, str): x = x.split(',') assert isinstance(x, list) return x
[docs] def parinds(f, x): x = varspec(x) if len(x) > 0 and isinstance(x[0], str): sig = getsig(f) inds = [sig.index(a) for a in sig if a in x] else: inds = x return inds
adc = {}
[docs] def clear(search=None): global adc if search is None: adc = {} astvisitor.modastcache = {} astvisitor.getast(mkActArgFunction) else: if search in adc: del adc[fqname(search)]
[docs] def doSourceDiff(function, opts): # Try source diff adfun = None _class = None # print(f'SD: {function.__name__}') if isbuiltin(function): fname = fqname(function) id = astvisitor.rid(function) msg = f'No rule for buitin {fname}, function {id} not found' raise (NoRule(msg)) elif isinstance(function, type): # print(f'Cannot diff. a type! {function.__name__}') return mkConstr(function) (adfun, actind) = difffunction(function, **opts) return adfun
rulemodules = {}
[docs] def clearrulemodules(name=None): global rulemodules rulemodules = {}
[docs] def addrulemodule(module, **kw): deco = module.decorator(**kw) handle = None if isinstance(deco, tuple): deco, handle = deco alias = kw.get('alias', module.__name__) ind = 0 while f'{alias}{ind}' in rulemodules: ind += 1 if ind == 0: ind = '' rulemodules[f'{alias}{ind}'] = module, deco, handle
[docs] def initRules(rules='ad=pyadi.forwardad', **opts): """Initialize the rule processing mechanism for :py:func:`processRules` that performs the mapping of functions to differentiated functions. The :py:func:`.decorator` of a rule module may return two function handles instead of one. In this case the second one can be retrieved using :py:func:`.getHandle`, possibly to manipulate the scope of the returned differentiated functions at runtime, as demonstrated by the :py:func:`~.trace.decorator` of the rule module :py:mod:`.trace`. When the same module shall be used several times in the chain, an alias can be defined, for example:: pyadi.initRules( rules='pyadi.trace,pyadi.forwardad,tr2=pyadi.trace', tracecalls=True, verbose=True, verboseargs=True) Then, ``getHandle('tr2')`` retrieves the handle to the second instance of the decorator that the trace module installed, and ``getHandle('pyadi.trace')`` gets that of the first, while ``getHandle('pyadi.forwardad')`` is None because it does not provide a handler function. Parameters ---------- rules : str Comma-separated list of python modules to use as rule modules. Entries can use 'alias=module' to define a name alias. opts : dict Passed to :py:func:`decorator` of all the rule modules upon initialization. """ clearrulemodules() rules = rules.split(',') for rule in rules: add = {} if '=' in rule: (alias, rule) = rule.split('=') add['alias'] = alias rmod = importlib.import_module(rule) addrulemodule(rmod, **add, **opts)
initRules()
[docs] def getHandle(alias): """Return handle to a rule module. Return the second item of the result of a rule module's decorator function, or None. This second item is meant to be a function that can manipulate or read the local scope of the decorator, for an example see the tracing rule module :py:mod:`.trace`. The result returned becomes invalid whenever :py:func:`initRules` is called again. Parameters ---------- alias : str The module name or alias used with :py:func:`initRules`. Returns ------- object, usually function or None The second item that the rule module's decorator returned, usally a second inner function. """ return rulemodules[alias][2]
[docs] def callHandle(name, *args, **kw): results = [mod[2](*args, **kw) for alias, mod in rulemodules.items() if mod[0].__name__ == name] return results
[docs] def getRuleModules(index): return rulemodules
[docs] class NoRule(BaseException): pass
[docs] def processRules(function, opts, *args, **kw): mkeys = list(rulemodules.keys()) def nextStep(ind=0): if ind >= len(mkeys): dres = doSourceDiff(function, opts, *args, **kw) else: deco = rulemodules[mkeys[ind]][1] dres = deco(nextStep, ind+1, function, *args, **kw) #print(f'nextStep({ind}): {function.__name__} = {dres}') return dres return nextStep()
[docs] def initType(function, *args, **kw): """Create pair of objects d_o, and o by calling the constructor ``function`` twice, and then zero all floats in d_o with :py:func:`dzeros`. """ do, o = function(*args[1::2], **kw), function(*args[1::2], **kw) do = dzeros(do) return do, o
[docs] def mkConstr(function): def constr(*args, **kw): args = list(chain(*args)) return initType(function, *args, **kw) return constr
[docs] class GenIter:
[docs] def __init__(self, genobj): self.genobj = genobj
[docs] def __iter__(self): self.index = 0 self.findex = 0 return self
[docs] def next(self): self.nitem = next(self.genobj) self.index += 1
[docs] def __next__(self): if self.index == self.findex: # print('l') self.next() # else: assert self.findex +1 == self.index self.findex += 1 # assert self.findex == self.index return self.nitem[0]
[docs] class GenIter2:
[docs] def __init__(self, genobj): self.genobj = genobj
[docs] def __iter__(self): self.index = 0 return self
[docs] def __next__(self): if self.genobj.index == self.index: print('r') assert False # is it true this never happens? next(self.genobj) # else: assert self.index +1 == self.genobj.index self.index += 1 # assert self.index == self.genobj.index return self.genobj.nitem[1]
[docs] def doDiffFunction(function, **opts): """Produce differentiated functions. This function is called to produce a derivative function for ``function``, that is, a function that is called with tuples (dx, x) for each original argument x in the original code, and that returns a tuple (dr, r) where r is the function result and dr is the derivative. This function will call :py:func:`processRules`, which calls the installed rule modules. The default rule module :py:mod:`.forwardad` will for example catch calls to :py:func:`print`, which is a builtin function, and call :py:func:`.mkRule` with :py:func:`.D_builtins_print` to produce a suitable result, namely, it will print the differentiated arguments in an additional line to the original print, which in the case of a formatted :py:doc:`f-string` would be the same f-string, but with differentiated expressions.. When no rule module catchs the call and returns a suitable function, finally the source differentiation :py:func:`doSourceDiff` is invoked. It will produce a suitable result by retrieving the AST of function using :py:func:`.getast` and differentiating that. A few special cases need to by handled as follows: - When function is a type, then a constructor is being called. - set _C = function - When _C.__init__ is not builtin, set function = _C.__init__, that is the constructor class method. - When function has a closure and the last closure entry fdec is a function this might be a call to a decorated function, that is, the result of deco(fdec), which is a local function that captured fdec. So when fdec has a decorator list (getting the ast of fdec), this might be the right thing. TODO: We should check if this is really the right function. However, substitute function by fdec. This will thus in the following get the AST of fdec, which is something like:: @mydeco2(1.23) def gdeco2(l): return gl_sum(l) This gets differentiated to an expression decorating the regular D(fdec) with the differentiated decorator expression:: @(lambda f: D(mydeco2)((0, 1.23))[0](f, gdeco2)[0]) def d_gdeco2(d_l, l): return D(gl_sum)((d_l, l)) Which when loaded produces the differentiated inner function that the differentiated decorator, that D(mydeco2) returns, creates when called with d_gdeco2 alias f, and gdeco2. Thus, whatever happens in these decorators and the inner functions that they produce is getting differentiated regulary. The differentiated decorator expression is one of the few cases where we have to throw away the second part of the tuples that differentiated functions produce, because we only need the differentiated result, and even twice in this case. Then the differentiated function ``adfun`` is produced by calling :py:func:`.processRules`. This function returns however a local function ``def theADFun(*args, **kw):`` that does the following: - first flatten the argument list ``args`` of N tuples to a list of 2*N, alternating derivative and regular arguments. This is because the source differentiation differentiates ``def f(x, y):`` to ``def d_f(d_x, x, d_y, y):``. Hence, the builtin rules will also be called with the flattened list of arguments. This step also forces the evaluation of potentially lazy zip and other iterators that the arguments may be. - when ``_C`` is not None, a type, that is, a constructor has been called. Initialize two objects d_o and o with :py:func:`.initType`. Prepend ``(d_o, o)`` to the list of arguments. This requires that _C.__init__ accepts being called with no arguments. Both objects will then be reinitialized again with the provided arguments to the constructor when the differentiated __init__ method ``adres`` is invoked in the next step. TODO: check if we can somehow produce unitialized objects, that is, do what Python does before it calls __init__? - Another corner case is when ``function`` is not a function but a bound method. This can only happen with global objects being called, like the method ``get`` of :py:obj:`os.environ`. Then the self pointer ``o`` is extracted and a copy ``d_o`` to be :py:func:`.dzeros`-ed must be created somehow on the fly. Prepend ``(d_o, o)`` to the list of arguments. - Finally call ``adres = adfun(*args, **kw)`` - when ``adres`` is None, which often happens with methods, return ``(None, None)``, unless ``_C`` is not None, then a constructor has been called, return ``(d_o, o)``. - when ``adres`` is an object of the builtin type ``generator``, ``function`` was a generator function and ``adfun`` is too, with differentiated yield statements that produce tuples. Create two coupled iterators ``d_it =`` :py:class:`GenIter` (``adres``) and ``it =`` :py:class:`GenIter2` (``d_it``) that in tandem iterate ``adres``, one returning ``r[0]`` and the other ``r[1]`` of the tuples ``r`` produced, and return ``(d_it, it)``. - otherwise, return ``adres``, which is a tuple. """ _class, constr, deco = None, None, None if isinstance(function, type): # print(f'SD: {function.__name__} is a type!') _class = function if not isbuiltin(function.__init__): constr = function = function.__init__ else: #print(f'SD: type {function.__name__} has a builtin constructor !') pass else: clos = getattr(function, '__closure__', None) if clos is not None and len(clos) > 0: if isinstance(clos[-1].cell_contents, type): # print(f'Function {function} is a method, {function.__closure__[-1]}') pass elif callable(clos[-1].cell_contents) and len(getast(clos[-1].cell_contents)[0].decorator_list) > 0: # print(f'Function {function} has a closure and is decoratored, {function.__closure__[-1]}') deco = function function = clos[-1].cell_contents else: # handled only in case of source diff, later pass self = getattr(function, '__self__', None) if self is not None: selfClass = self.__class__ if selfClass.__name__ == 'module': self = None adfun = processRules(function, opts) if opts.get('verbose', 0): print(f'AD function produced for {fqname(function)}: {adfun.__qualname__}') def theADFun(*ADargs, **kw): args = list(chain(*ADargs)) #print(f'adfun called for {function.__qualname__}: {adfun.__qualname__}: {ADargs}, kw={kw}') if constr is not None: #print(f'adfun called for constr {function.__qualname__}: {adfun.__qualname__}, kw={kw}') d_kw, f_kw = unjnd(kw) do, o = initType(_class, *args, **f_kw) args = [do, o] + list(args) elif self is not None: try: dself = dzeros(self.__class__()) except: dself = dzeros(self) args = [dself, self] + list(args) adres = adfun(*args, **kw) if adres is None: if _class: # was constructor adres = do, o else: adres = None, None elif adres.__class__.__name__ == 'generator': adres_d = GenIter(adres) adres_v = GenIter2(adres_d) adres = adres_d, adres_v return adres return theADFun
[docs] def DiffFunction(function, **opts): """Runtime decorator to handle function calls. This function merely caches the calls to :py:func:`doDiffFunction`, which does the actual work when no entry is found for function. Use :py:func:`clear` to clear this cache, which should be necessary only when the processing is redefined at runtime using :py:func:`initRules`. """ ckey = fquname(function) centry = adc.get(ckey, None) if centry is None: # print(f'Diff function {fqname(function)}') adfun = doDiffFunction(function, **(transformOpts|opts)) adc[ckey] = (adfun, function) # print(f'Diff function {function.__name__} cached => {adfun.__name__}') else: adfun = centry[0] if opts.get('verbose', 0) > 2: print(f'Found diff function {fqname(function)} in cache: {adfun.__name__}') cl_data = mkClosDict(function) if cl_data: if opts.get('verbose', 0) > 1: print(f'DiffFunction: Function {function} has a closure: {cl_data.keys()}') if fqname(function) in adglobalsc: adglobalsc[fqname(function)] |= cl_data return adfun
D = DiffFunction """An alias for :py:func:`.DiffFunction` so the generated code can be shorter."""
[docs] def DiffFunctionObj(tpl, **opts): """Runtime decorator to handle calls to local variables. Calls to local variables like ``obj.meth`` are differentiated to an expression invoking this function as ``Dc((d_obj.meth, obj,meth))``, that is, with a tuple of the "differentiated" function and the original function. Differentiated is in quotes because different cases can happen. Let the tuple tpl be expanded to dfunc, func. This function will usually call DiffFunction (aka. D) with func after handling the following cases: - When dfunc != func: 1) A method is being called, dfunc and func are bound methods. Extract the two self pointers from both and substitute func with the actual class function T.meth, where T is the type. T is not necessarily the type of obj but may also be a parent class when an inherited method is being called using super(). At runtime, inject the two self pointers to the front of the argument list. 2) When func is not a function, then an object is being called, dfunc is the derivative object. Substitute func by T.__call__, where T is the type of obj. At runtime, inject the two self pointers (that is, dfunc and the original func) to the front of the argument list. 3) Otherwise, a local function inner is being called, which will have been differentiated in source code already, the call to this decorator then being Dc(d_inner, inner). Hence, dfunc is already the differentiated function of func, it can be called directly. In this case D() is not called in the following. - When dfunc == func: A function alias has been called, that is, a global function was assigned to a local variable like myf = math.sin. The differentiated variable d_myf then has dzeros(math.sin), which is also math.sin. Do nothing. . Finally call DiffFunction(func) and return that result, unless the runtime arguments need patching, then return a local function doing that. """ dfunc, function = tpl dself, self = None, None adfun = None # print(f'diff likely method {function}: {dfunc}: {opts}') if dfunc != function: self = getattr(function, '__self__', None) # print(f'different functions {function}: {dfunc}: self {self}') if self is not None: if self.__class__.__name__ != 'module': parts = function.__qualname__.split('.') cname = parts[-2] # pick the right class from object's MRO _class = [c for c in self.__class__.__mro__ if c.__name__ == cname][0] dself = dfunc.__self__ function = getattr(_class, function.__name__) elif not hasattr(function, '__qualname__'): # not isinstance(function, Function): _class = function.__class__ dself, self = dfunc, function function = function.__class__.__call__ else: def inner(*args, **kw): # print(f'inner shortcut called: {dfunc.__qualname__} for {function.__qualname__}') args = list(chain(*args)) return dfunc(*args, **kw) return inner if dself is not None: dfname = f'd_{function.__name__}' #adfun = getattr(_class, dfname, None) # print(f'DC: {dfunc} for {function} self={self}, dself={dself}') if adfun is None: adfun = DiffFunction(function, **opts) if dself is not None: try: setattr(_class, dfname, adfun) # print(f'Diff function {function.__name__} saved class type as {dfname} => {adfun.__name__}') except: pass else: # print(f'Diff function {function.__name__} in class type as {dfname} => {adfun.__name__}') pass def inner(*args, **kw): # Prepend dself and self to method call # print(f'method called: {adfun.__qualname__} for {function.__qualname__}, kw={kw}') return adfun((dself, self), *args, **kw) return inner if dself is not None else adfun
Dc = DiffFunctionObj """An alias for :py:func:`DiffFunctionObj` so the generated code can be shorter."""
[docs] def nvars(args): """Compute recursively the total number of values in the list args.""" if isinstance(args, list) or isinstance(args, tuple): return sum([nvars(f) for f in args]) elif isinstance(args, dict): return sum([nvars(v) for f, v in args.items()]) elif isgeneric(args): return 1 elif hasattr(args, 'flat'): return args.size else: return len(args)
[docs] def varv(args): if isinstance(args, list) or isinstance(args, tuple): return chain(*[varv(f) for f in args]) elif isinstance(args, dict): return chain(*[varv(v) for f, v in args.items()]) elif isgeneric(args): return [args] elif hasattr(args, 'flat'): return list(args.flat)
[docs] class FillHelper: """A simple iterator used to source floats one by one from either a list or a :py:mod:`numpy` array, used by :py:func:`.fill`. """
[docs] def __init__(self, seed): self.seed = np.array(seed) self.len = nvars(seed) self.offs = 0
[docs] def __iter__(self): return self
[docs] def __next__(self): """Return a single float from the data. """ if self.offs < self.len: r = self.seed[self.offs] self.offs += 1 return r else: raise StopIteration
[docs] def __repr__(self): """Print the fill status like "FillHelper(i/n)". """ return f'FillHelper({self.offs}/{len(self.seed)})'
[docs] def get(self, N): """Batch-return N values to speed up filling arrays.""" r = self.seed[self.offs:(self.offs+N)] self.offs += N return r
[docs] def fill(arg, seed): """Fill arg with values from seed. Fill arguments arg with values from seed. Lists, tuples, dicts and objects are deep-copied and each generic value, as per :py:func:`astunparse.astnode.isgeneric` is filled with one value from seed using :py:class:`.FillHelper`. :py:mod:`numpy` arrays are batch-filled in-place using :py:meth:`.FillHelper.get`, so :py:func:`.dzeros` should be used before if it is desired that arrays are cloned and the original arrays not modified Parameters ---------- arg : list of objects Function arguments. seed : list of floats or a :py:mod:`numpy` array Values to fill into arg during deep-copy. Returns ------- arg A deep copy of arg filled with seed. """ if not isinstance(seed, FillHelper): seed = FillHelper(seed) if isinstance(arg, list): return [fill(f, seed) for f in arg] elif isinstance(arg, tuple): return tuple(fill(f, seed) for f in arg) elif isinstance(arg, dict): return {f: fill(v, seed) for f, v in arg.items()} elif isgeneric(arg): return next(seed) elif hasattr(arg, 'flat'): arg.flat[:] = seed.get(arg.size) return arg
[docs] def dargs(args, seed=1): zargs = dzeros(args) if seed == 1: seed = [0] * nvars(args) seed[0] = 1 dargs = fill(zargs, seed) return dargs
[docs] def createFullGradients(args): N = nvars(args) seeds = [] for i in range(N): seed = [0] * N seed[i] = 1 seeds.append(seed) return seeds
transformOpts = {}
[docs] def DiffFor(function, *args, seed=1, active=[], f_kw=None, timings=True, verbose=0, dump=0, dumpdir='dump', **opts): """Differentiate ``function`` and compute first-order derivatives evaluated at ``*args`` and ``**f_kw``, w.r.t. all floats in ``args``, possibly restricted by ``àctive``. Differentiate function ``function(*args)`` with forward mode AD to produce ``adfun``. This function is the main entry point to start the differentiation process. This function basically does the following: - Differentiate ``function`` with :py:func:`.DiffFunction` alias :py:func:`.D` - Create one set of derivative arguments ``dx = dzeros(args)`` using :py:func:`.dzeros` - For each seeddir in seed, initialize ``dx`` with seeddir using :py:func:`.fill` and call ``adfun`` with ``dx`` and ``args`` appropriately. The result is a tuple of the list of the derivative results thus produced, and the function result. Although PyADi supports almost the full set of Python language features including keyword arguments, lambda functions, etc. the ``function`` given here must adhere to some restrictions: - ``function`` must be a regular Python function, defined with ``def``, not a lambda expression. - This function only processes the positional arguments ``args`` and considers all keyword arguments as options to the process, additional keyword arguments can be passed using ``f_kw``. - ``function`` can have parameter default values, - ``function`` can be a local function returned by whatever other function. This setup processs will not be differentiated. - ``function`` can also have a decorator, which will be differentiated. It may in some cases be required, and it is no problem, to create additional toplevel functions that can be given to this function, for example to wrap a lambda expression. It should not be required to build extra toplevel functions to inject global variables into the scope, since ``function`` can be a local function already. It will have access to the parent scopes as usual, but the values in it are treated as global values with zero derivative. However, when a function returning a function is called within ``function``, then this entire process, including possible calls to the result later, will be differentiated. Parameters ---------- function : function Function to differentiate. Must be a regular function, defined with ``def``, can be a local function. args : list Function arguments. ``function`` will be differentiated with respect to all arguments or to those listed by ``active``. active : list or str Active arguments, like [0,1], ['x', 'y'], or a comma-separated string like 'x,y'. The empty list or string means all arguments. What actually happens is that a local function of only the active ``args``, calling ``function``, is generated by :py:func:`.mkActArgFunction` and that is differentiated instead. seed : 1 or list Seed, derivative directions. When seed == 1, all derivative directions are computed. When seed is a list, then each entry must be a list or array of the same size as the total length of the active arguments. The function :py:func:`.nvars` can compute that value. f_kw : dict Further keyword arguments that will be passed to ``function`` as ``**f_kw``. This wraps ``function`` with :py:func:`.mkKwFunction`. opts : dict Further options ``opts`` including also verbose, dump and dumpdir are stored in a global variable :py:data:`.transformOpts`. These global options are added to the options of :py:func:`.doDiffFunction` by each call to :py:func:`.D` in the subsequent process. Returns ------- tuple A tuple of the derivative and the function result. The derivative is a list with as many entries as there were seed directions. """ global transformOpts transformOpts = opts | dict(timings=timings, verbose=verbose, dump=dump, dumpdir=dumpdir) if dump > 0 and dumpdir != '.': global dumpDir dumpDir = dumpdir if not os.path.exists(dumpDir): print(f'mkdir {dumpDir}') os.makedirs(dumpDir) jacobian = opts.get('jacobian', True) if f_kw is not None: assert isinstance(f_kw, dict) function = mkKwFunction(function, f_kw) if len(active) > 0: inds = parinds(function, active) function, args = mkActArgFunction(function, args, inds) if timings: with Timer(function.__qualname__, 'run', verbose=verbose-1) as t: result = function(*args) with Timer(function.__qualname__, 'diff', verbose=verbose-1) as t: adfunOrig = D(function, **opts) def TimeIt(*args, **kw): with Timer(function.__qualname__, 'adrun', verbose=verbose) as t: return adfunOrig(*args, **kw) adfun = TimeIt else: adfun = D(function, **opts) if 'dx' in opts: dargs = dx dresult, result = adfun(*zip(fill(dargs, s), args)) else: if isgeneric(seed) and seed == 1: seed = createFullGradients(args) elif isinstance(seed, list): pass else: raise ValueError() dargs = dzeros(args) dresult = [] for s in seed: dresult.append(adfun(*zip(fill(dargs, s), args))) result = dresult[0][1] if len(dresult) else None dresult = [d for d, r in dresult] return dresult, result
[docs] def Diff(active='all', **opts): def _pyadi_diff(function): adc = {'f': None} def inner(*args, **kw): if 'mode' in kw and kw['mode'] == 'f': result = function(*args) else: result = function(*args) if adc['f'] is None: (adfun, actind) = difffunction(function, active=active, **opts) adc['f'] = (adfun, actind) else: (adfun, actind) = adc['f'] if 'dx' in kw: dargs = dx else: dargs = createGradients(args, actind) (dresult, result) = adfun(dargs, args) return (dresult, result) return inner return _pyadi_diff
[docs] def DiffFD(f, *args, active=[], seed=1, h=1e-8, f_kw={}, **opts): """Evaluate derivatves using central finite differences. The function f is called two times for each derivative direction provided by seed, to evaluate a central finite difference with step size h. The function f is called once more to compute the original result. Parameters ---------- f : function Function to differentiate. args : list Function arguments. h : float Step size, default 1e-8 active : list or str Active arguments, like [0,1], ['x', 'y'], or a comma-separated string like 'x,y'. The empty list or string means all arguments. What actually happens is that a local function of only the active ``args``, calling ``function``, is generated by :py:func:`.mkActArgFunction` and that is differentiated instead. seed : 1 or list Seed, derivative directions. When seed == 1, all derivative directions are computed. When seed is a list, then each entry must be a list or array of the same size as the total length of the active arguments. The function :py:func:`.nvars` can compute that value. f_kw : dict Further keyword arguments that will be passed to ``f`` as ``**f_kw``. opts : dict Options, not used. Returns ------- tuple A tuple of the derivative and the function result. The derivative is a list with as many entries as there were seed directions. """ if len(active) == 0: func = f else: inds = parinds(f, active) func, args = mkActArgFunction(f, args, inds) v = np.array(list(varv(args))) dargs = dzeros(args) N = v.size h2 = h*2 r = func(*args, **f_kw) def dirder(func, seed): #print('FDD', v, seed) r1 = func(*fill(dargs, v + h * seed), **f_kw) r2 = func(*fill(dargs, v - h * seed), **f_kw) #print('FDD', r1, r2) rv1 = np.array(varv(r1)) rv2 = np.array(varv(r2)) der = (rv1 - rv2)/h2 return fill(dzeros(r), der) # print('seed', seed) if isgeneric(seed) and seed == 1: dres = [] for i in range(N): seed = np.zeros(N) seed[i] = 1 dres.append(dirder(func, seed)) elif isinstance(seed, list): dres = [ dirder(func, np.array(seeddir)) for seeddir in seed ] else: raise ValuseError() return dres, r
[docs] def DiffFDNP(f, *args, active=[0], seed=1, h=1e-8, f_kw={}, **opts): """An optimized version of :py:func:`DiffFD` with some restrictions: - there can be only one active argument, considering opts['active']. - the only active argument, the seeds, and the function result must all by :py:mod:`numpy` arrays. Parameters ---------- f : function Function to differentiate. args : list Function arguments. h : float Step size. active : list or str Active arguments, like [0,1], ['x', 'y'], or a comma-separated string like 'x,y'. What actually happens is that a local function of only the active ``args``, calling ``function``, is generated by :py:func:`.mkActArgFunction` and that is differentiated instead. seed : 1 or list Seed, derivative directions. When seed == 1, all derivative directions are computed. When seed is a list, then each entry must be an array of the same size as the active argument. f_kw : dict Further keyword arguments that will be passed to ``f`` as ``**f_kw``. opts : dict Options, of which this funciton uses: Returns ------- tuple A tuple of the derivative and the function result. The derivative is a list with as many entries as there were seed directions. """ if len(active) == 0: func = f else: inds = parinds(f, active) func, args = mkActArgFunction(f, args, inds) assert(len(args) == 1) v = args[0] N = v.size sh = v.shape h2 = h*2 r = func(*args, **f_kw) if isgeneric(seed) and seed == 1: seed = np.eye(N) if isinstance(seed, list): getcol = lambda i: seed[i] ndd = len(seed) else: getcol = lambda i: seed[:,i] n, ndd = seed.shape Jac = np.zeros((r.size, ndd)) for i in range(ndd): v1 = v + h * getcol(i).reshape(sh) v2 = v - h * getcol(i).reshape(sh) r1 = func(v1, **f_kw) r2 = func(v2, **f_kw) Jac[:,i] = (r1.flat[:] - r2.flat[:])/h2 return Jac, r
# (c) 2023 AI & IT UG # Author: Johannes Willkomm jwillkomm@ai-and-it.de