"""Divacon Source v0.2 - a language for parallel divide-and-conquer.

   Source syntax:
       name = expr                       assignment
       let name = expr in expr           local binding
       expr @ expr                       composition (right-to-left) (Mou 1990's ':')
       ! expr                            broadcast over both halves of a pair
       # expr                            communication (with a generator)
       name(args)                        application; args positional or kw=val
       (expr, expr, ...)                 tuple
       [expr, expr, ...]                 list
       expr +|-|*|/ expr                 arithmetic
       -- text                           comment

    Run:  python3 divacon.py file.dc

"""

import lark

# ----------------------- Grammar -----------------------

GRAMMAR = r"""
start: stmt+

COMMENT: /--[^\n]*/

%ignore COMMENT

?stmt: "def" NAME "(" args ")" "=" expr  -> fundef
     | NAME "=" expr               -> assign
     | expr                        -> exprstmt

?expr: "let" NAME "=" expr "in" expr  -> let_
     | addexpr ("@" addexpr)*    -> compose

?addexpr: addexpr "+" mulexpr  -> add
        | addexpr "-" mulexpr  -> sub
        | mulexpr

?mulexpr: mulexpr "*" term     -> mul
        | mulexpr "/" term     -> div_
        | term

?term: "!" term            -> bang
     | "#" term            -> hash_
     | "-" term            -> neg
     | postfix

?postfix: atom
        | postfix "(" [args] ")"  -> call
        | postfix "." NUMBER      -> index_
        | postfix "." NAME        -> field_

args: arg ("," arg)*
?arg: NAME "=" expr        -> kwarg
    | expr                 -> posarg

tup_items: tup_item ("," tup_item)+
?tup_item: NAME "=" expr   -> named_field
         | expr            -> pos_field

?atom: NAME                -> var
     | QNAME               -> qvar
     | NUMBER              -> num
     | "[" [expr ("," expr)*] "]"  -> list_lit
     | "(" tup_items ")"           -> tuple_lit
     | "(" expr ")"

NAME: /[a-zA-Z_][a-zA-Z_0-9]*/
QNAME: /'[^']+'/
NUMBER: /\d+(?:\.\d+)?/
%import common.WS
%ignore WS
"""


# ----------------------- Runtime primitives -----------------------

def id_(x):  return x
def atom(A): return len(A) == 1    
def d_lr(A):
    n = len(A) // 2       # integer div
    return (A[:n], A[n:]) # i.e., A[0..n-1], A[n..len-1]

def c_lr(pair):
    L, R = pair           
    return L + R          # array concatenation

def d_eo(A):                  # use python slice syntax [start:stop:step]
    return (A[0::2], A[1::2]) # pick from 0 until end every 2nd; pick from 1 ...

def c_eo(pair):
    L, R = pair
    out = []
    for a, b in zip(L, R):    # zip([a,b,c],[A,B,C])=[(a,A),(b,B),(c,C)]
        out.extend([a, b])    # append a then b to the end of out.
    if len(L) > len(R):       # might be one longer if A was odd length
        out.append(L[-1])     # just append the last one.
    return out

def sum_combine(pair):  L, R = pair; return L + R
def max_combine(pair):  L, R = pair; return max(L, R)

def _min2(t): return t[0] if t[0] <= t[1] else t[1]   # min over a 2-tuple, named for trace clarity
def _max2(t): return t[0] if t[0] >= t[1] else t[1]
_min2.__name__ = 'min'
_max2.__name__ = 'max'

# Binary arithmetic on a 2-tuple - bound to '+', '*', '-', '/' in the env.
def plus(t):  return t[0] + t[1]
def times(t): return t[0] * t[1]
def minus(t): return t[0] - t[1]
def div(t):   return t[0] / t[1]

def liftop(a, b, op, sym):
    """Pointwise lift of a binary op and its args into a function to-be-evaluated-later
       scalar OP scalar     -> scalar
       fn     OP scalar     -> lambda x. fn(x)  OP scalar
       scalar OP fn         -> lambda x. scalar OP fn(x)
       fn     OP fn         -> lambda x. fn(x)  OP gn(x)
       Lets `self + other` evaluate to lambda t. t[0] + t[1] without a lambda keyword."""
    af, bf = callable(a), callable(b)
    if not af and not bf: return op(a, b)            # if neither is callable, evaluate now.
    if af and not bf:     f = lambda x: op(a(x), b)  # otherwise build a function.
    elif not af and bf:   f = lambda x: op(a, b(x))
    else:                 f = lambda x: op(a(x), b(x))
    f.__name__ = f"({getattr(a,'__name__',a)}{sym}{getattr(b,'__name__',b)})"
    return f

