Source code for openalea.metafspm.specializer

import ast, inspect, textwrap, linecache
import numpy as np
from numba import njit


# ------------------- AST helpers -------------------

def _parents_map(tree):
    parents = {}
    for p in ast.walk(tree):
        for c in ast.iter_child_nodes(p):
            parents[c] = p
    return parents

def _infer_attrs_to_inline(fdef):
    parents = _parents_map(fdef)
    reads, mutated = set(), set()
    for n in ast.walk(fdef):
        if isinstance(n, ast.Attribute) and isinstance(n.value, ast.Name) and n.value.id == 'self':
            if isinstance(n.ctx, ast.Store):
                mutated.add(n.attr)
            else:
                p = parents.get(n)
                if isinstance(p, ast.Call) and p.func is n:  # it's a call target -> skip
                    continue
                reads.add(n.attr)
        if isinstance(n, ast.Subscript) and isinstance(n.value, ast.Attribute):
            a = n.value
            if (isinstance(a.value, ast.Name) and a.value.id == 'self' and
                isinstance(n.ctx, ast.Store)):
                mutated.add(a.attr)
    return reads - mutated

def _find_self_callees(fdef):
    names = set()
    for n in ast.walk(fdef):
        if isinstance(n, ast.Call) and isinstance(n.func, ast.Attribute):
            if isinstance(n.func.value, ast.Name) and n.func.value.id == 'self':
                names.add(n.func.attr)
    return names

def _stable_hash(text: str) -> str:
    import hashlib
    return hashlib.sha1(text.encode('utf-8')).hexdigest()[:16]

[docs] class StripDecorators(ast.NodeTransformer): """Remove all decorators from functions/classes in the parsed tree.""" def visit_FunctionDef(self, node): node.decorator_list = [] # drop @dec(...) self.generic_visit(node) return node def visit_AsyncFunctionDef(self, node): node.decorator_list = [] self.generic_visit(node) return node def visit_ClassDef(self, node): node.decorator_list = [] # in case you ever parse a class block self.generic_visit(node) return node
# ------------------- Recursive specializer -------------------
[docs] def specialize_method_recursive(method, instance, max_depth=2, registry=None, print_src=False, debug=False): """ Recursively specialize `method` and its nested self.method(...) callees. Returns (numba_dispatcher, registry). """ if registry is None: registry = {} cls = instance.__class__ visiting = set() def _specialize_by_name(name: str, depth: int): if name in registry: return if name in visiting: if debug: print(f"Cyclic method calls detected at '{name}'") if depth < 0: if debug: print(f"Max recursion depth exceeded while specializing '{name}'") py_method = getattr(cls, name) # unbound function src = textwrap.dedent(inspect.getsource(py_method)) tree = ast.parse(src) # Remove decorators StripDecorators().visit(tree) ast.fix_missing_locations(tree) fdef = next(n for n in tree.body if isinstance(n, ast.FunctionDef)) # 1) recurse into nested self.method(...) first callees = _find_self_callees(fdef) visiting.add(name) for callee in sorted(callees): if hasattr(cls, callee) and inspect.isfunction(getattr(cls, callee)): _specialize_by_name(callee, depth-1) visiting.remove(name) # 2) inline self.<attr> const_map, global_arrays = {}, {} for attr in sorted(_infer_attrs_to_inline(fdef)): val = getattr(instance, attr) if isinstance(val, (bool, int, float, np.bool_, np.integer, np.floating)): const_map[attr] = bool(val) if isinstance(val, np.bool_) \ else int(val) if isinstance(val, np.integer) \ else float(val) if isinstance(val, np.floating) \ else val elif isinstance(val, np.ndarray): global_arrays[attr] = val # inject as read-only global elif isinstance(val, tuple) and all(isinstance(x, (bool, int, float, np.bool_, np.integer, np.floating)) for x in val): const_map[attr] = tuple(bool(x) if isinstance(x, np.bool_) else int(x) if isinstance(x, np.integer) else float(x) if isinstance(x, np.floating) else x for x in val) sym_for = {callee: f"__HELP_{callee}" for callee in callees} class Rewriter(ast.NodeTransformer): def visit_Attribute(self, node): self.generic_visit(node) if isinstance(node.value, ast.Name) and node.value.id == 'self': nm = node.attr if nm in const_map: return ast.copy_location(ast.Constant(const_map[nm]), node) if nm in global_arrays: return ast.copy_location(ast.Name(id=f'__C_{nm}', ctx=ast.Load()), node) return node def visit_Call(self, node): self.generic_visit(node) f = node.func if (isinstance(f, ast.Attribute) and isinstance(f.value, ast.Name) and f.value.id == 'self' and f.attr in sym_for): node.func = ast.copy_location(ast.Name(id=sym_for[f.attr], ctx=ast.Load()), f) return node Rewriter().visit(tree) ast.fix_missing_locations(tree) # 3) drop 'self' if fdef.args.args and fdef.args.args[0].arg == 'self': fdef.args.args = fdef.args.args[1:] fdef.name = name # 4) pretty source (optional) gen_src = ast.unparse(tree) if print_src: print(f"\n=== specialized {name} ===\n{gen_src}") # keep inspect.getsource() working fake_fn = f"<specialized:{id(instance)}:{name}:{_stable_hash(gen_src)}>" linecache.cache[fake_fn] = (len(gen_src), None, gen_src.splitlines(True), fake_fn) # 5) exec globals: start from method's module globals (so np & other helpers are visible) glb = dict(py_method.__globals__) glb.setdefault('np', np) for an, arr in global_arrays.items(): glb[f"__C_{an}"] = arr for callee in callees: if callee in registry: glb[sym_for[callee]] = registry[callee]['jit'] # compiled helper # 6) exec + JIT try: ns = {} code = compile(tree, fake_fn, "exec") exec(code, glb, ns) py_func = ns[name] jitted = njit(py_func, cache=False) # try compile once (ignore arg mismatch) try: _ = jitted(*(1 for _ in range(len(fdef.args.args)))) registry[name] = {'jit': jitted, 'py': py_func, 'src': gen_src} except Exception: pass except Exception: pass entry_name = method.__name__ if inspect.ismethod(method) else method.__name__ _specialize_by_name(entry_name, max_depth) # Debug print # print(registry) if entry_name in registry: return registry[entry_name]['jit'], registry else: return None, None