#!/usr/bin/python3
"""pydc - run a Divacon program, optionally with a Python helper library.

  usage:  pydc [--html] <program.dc> [library.py] [name=value ...]

  --html      replace trace() with trace_html() and print the HTML fragment to stdout.
  name=value  bind a variable that overrides any in-source `name = ...` assignment.
              Values are parsed as int, then float, then string (in that order).

If a library is supplied (and is a real .py file), every public callable
it defines is added to the Divacon STDLIB before the program runs."""

import os, sys, importlib.util

args  = sys.argv[1:]
html  = '--html' in args
args  = [a for a in args if a != '--html']

bindings = [a for a in args if '=' in a and not a.startswith('=')]
files    = [a for a in args if a not in bindings]

if len(files) not in (1, 2):
    print("usage: pydc [--html] <program.dc> [library.py] [name=value ...]", file=sys.stderr)
    sys.exit(2)

prog = files[0]
lib  = files[1] if len(files) == 2 else None

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import divacon

if html:
    divacon.STDLIB['trace'] = divacon.trace_html

if lib and lib.endswith('.py') and os.path.isfile(lib):
    spec = importlib.util.spec_from_file_location("_pydc_user_lib", lib)
    mod  = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    for name, val in vars(mod).items():
        if not name.startswith('_'):
            divacon.STDLIB[name] = val
    divacon._LIB_SOURCE = open(lib).read()
    divacon._LIB_NAME   = os.path.basename(lib)

for b in bindings:
    k, v = b.split('=', 1)
    try:    val = int(v)
    except ValueError:
        try:    val = float(v)
        except ValueError: val = v
    divacon.OVERRIDES[k] = val

src = open(prog).read()
divacon._SOURCE = src

import lark
try:
    divacon.run(src)
except (lark.exceptions.UnexpectedToken,
        lark.exceptions.UnexpectedCharacters,
        lark.exceptions.UnexpectedInput) as e:
    lines = src.split("\n")
    ln    = getattr(e, "line", None)
    col   = getattr(e, "column", None)
    print(f"\n*** Parse error in {prog} at line {ln}, column {col} ***", file=sys.stderr)
    if ln and 1 <= ln <= len(lines):
        print(f"  {lines[ln-1]}", file=sys.stderr)
        if col: print("  " + " " * (col-1) + "^", file=sys.stderr)
    tok = getattr(e, "token", None)
    if tok is not None: print(f"  unexpected: {tok!r}", file=sys.stderr)
    _NAME_PRETTY = {
        "_LIFT_PLUS": "[+]", "_LIFT_MINUS": "[-]",
        "_LIFT_TIMES": "[*]", "_LIFT_DIV":   "[/]",
        "LPAR": "(", "RPAR": ")", "LSQB": "[", "RSQB": "]",
        "PLUS": "+", "MINUS": "-", "STAR": "*", "SLASH": "/",
        "AT": "@", "COMMA": ",", "DOT": ".", "EQUAL": "=",
        "BANG": "!", "HASH": "#",
    }
    allowed = getattr(e, "allowed", None) or getattr(e, "expected", None)
    if allowed:
        names = sorted(_NAME_PRETTY.get(str(a), str(a)) for a in allowed)
        print(f"  expected one of: {', '.join(names)}", file=sys.stderr)
    sys.exit(1)