class FnTuple(tuple):
    """A tuple of functions/values to be applied(distributed to args)/returned later.
         Polymorphic on argument shape:
           distribution: arg is a tuple of matching arity -> (f0(a0), f1(a1), ...)
           construction: otherwise                        -> (f0(arg), f1(arg), ...)
       Non-callable elements pass through as constants in either mode.
       Distribution wins when arities match; that's the documented edge case."""
    def __call__(self, arg):
        if isinstance(arg, tuple) and len(arg) == len(self):
            if any(isinstance(a, list) and a and isinstance(a[0], CommTuple) for a in arg):
                return tuple([f(x) for x in a] if callable(f) else a for f, a in zip(self, arg))
            return tuple(f(a) if callable(f) else f for f, a in zip(self, arg))
        return tuple(f(arg) if callable(f) else f for f in self)
    @property
    def __name__(self):
        return "(" + ", ".join(getattr(f, '__name__', repr(f)) for f in self) + ")"

class NamedRec(FnTuple):
    """FnTuple with field names. Some entries may be unnamed ('').
       Field access by .name; index access by .N still works via tuple."""
    def __new__(cls, items, fields):
        t = super().__new__(cls, items)
        t._fields = tuple(fields)
        return t
    def __call__(self, arg):
        result = FnTuple.__call__(self, arg)
        return NamedRec(result, self._fields)
    def __getattr__(self, name):
        if name.startswith('_'):
            raise AttributeError(name)
        try: i = self._fields.index(name)
        except ValueError: raise AttributeError(f"no field {name!r}")
        return self[i]
    @property
    def __name__(self):
        return "(" + ", ".join(
            (f"{n}=" if n else "") + getattr(v, '__name__', repr(v))
            for n, v in zip(self._fields, self)) + ")"

def first(A):  return A[0]
def second(A): return A[1]
def third(A):  return A[2]
def fourth(A): return A[3]

# Composition: f @ g  means  lambda x: f(g(x))
class Composed:
    def __init__(self, fns): self.fns = fns; self.name = None
    def __call__(self, x):
        for f in reversed(self.fns):
            x = f(x)
        return x
    @property
    def __name__(self):
        return " @ ".join(getattr(f, 'name', None) or getattr(f, '__name__', repr(f))
                          for f in self.fns)

def compose_many(fns):
    flat = []
    for f in fns:
        if isinstance(f, Composed): flat.extend(f.fns)
        else: flat.append(f)
    return flat[0] if len(flat) == 1 else Composed(flat)

# Generators are pure index maps from this-side index i to a partner index on the
# OTHER side, given the other side's length n. nil means "no comm to this side."
#   corr(i,n)  = i              (same index on the other side)
#   mirr(i,n)  = n-1-i          (mirror image)
#   shift(i,n) = (i+1) % n      (next index, wraparound)
#   last(i,n)  = n-1            (constant: always the last element of other side)
#   nil        - no communication; this side is left unchanged.
def gen_at(gen, i, n):
    if gen == 'corr':  return i
    if gen == 'mirr':  return n - 1 - i
    if gen == 'shift': return (i + 1) % n
    if gen.startswith('last') and gen[4:].isdigit():
        return n - 1 - int(gen[4:])           # last0=last, last1=second-to-last, ...
    raise ValueError(f"unknown generator: {gen}")

class BiGen:
    """Marker produced by `!gen` - flags a generator as bidirectional for `#`."""
    def __init__(self, gen): self.gen = gen
    def __repr__(self):      return f"!{self.gen}"

# # - communication. Builds a function (pair -> pair) using a generator spec.
#   #gen           unidirectional: only L is augmented (R unchanged).
#   #!gen (BiGen)  bidirectional:  both sides packed with their counterpart.

class CommTuple(tuple):
    """Marks a node's post-comm data: (self_data, other_data). Detected by FnTuple/bang."""

def _comm_side(this, other, gen):
    """For each i in `this`, pack (this[i], other[gen(i)]).  `nil` leaves `this` alone."""
    if gen == 'nil':
        return list(this)
    n = len(other)
    return [CommTuple((this[i], other[gen_at(gen, i, n)])) for i in range(len(this))]

def hash_(spec):
    if isinstance(spec, BiGen):
        g = spec.gen
        def comm(pair):
            L, R = pair
            return (_comm_side(L, R, g), _comm_side(R, L, g))
        comm.__name__ = f"#!{g}"
        return comm
    if isinstance(spec, tuple):                                # asymmetric (gL, gR)
        gL, gR = spec
        def comm(pair):
            L, R = pair
            return (_comm_side(L, R, gL), _comm_side(R, L, gR))
        comm.__name__ = f"#({gL},{gR})"
        return comm
    g = spec
    def comm(pair):
        L, R = pair
        return (_comm_side(L, R, g), list(R))
    comm.__name__ = f"#{g}"
    return comm

# ! - overloaded by argument:
#   !gen-string  -> BiGen('gen')          (modifier inside a `#` comm)
#   !pdc         -> broadcast pdc to each side of a pair as a whole
#   !f-of-elem   -> map f over the elements of each side of a pair
def bang(f):
    if isinstance(f, str) and (f in ('corr', 'mirr', 'shift')
                               or (f.startswith('last') and f[4:].isdigit())):
        return BiGen(f)
    if isinstance(f, FnTuple):                                # !(f1,f2): each node picks f by L/R position
        nm = f.__name__
        def applied_fntuple(arg):
            L, R = arg
            return ([f[0](x) for x in L], [f[1](x) for x in R])
        applied_fntuple.__name__ = f"!{nm}"
        return applied_fntuple
    if isinstance(f, PDC):                                    # !pdc applies the PDC to each whole side
        nm = f.name
        def applied_pdc(arg):
            if isinstance(arg, tuple) and len(arg) == 2:
                return (f(arg[0]), f(arg[1]))
            return f(arg)
        applied_pdc.__name__ = f"!{nm}"
        return applied_pdc
    nm = getattr(f, 'name', None) or getattr(f, '__name__', '?')
    def applied(arg):
        if isinstance(arg, CommTuple):
            _print_sub(f, arg)
            return f(arg)
        if isinstance(arg, list):
            for x in arg: _print_sub(f, x)
            return [f(x) for x in arg]
        if isinstance(arg, tuple) and len(arg) == 2:
            L, R = arg
            if isinstance(L, list) and isinstance(R, list):
                for x in L: _print_sub(f, x)
                for x in R: _print_sub(f, x)
                return ([f(x) for x in L], [f(x) for x in R])
            _print_sub(f, L); _print_sub(f, R)
            return (f(L), f(R))
        _print_sub(f, arg)
        return f(arg)
    applied.__name__ = f"!{nm}"
    return applied

# Selectors used inside post-adjustments
def other(t): return t[1]
def self_(t): return t[0]


# Pretty-print arrays/tuples for the tracer
def _fmt(x):
    if isinstance(x, complex):
        r, i = round(x.real, 3), round(x.imag, 3)
        if i == 0:  return f"{r}"
        if r == 0:  return f"{i}j"
        return f"{r}{'+' if i>=0 else ''}{i}j"
    if isinstance(x, float):   return f"{x:.3g}"
    if isinstance(x, tuple):   return "(" + ", ".join(_fmt(e) for e in x) + ")"
    if isinstance(x, list):    return "[" + ",".join(_fmt(e) for e in x) + "]"
    return repr(x)

# Tracing state — module-level so applied() (inside bang) can find current
# indent and emit fundef substitutions under the corresponding trace step.
_TRACE_ON     = False
_TRACE_INDENT = 0
_SOURCE       = ""            # set by pydc before run(); shown at trace start.
_LIB_SOURCE   = ""            # set by pydc if a Python helper lib was supplied; shown at trace start.
_LIB_NAME     = ""
_FUNDEF_BODY  = {}            # name -> source body text (rhs of `=`), set in _preprocess_fundefs

def _sub_body(body, binds):
    """Substitute bound name -> formatted value into body source text (word-boundary)."""
    out = body
    for nm, val in sorted(binds.items(), key=lambda kv: -len(kv[0])):
        out = re.sub(r'\b' + re.escape(nm) + r'\b', _fmt(val), out)
    return out

def _flatten_params(params):
    """Params from fundef_processed -> flat list of (name, slot)."""
    flat = []
    for p in params:
        if p[0] == 'name':
            flat.append((p[1], p[2]))
        else:
            flat.extend(p[1])
    return flat

def _bind_arg(params, arg):
    """Compute name -> value map matching how fundef binds arg, without mutating slots."""
    binds = {}
    if len(params) == 1:
        p = params[0]
        if p[0] == 'name':
            binds[p[1]] = arg
        elif isinstance(arg, tuple) and len(arg) == len(p[1]):
            for (nm, _), v in zip(p[1], arg):
                binds[nm] = v
    elif isinstance(arg, tuple) and len(arg) == len(params):
        for p, v in zip(params, arg):
            if p[0] == 'name':
                binds[p[1]] = v
            elif isinstance(v, tuple) and len(v) == len(p[1]):
                for (nm, _), sv in zip(p[1], v):
                    binds[nm] = sv
    return binds

def _print_sub(fn, arg):
    """If fn is a user fundef (has _dc_body / _dc_params), print its substitution under the
       current trace indent. Called from inside bang's applied(), per element."""
    if not _TRACE_ON:                  return
    body   = getattr(fn, '_dc_body',   None)
    params = getattr(fn, '_dc_params', None)
    if not body or not params:         return
    binds  = _bind_arg(params, arg)
    subbed = _sub_body(body, binds)
    nm     = getattr(fn, '__name__', 'fn')
    bs     = ", ".join(f"{k}={_fmt(v)}" for k, v in binds.items())
    pad    = "  " * (_TRACE_INDENT + 1)
    print(f"{pad}{nm}({_fmt(arg)})  [{bs}]  -> {body}  = {subbed}")

# PDC value
class PDC:
    def __init__(self, divide=d_lr, combine=c_lr, pre=id_, post=id_,
                 basep=atom, basef=id_, name='f'):
        self.divide, self.combine = divide, combine
        self.pre, self.post = pre, post
        self.basep, self.basef = basep, basef
        self.name = name

    def __call__(self, A):
        if self.basep(A):
            return self.basef(A)
        pair = self.divide(A)
        pair = self.pre(pair)
        L, R = pair
        pair = (self(L), self(R))
        pair = self.post(pair)
        return self.combine(pair)

    def _trace(self, A, depth=0):
        global _TRACE_INDENT
        _TRACE_INDENT = depth
        pad = "  " * depth
        print(f"{pad}{self.name}({_fmt(A)})")
        if self.basep(A):
            r = self.basef(A)
            print(f"{pad}  ⇣ atom; basef -> {_fmt(r)}")
            return r
        pair = self.divide(A)
        print(f"{pad}  divide  {_name_of(self.divide):<10} -> {_fmt(pair)}")
        if self.pre is not id_:
            pair = _trace_step(self.pre, pair, f"{pad}  pre")
        L, R = pair
        L_ = self._trace(L, depth + 1)
        R_ = self._trace(R, depth + 1)
        pair = (L_, R_)
        if self.post is not id_:
            pair = _trace_step(self.post, pair, f"{pad}  post")
        r = self.combine(pair)
        print(f"{pad}  combine {_name_of(self.combine):<10} -> {_fmt(r)}")
        return r


def make_pdc(*args, **kw):
    """PDC(d, c, pre, post, basep, basef) - positional matches Mou's tuple order.
       PDC(divide=..., combine=..., pre=..., post=..., basep=..., basef=...) - kw."""
    fields = ('divide', 'combine', 'pre', 'post', 'basep', 'basef')
    for name, val in zip(fields, args):
        if name in kw:
            raise TypeError(f"PDC: {name} given both positionally and by keyword")
        kw[name] = val
    return PDC(**kw)

def _trace_step(fn, x, label):
    """Apply fn to x with a labeled trace; if fn is Composed, show each sub-step."""
    fns = fn.fns if isinstance(fn, Composed) else [fn]
    for sub in reversed(fns):
        x = sub(x)
        nm = getattr(sub, '__name__', repr(sub))
        print(f"{label} {nm:<10} -> {_fmt(x)}")
    return x

def _name_of(f):
    return getattr(f, 'name', None) or getattr(f, '__name__', None) or repr(f)

def _trace_any(fn, x, depth):
    """Trace one callable: descend into PDCs and Composed pipelines; label others."""
    global _TRACE_INDENT
    _TRACE_INDENT = depth
    if isinstance(fn, PDC):
        return fn._trace(x, depth)
    if isinstance(fn, Composed):
        for sub in reversed(fn.fns):
            x = _trace_any(sub, x, depth)
        return x
    pad = "  " * depth
    y = fn(x)
    print(f"{pad}{_name_of(fn):<10} -> {_fmt(y)}")
    return y

def trace(pdc, *args):
    """Run pdc on the input, printing every step.
       Multiple args are packed into a tuple (so `trace(f, A, B)` -> `f((A, B))`)."""
    global _TRACE_ON, _TRACE_INDENT
    A = args[0] if len(args) == 1 else tuple(args)
    # If tracing a Python-defined lib function, show just its own source.
    # Otherwise (PDC/Composed/etc defined in .dc) show the full .dc + lib sources.
    _src_for_trace = None
    try:
        import inspect
        if inspect.isfunction(pdc) and inspect.getsourcefile(pdc):
            _src_for_trace = inspect.getsource(pdc)
    except (TypeError, OSError):
        pass
    if _src_for_trace:
        print(f"-- source: {_name_of(pdc)} --")
        for ln in _src_for_trace.rstrip().split("\n"):
            print(f"  {ln}")
        print(f"-- end source --")
    else:
        if _SOURCE:
            print("-- source --")
            for ln in _SOURCE.rstrip().split("\n"):
                print(f"  {ln}")
            print("-- end source --")
        if _LIB_SOURCE:
            print(f"-- lib: {_LIB_NAME} --")
            for ln in _LIB_SOURCE.rstrip().split("\n"):
                print(f"  {ln}")
            print(f"-- end lib --")
    print(f"-- trace: {_name_of(pdc)}({_fmt(A)}) --")
    _TRACE_ON, _TRACE_INDENT = True, 0
    try:
        r = _trace_any(pdc, A, 0)
    finally:
        _TRACE_ON, _TRACE_INDENT = False, 0
    print(f"-- result: {_fmt(r)} --")
    return r

def _fmt_c(x):
    """Format a value compactly for HTML: round complex numbers."""
    if isinstance(x, complex):
        r, i = round(x.real, 3), round(x.imag, 3)
        if i == 0:   return str(r)
        if r == 0:   return f"{i}i"
        sign = '+' if i >= 0 else '-'
        return f"{r}{sign}{abs(i)}i"
    if isinstance(x, tuple): return "(" + ", ".join(_fmt_c(e) for e in x) + ")"
    if isinstance(x, list):  return "[" + ", ".join(_fmt_c(e) for e in x) + "]"
    return repr(x)

def trace_html(pdc, *args):
    """Like trace() but returns an HTML string instead of printing."""
    import io, contextlib, html as _html
    buf = io.StringIO()
    with contextlib.redirect_stdout(buf):
        trace(pdc, *args)
    text = buf.getvalue()
    esc  = _html.escape(text)
    print(
        '<style>.dc-trace{background:#1a1a2e;color:#a8d8ea;padding:1em;border-radius:6px;'
        'font-family:monospace;font-size:.85em;overflow-x:auto;white-space:pre;'
        'border:1px solid #444;margin:1em 0}</style>\n'
        f'<h2>{_html.escape(_name_of(pdc))}</h2>'
        f'<div class="dc-trace">{esc}</div>'
    )


def _zip(arg):
    """zip((A,B)) -> [(a0,b0), (a1,b1), ...]; or zip(A, B) -> same."""
    if isinstance(arg, tuple) and len(arg) == 2:
        A, B = arg
    else:
        raise TypeError(f"zip expects a 2-tuple of vectors, got {arg!r}")
    return [(a, b) for a, b in zip(A, B)]

def _unzip(V):
    """unzip([(a0,b0),...]) -> ([a0,...], [b0,...])."""
    return ([p[0] for p in V], [p[1] for p in V])

def _xvec(x, m): return [x] * m

def _lift1(f):
    """Lift a value-function f to a singleton-list-function: [v] -> [f(v)].
       Bridges pure user functions into PDC's basef convention."""
    def lifted(L):
        return [f(L[0])]
    lifted.__name__ = f"lift1({getattr(f, '__name__', f)})"
    return lifted

STDLIB = {
    'id': id_, 'atom': atom, 'atomq': atom,
    'zip': _zip, 'unzip': _unzip, 'xvec': _xvec, 'lift1': _lift1,

    'd_lr': d_lr, 'c_lr': c_lr, 'd_eo': d_eo, 'c_eo': c_eo,
    'first': first, 'only': first, 'second': second, 'third': third, 'fourth': fourth,
    'plus': plus, '+': plus, '*': times, '-': minus, '/': div,
    'sum_combine': sum_combine, 'max_combine': max_combine,
    'min': _min2,                                              # min over a 2-tuple
    'max': _max2,                                              # max over a 2-tuple
    'self': self_, 'other': other,
    'corr': 'corr', 'mirr': 'mirr', 'shift': 'shift',          # generator tags
    'nil': 'nil',                                              # no-comm sentinel for #(gL, gR)
    'last':     lambda n: f'last{n}',                          # last(n) -> last_n generator name
    'PDC': make_pdc, 'trace': trace,
    'sqrt': __import__('cmath').sqrt,
    'complex': complex,
    'magnitude': abs,
    'ap': lambda f, *args: (
        (lambda x, _f=f, _a=args: _f(*[a(x) if callable(a) else a for a in _a]))
        if any(callable(a) for a in args) else f(*args)
    ),
}


# ----------------------- Interpreter -----------------------

OVERRIDES = {}    # name -> value: CLI-injected bindings that win over in-source assigns.

class Interp(lark.Transformer):
    def __init__(self):
        super().__init__()
        self.env = dict(STDLIB)
        self.env.update(OVERRIDES)

    def num(self, items):
        s = str(items[0])
        return float(s) if '.' in s else int(s)

    def list_lit(self, items):
        return [x for x in items if x is not None]

    def named_field(self, items):
        return (str(items[0]), items[1])

    def pos_field(self, items):
        return ('', items[0])

    def tup_items(self, items): return list(items)

    def tuple_lit(self, items):
        pairs = items[0] if len(items) == 1 and isinstance(items[0], list) else list(items)
        names = [p[0] for p in pairs]
        vals  = [p[1] for p in pairs]
        if any(names): return NamedRec(vals, names)
        return FnTuple(vals)

    def field_(self, items):
        base, name = items[0], str(items[1])
        if isinstance(base, NamedRec):
            return getattr(base, name)
        if callable(base):
            f = lambda x, _b=base, _n=name: getattr(_b(x), _n)
            f.__name__ = f"{getattr(base,'__name__',base)}.{name}"
            return f
        return getattr(base, name)

    def var(self, items):
        n = str(items[0])
        if n not in self.env:
            raise NameError(f"undefined name: {n}")
        return self.env[n]

    def qvar(self, items):
        n = str(items[0])[1:-1]               # strip the surrounding quotes
        if n not in self.env:
            raise NameError(f"undefined name: '{n}'")
        return self.env[n]

    def compose(self, items):
        return items[0] if len(items) == 1 else compose_many(list(items))

    def add(self, items):  return liftop(items[0], items[1], lambda a, b: a + b, '+')
    def sub(self, items):  return liftop(items[0], items[1], lambda a, b: a - b, '-')
    def neg(self, items):  return liftop(       0, items[0], lambda a, b: a - b, '-')
    def mul(self, items):  return liftop(items[0], items[1], lambda a, b: a * b, '*')
    def div_(self, items): return liftop(items[0], items[1], lambda a, b: a / b, '/')

    def bang(self, items): return bang(items[0])
    def hash_(self, items): return hash_(items[0])

    def slot_ref(self, items):
        slot = items[0]   # _Slot object carried through AST
        f = lambda x: slot.val
        f.__name__ = slot.name
        return f

    def let_processed(self, items):
        e_val, body_val, slot = items[0], items[1], items[2]
        if not callable(e_val):
            slot.val = e_val
            return body_val
        if not callable(body_val):
            return body_val
        def let_wrap(x):
            slot.val = e_val(x)
            return body_val(x)
        let_wrap.__name__ = f"(let {slot.name})"
        return let_wrap

    def index_(self, items):
        base, n = items[0], int(str(items[1]))
        if callable(base):
            f = lambda x, _b=base, _n=n: _b(x)[_n]
            f.__name__ = f"{getattr(base,'__name__',base)}.{n}"
            return f
        return base[n]

    def kwarg(self, items): return ('kw', str(items[0]), items[1])
    def posarg(self, items): return ('pos', items[0])
    def args(self, items): return list(items)

    def call(self, items):
        f = items[0]
        raw = items[1] if len(items) > 1 and items[1] is not None else []
        pos, kw = [], {}
        for a in raw:
            if a[0] == 'kw': kw[a[1]] = a[2]
            else: pos.append(a[1])
        return f(*pos, **kw)

    def assign(self, items):
        name = str(items[0])
        if name in OVERRIDES:
            return                       # CLI binding wins over in-source assign
        val = items[1]
        if isinstance(val, PDC):
            val.name = name
        elif callable(val) and hasattr(val, 'name'):
            val.name = name
        self.env[name] = val

    def fundef_processed(self, items):
        name     = str(items[0])
        body_val = items[1]
        params   = list(items[2].children)
        n        = len(params)
        def _bind(p, val):
            if p[0] == 'name':
                p[2].val = val
            else:
                if not isinstance(val, tuple) or len(val) != len(p[1]):
                    raise TypeError(f"{name}: expected {len(p[1])}-tuple, got {val!r}")
                for (nm, slot), v in zip(p[1], val):
                    slot.val = v
        def fn(arg):
            if n == 1:
                _bind(params[0], arg)
            else:
                if not isinstance(arg, tuple) or len(arg) != n:
                    raise TypeError(f"{name} expects a {n}-tuple, got {arg!r}")
                for p, v in zip(params, arg): _bind(p, v)
            return body_val(arg) if callable(body_val) else body_val
        fn.__name__ = name
        fn._dc_body   = _FUNDEF_BODY.get(name)
        fn._dc_params = params
        self.env[name] = fn

    def exprstmt(self, items):
        return items[0]

    def start(self, items):
        return self.env


try:
    import pydc                                            # user-extensible sidecar
    for _name, _val in vars(pydc).items():
        if callable(_val) and not _name.startswith('_'):
            STDLIB[_name] = _val
except ImportError:
    pass


PARSER = lark.Lark(GRAMMAR, parser='lalr')              # LALR forces NAME "(" to bind as call,
                                                        # avoiding Earley's ambiguity that split
                                                        # `name = PDC` from `(d_lr, ...)` as
                                                        # separate stmt + tuple_lit.

class _Slot:
    """Mutable cell carried through the AST for let-binding sharing.
       slot.val holds the per-element computed value during body evaluation."""
    __slots__ = ('name', 'val')
    def __init__(self, name): self.name = name; self.val = None

def _peel(node):
    """Strip single-child wrapper trees down to a meaningful node."""
    while isinstance(node, lark.Tree) and len(node.children) == 1 \
            and node.data not in ('var', 'tuple_lit', 'pos_field', 'named_field'):
        node = node.children[0]
    return node

def _parse_param(node, fname):
    """Parse one fundef parameter. Returns ('name', n, slot) or ('tuple', [(n,slot), ...])."""
    n = _peel(node)
    if isinstance(n, lark.Tree) and n.data == 'pos_field':
        n = _peel(n.children[0])
    if isinstance(n, lark.Tree) and n.data == 'var':
        nm = str(n.children[0])
        return ('name', nm, _Slot(nm))
    if isinstance(n, lark.Tree) and n.data == 'tuple_lit':
        items = n.children[0].children if (len(n.children) == 1 and isinstance(n.children[0], lark.Tree) and n.children[0].data == 'tup_items') else n.children
        sub = []
        for it in items:
            it2 = _peel(it)
            if isinstance(it2, lark.Tree) and it2.data == 'pos_field':
                it2 = _peel(it2.children[0])
            if isinstance(it2, lark.Tree) and it2.data == 'var':
                nm = str(it2.children[0])
                sub.append((nm, _Slot(nm)))
            else:
                raise SyntaxError(f"fundef {fname}: tuple pattern items must be names")
        return ('tuple', sub)
    raise SyntaxError(f"fundef {fname}: parameters must be names or tuple patterns")

def _process_fundefs(tree):
    """Pre-pass: rewrite `f(p1,...,pn) = body` -> closure with slot-bound params.
       Params can be NAMEs or (NAME, ..., NAME) tuple patterns."""
    if not isinstance(tree, lark.Tree):
        return tree
    if tree.data == 'fundef':
        name      = str(tree.children[0])
        args_tree = tree.children[1]
        body      = tree.children[2]
        params    = [_parse_param(a, name) for a in args_tree.children]
        body_sub  = body
        all_slots = []
        for p in params:
            if p[0] == 'name':
                body_sub = _subst_slot(body_sub, p[1], p[2])
                all_slots.append(p[2])
            else:
                for nm, slot in p[1]:
                    body_sub = _subst_slot(body_sub, nm, slot)
                    all_slots.append(slot)
        body_sub = _process_fundefs(body_sub)
        return lark.Tree('fundef_processed',
                         [lark.Token('NAME', name), body_sub,
                          lark.Tree('_params', params)])
    return lark.Tree(tree.data, [_process_fundefs(c) for c in tree.children])

def _process_lets(tree):
    """Pre-pass: rewrite let_ nodes into let_processed(e, body, _Slot).
       var(name) references in body are replaced with slot_ref(_Slot).
       At element-call time the wrapper computes e(x) once into slot.val,
       then calls body(x) - every slot_ref reads the same slot.val."""
    if not isinstance(tree, lark.Tree):
        return tree
    if tree.data == 'let_':
        name     = str(tree.children[0])
        e_sub    = _process_lets(tree.children[1])
        body_sub = _process_lets(tree.children[2])
        slot     = _Slot(name)
        body_sub = _subst_slot(body_sub, name, slot)
        return lark.Tree('let_processed', [e_sub, body_sub, slot])
    return lark.Tree(tree.data, [_process_lets(c) for c in tree.children])

def _subst_slot(tree, name, slot):
    """Replace var(name) with slot_ref(slot) everywhere in tree."""
    if not isinstance(tree, lark.Tree):
        return tree
    if tree.data == 'var' and str(tree.children[0]) == name:
        return lark.Tree('slot_ref', [slot])
    return lark.Tree(tree.data, [_subst_slot(c, name, slot) for c in tree.children])

import re
_FUNDEF_HEAD_RE = re.compile(r'^([ \t]*)([a-zA-Z_]\w*)\s*\(')

def _preprocess_fundefs(src):
    """Source-level rewrite: `name(...) = body` -> `def name(...) = body`.
       Handles balanced parens (nested tuples in destructuring patterns).
       Fundef head + params may span multiple physical lines; comments (--...)
       on each continuation line are stripped during paren scanning."""
    def _strip_comment(s):
        j = s.find('--')
        return s if j < 0 else s[:j]
    lines = src.split('\n')
    out, i = [], 0
    while i < len(lines):
        line = lines[i]
        m    = _FUNDEF_HEAD_RE.match(line)
        if not m:
            out.append(line); i += 1; continue
        head, depth      = m.end(), 1
        joined           = _strip_comment(line)
        j, start_i       = head, i
        while True:
            while j < len(joined):
                c = joined[j]
                if   c == '(': depth += 1
                elif c == ')':
                    depth -= 1
                    if depth == 0: j += 1; break
                j += 1
            if depth == 0:                       break
            i += 1
            if i >= len(lines):                  joined = None; break
            joined = joined + ' ' + _strip_comment(lines[i])
        # Parens balanced. Keep pulling lines until we see a non-whitespace
        # character — to detect '=' that may live on a later line.
        while joined is not None and not joined[j:].strip():
            i += 1
            if i >= len(lines):                  break
            joined = joined + ' ' + _strip_comment(lines[i])
        if joined is not None and joined[j:].lstrip().startswith('='):
            eq_pos = joined.find('=', j)
            body   = joined[eq_pos+1:].strip()
            _FUNDEF_BODY[m.group(2)] = body
            ws = m.group(1)
            out.append(ws + 'def ' + joined[len(ws):])
        else:
            for k in range(start_i, min(i, len(lines) - 1) + 1):
                out.append(lines[k])
        i += 1
    return '\n'.join(out)

def run(source):
    source = _preprocess_fundefs(source)
    tree = PARSER.parse(source)
    tree = _process_fundefs(tree)
    tree = _process_lets(tree)
    interp = Interp()
    interp.transform(tree)
    return interp.env


# ----------------------- Demo -----------------------

if __name__ == "__main__":
    import sys, os
    arg = sys.argv[1] if len(sys.argv) > 1 else None
    if arg and os.path.isfile(arg):
        run(open(arg).read())                                 # python3 divacon.py prog.dc
    else:
        N = int(arg) if arg else 8                            # python3 divacon.py [N]
        A = list(range(1, N + 1))
        run(f"""
        reverse = PDC(d_lr, c_lr, id, !other @ #!corr, atom, id)
        trace(reverse, {A})
        """)
