from dataclasses import dataclass, field from typing import Tuple, Union import juic.parser from juic.parser import Node from juic.datatypes import * # Storage of a scope's internal data, including all definitions in the scope. # The data in this object is never copied, so for instance, the name -> value # mappings are always unique and can only be found in the relevant's scope # unique MutableScopeData instance. This data is, of course, not constant, # which makes it unsuitable for building closures. The Closure type captures an # immutable snapshot of a scope's contents at one point in time. @dataclass class MutableScopeData: # List of definitions. Each value has a timestamp integer that indicates # when it was added; this allows snapshots to ignore definitions newer than # the snapshot by simply comparing timestamp values. The value itself is # either a thunk (evaluated or not), or a runtime value. defs: dict[str, Tuple[int, Union[JuiValue, "Thunk"]]] = \ field(default_factory=dict) # Current timestamp, changed every time a value is added. timestamp: int = 0 def addDefinition(self, name: str, content: Union[JuiValue, "Thunk"]): if name in self.defs: raise JuiRuntimeError(f"{name} already defined in current scope") self.defs[name] = (self.timestamp, content) self.timestamp += 1 def lookup(self, name: str, maxTimestamp: int = -1) \ -> Union[JuiValue, "Thunk", None]: if name not in self.defs: return None if maxTimestamp >= 0 and self.defs[name][0] >= maxTimestamp: return None return self.defs[name][1] def dump(self): digits = len(str(self.timestamp)) print(f"Scope: timestamp={self.timestamp}, defs:") for name, (ts, content) in self.defs.items(): print(f" {ts: >{digits}} {name}:", juiValueString(content)) # An immutable snapshot of all definitions that can be referred to at a given # point in time. This specifies the interpretation of all defined identifiers # but not of the "this" keyword. @dataclass class Closure: # Parent snapshot on which this one is based, which defines the contents of # surrounding scopes at that time. parent: Union["Closure", None] # Reference to the scope in which the closure was created. scope: MutableScopeData # Timestamp at which the closure was created. Definitions from `scope` # created later than the closure will be ignored during name lookup. timestamp: int = field(init=False) def __post_init__(self): # Automatically get the timestamp at the time of creation self.timestamp = self.scope.timestamp def lookup(self, name: str) -> Union[JuiValue, "Thunk", None]: value = self.scope.lookup(name, self.timestamp) if value is not None or self.parent is None: return value return self.parent.lookup(name) def dump(self, level=0): print(f"Closure at depth {level}: timestamp={self.timestamp}:") self.scope.dump() if self.parent is not None: self.parent.dump(level+1) # Thunk representing a suspended/lazy computation. Thunks are evaluated when # needed (which is the call-by-need evaluation strategy). They are always # interpreted in the context of the associated closure. # # While the object is technically mutable, the result of the computation # (whether a value or a cyclic invalid result) is entirely determined by the # expression AST and closure, so the result is just a lazily computed/memoized # member. That said, the result might not exist if there is a cyclic dependency # between this and other thunks referenced by the closure. @dataclass class Thunk: # Expression AST that the thunk evaluates to ast: Node # Closure in which the thunk can be evaluated, plus the definition for # "this"; if there is one, it is always a PRS closure: Closure thisReference: Union[None, "PartialRecordSnapshot"] = None # Whether the thunk has been evaluated yet evaluated: bool = False # Whether the thunk has been evaluated to an invalid result because of a # cyclic dependency with other thunks captured by the closure invalid: bool = False # If the thunk has been evaluated and is not invalid, resulting value. None # is a possible resulting value (`null` in the language); do not use this # field to determine whether the thunk has been successfully evaluted. result: JuiValue = None # Whether the thunk is currently under evaluation _running: bool = False def __str__(self): s = "Thunk(" if self._running: s += "running, " if self.evaluated: s += "evaluated" if self.invalid: s += ", invalid" else: s += " = " + juiValueString(self.result) else: s += "unevaluated" s += "){ " + str(self.ast) + " }" return s # Object representing the interpretation of a partially-built record. In a # record construction or update operation, fields of `this` refer to the latest # definition before the point of use, if there is one, or the definition in the # underlying record otherwise, if there is one. Because the use of record # constructor arguments can create forwards dependencies, all these definitions # are thunked and computed lazily in a somewhat unpredictable order. A partial # record snapshot fully specifies the interpretation of all fields of a # partially-constructed record (i.e. `this`) at the location of one record # entry. It does so by mapping defined field names to the correct thunks. This # object is how dependencies between fields are discovered and followed. @dataclass class PartialRecordSnapshot: # Mapping from fields of `this` to thunks fieldThunks: dict[str, Thunk] = field(default_factory=dict) # Base object for fields not captured in `fieldThunks` base: None | Thunk = None def __str__(self): return "" # TODO def copy(self): return PartialRecordSnapshot(self.fieldThunks.copy(), self.base) class JuiTypeError(Exception): pass class JuiNameError(Exception): pass class JuiStaticError(Exception): pass class JuiRuntimeError(Exception): pass # TODO: Better diagnostics for type errors def requireType(v: JuiValue, predicate: Callable[[JuiValue], bool]): if not predicate(v): raise JuiTypeError(f"type error: got {v}, needed {predicate}") def requireSameType(vs: list[JuiValue], predicate: Callable[[JuiValue], bool]): if len(vs) == 0: return if len(set(type(v) for v in vs)) != 1: vals = ", ".join(str(v) for v in vs) raise JuiTypeError(f"type error: heterogeneous types for {vals}") if not predicate(vs[0]): raise JuiTypeError(f"type error: got {vs[0]}, needed {predicate}") # Evaluation context. This tracks the current scope and provides evaluation # functions, error tracking, etc. class Context: # Current top-level scope. Can be None. This is the only scope in which we # can add definitions. currentScope: MutableScopeData # Current closure defining the scopes that surround `currentScope`. currentClosure: Closure | None def __init__(self, initialClosure = None): self.currentScope = MutableScopeData() self.currentClosure = initialClosure self._contextStack = [] #=== Context switch commodity ===# _contextStack: list[Tuple[MutableScopeData, Closure]] class ContextSwitchContextManager: def __init__(self, parent: "Context", c: Closure): self.parent = parent self.targetClosure = c def __enter__(self): p = self.parent p._contextStack.append((p.currentScope, p.currentClosure)) p.currentScope = MutableScopeData() p.currentClosure = self.targetClosure def __exit__(self, exc_type, exc_value, traceback): p = self.parent s, c = p._contextStack.pop() p.currentScope = s p.currentClosure = c # Temporarily switch the execution context. This is used when we need e.g. # to evaluate a thunk in its associated closure. Morally one could # construct a different Context object, but it's easier to accumulate error # messages and other things in the same Context object. # Example: # with self.contextSwitchAndPushNewScope(myClosure): # myResult = self.eval(myExprNode) def contextSwitchAndPushNewScope(self, c: Closure): return Context.ContextSwitchContextManager(self, c) # Freeze the current scope and push a new empty one to start working in. # Example: # with self.pushNewScope(): # self.addDefinition(myParamName, myParamValue) # myResult = self.eval(myFunctionBody) def pushNewScope(self): return self.contextSwitchAndPushNewScope(self.currentStateClosure()) #=== State management ===# # Build a the closure of the current state def currentStateClosure(self): return Closure(parent=self.currentClosure, scope=self.currentScope) # Add a new definition in the innermost scope def addDefinition(self, name: str, content: JuiValue | Thunk): return self.currentScope.addDefinition(name, content) # Lookup a name def lookup(self, name: str) -> JuiValue | Thunk | None: return self.currentStateClosure().lookup(name) #=== Evaluation helpers ===# def makeFunction(self, params: list[Tuple[str, bool]], body: Node) \ -> Function: # Check validity of the variadic parameter specification variadics = [i for i, (_, v, _) in enumerate(params) if v] for (name, _, _) in params: count = len([0 for (n, _, _) in params if n == name]) if count > 1: raise JuiStaticError(f"duplicate argument {name}") if len(variadics) > 1: names = ", ".join(params[i][0] for i in variadics) raise JuiStaticError(f"multiple variadic arguments: {names}") if variadics and variadics != [len(params) - 1]: name = params[variadics[0]][0] raise JuiStaticError(f"variadic argument {name} not at the end") variadic = params[variadics[0]][0] if variadics else None # Build function object # TODO: Make a full scope not just an expr return Function(closure = self.currentStateClosure(), body = Node(Node.T.SCOPE_EXPR, [body]), params = [name for (name, v, _) in params if not v], variadic = variadic) def makeRecordCtor(self, params: list[Tuple[str, bool]], body: Node) \ -> RecordCtor: return RecordCtor(func=self.makeFunction(params, body)) #=== Main evaluation functions ===# # Expression evaluator; this function is pure. It can cause thunks to be # evaluated and thus reveal previously-undiscovered cyclic dependencies, # but it doesn't modify the context. def evalExpr(self, node: Node) -> JuiValue: match node.ctor, node.args: case Node.T.LIT, [x]: return x case Node.T.IDENT, [name]: match self.lookup(name): case None: raise JuiNameError(f"name {name} not defined") case Thunk() as t: return self.evalThunk(t) case value: return value case Node.T.OP, ["+", X, Y]: x = self.evalExpr(X) y = self.evalExpr(Y) requireSameType([x, y], juiIsAddable) return x + y # TODO: Arithmetic: integer division, type conversions case Node.T.OP, [("+" | "-" | "*" | "/" | "%") as op, X, Y]: x = self.evalExpr(X) y = self.evalExpr(Y) requireSameType([x, y], juiIsArith) match op: case "+": return x + y case "-": return x - y case "*": return x * y case "/": match x, y: case float(), float(): return (x / y) case _, _: return (x // y) case "%": return x % y case Node.T.OP, [("!" | "+" | "-") as op, X]: x = self.evalExpr(X) requireType(x, juiIsArith) match op: case "!": return not x case "+": return +x case "-": return -x case Node.T.OP, [(">"|">="|"<"|"<="|"=="|"!=") as op, X, Y]: x = self.evalExpr(X) y = self.evalExpr(Y) requireSameType([x, y], juiIsComparable) match op: case ">": return x > y case "<": return x < y case ">=": return x >= y case "<=": return x <= y case "==": return x == y case "!=": return x != y case Node.T.OP, [("&&" | "||") as op, X, Y]: x = self.evalExpr(X) y = self.evalExpr(Y) requireSameType([x, y], juiIsLogical) match op: case "&&": return x and y case "||": return x or y case Node.T.OP, ["...", X]: x = self.evalExpr(X) requireType(x, juiIsUnpackable) raise NotImplementedError("unpack operator o(x_x)o") case Node.T.OP, [("|"|"|>"), X, F]: f = self.evalExpr(F) requireType(f, juiIsCallable) return self.evalCall(f, [X]) case Node.T.OP, ["<|", F, X]: f = self.evalExpr(F) requireType(f, juiIsCallable) return self.evalCall(f, [X]) case Node.T.THIS, []: v = self.lookup("this") if v is None: raise JuiNameError(f"no 'this' in current context") assert isinstance(v, PartialRecordSnapshot) return v case Node.T.PROJ, [R, field]: r = self.evalExpr(R) requireType(r, juiIsProjectable) f = self.project(r, field) if isinstance(f, Thunk): return self.evalThunk(f) else: return f case Node.T.CALL, [F, *A]: f = self.evalExpr(F) requireType(f, juiIsCallable) return self.evalCall(f, A) case Node.T.IF, [C, T, E]: c = self.evalExpr(C) requireType(c, juiIsLogical) if bool(c): return self.evalExpr(T) else: return None if E is None else self.evalExpr(E) case Node.T.RECORD, [R, *A]: r = self.evalExpr(R) requireType(r, juiIsConstructible) return self.evalRecordConstructor(r, A) case Node.T.REC_ATTR, _: raise NotImplementedError case Node.T.REC_VALUE, args: raise NotImplementedError raise Exception("invalid expr o(x_x)o: " + str(node)) def project(self, v: JuiValue, field: str) -> JuiValue: match v: case Record() as r: if field not in r.attr: raise Exception(f"access to undefined field {field}") return self.evalValueOrThunk(r.attr[field]) case PartialRecordSnapshot() as prs: if field in prs.fieldThunks: return self.evalThunk(prs.fieldThunks[field]) elif prs.base is None: raise Exception(f"access to undefined field {field} of 'this'") else: return self.project(self.evalThunk(prs.base), field) case _: raise NotImplementedError # unreachable def execStmt(self, node: Node) -> JuiValue: match node.ctor, node.args: case Node.T.LET_DECL, [name, X, let_type]: self.addDefinition(name, self.evalExpr(X)) # TODO: Check type? return None case Node.T.FUN_DECL, [name, params, body, body_type]: self.addDefinition(name, self.makeFunction(params, body)) # TODO: Check type when called? return None case Node.T.REC_DECL, [name, params, body, body_type]: self.addDefinition(name, self.makeRecordCtor(params, body)) return None case Node.T.SET_STMT, _: raise NotImplementedError case Node.T.UNIT_TEST, [subject, expected]: vs = self.evalExpr(subject) ve = self.evalExpr(expected) vs = self.force(vs) ve = self.force(ve) if not juiValuesEqual(vs, ve): print("unit test failed:") print(" " + juiValueString(vs)) print("vs.") print(" " + juiValueString(ve)) return None case Node.T.SCOPE_EXPR, [e]: return self.evalExpr(e) raise Exception(f"execStmt: unrecognized node {node.ctor.name}") # TODO: Context.eval*: continue failed computations to find other errors? def evalThunk(self, th: Thunk) -> JuiValue: if th.evaluated: if th.invalid: raise JuiRuntimeError("cyclic dependency result encountered") return th.result if th._running: raise JuiRuntimeError("cyclic dependency detected!") with self.contextSwitchAndPushNewScope(th.closure): if th.thisReference is not None: self.addDefinition("this", th.thisReference) th._running = True result = self.evalExpr(th.ast) th._running = False # TODO: Set invalid = True if we continue failed computations th.evaluated = True th.result = result return result def evalValueOrThunk(self, v: JuiValue | Thunk) -> JuiValue: return self.evalThunk(v) if isinstance(v, Thunk) else v def evalCall(self, f: JuiValue, args: list[Node]) -> JuiValue: # Built-in functions: just evaluate arguments and go # TODO: Check types of built-in function calls if type(f) == BuiltinFunction: return f.func(*[self.evalExpr(a) for a in args]) assert type(f) == Function and "evalCall: bad type check precondition" # Check number of arguments req = str(len(f.params)) + ("+" if f.variadic is not None else "") if len(args) < len(f.params): raise JuiRuntimeError(f"not enough args (need {req}, got {len(args)})") if len(args) > len(f.params) and f.variadic is None: raise JuiRuntimeError(f"too many args (need {req}, got {len(args)})") # TODO: In order to build variadic set I need a LIST node if f.variadic is not None: raise NotImplementedError("list node for building varargs o(x_x)o") # Run into the function's scope with self.contextSwitchAndPushNewScope(f.closure): for name, node in zip(f.params, args): th = Thunk(ast=node, closure=self.currentStateClosure()) self.addDefinition(name, th) # self.currentScope.dump() # self.currentClosure.dump() assert f.body.ctor == Node.T.SCOPE_EXPR return self.execStmt(f.body) def evalRecordConstructor(self, ctor: JuiValue, entries: list[Node]) \ -> JuiValue: # Base record constructor: starts out with an empty record. All # arguments are children. Easy. if isinstance(ctor, RecordType): r = Record(base=ctor, attr=dict(), children=[]) # Create thunks for all entries while providing them with # progressively more complete PRS of r. prs = PartialRecordSnapshot() for i, e in enumerate(entries): if e.ctor == Node.T.REC_ATTR: name, label, node = e.args elif e.ctor == Node.T.REC_VALUE: name, label, node = None, None, e.args[0] th = Thunk(ast=node, closure=self.currentStateClosure()) th.thisReference = prs.copy() if name is not None: r.attr[name] = th prs.fieldThunks[name] = th else: r.children.append(th) return r # NOTE: NO WAY TO SPECIFY AN ATTRIBUTE IN A NON-STATIC WAY. # Create thunks for all entries that have everything but the "this". entry_thunks = [] for e in entries: if e.ctor == Node.T.REC_ATTR: name, label, node = e.args elif e.ctor == Node.T.REC_VALUE: name, label, node = None, None, e.args[0] th = Thunk(ast=node, closure=self.currentStateClosure()) entry_thunks.append(th) # Collect arguments to the constructor and build a thunk the call. args = [entry_thunks[i] for i, e in enumerate(entries) if e.ctor == Node.T.REC_VALUE] # TODO: Merge with an internal version of evalCall() #--- assert isinstance(ctor, RecordCtor) f = ctor.func # Check number of arguments req = str(len(f.params)) + ("+" if f.variadic is not None else "") if len(args) < len(f.params): raise JuiRuntimeError(f"not enough args (need {req}, got {len(args)})") if len(args) > len(f.params) and f.variadic is None: raise JuiRuntimeError(f"too many args (need {req}, got {len(args)})") # TODO: In order to build variadic set I need a LIST node if f.variadic is not None: raise NotImplementedError("list node for building varargs o(x_x)o") # Run into the function's scope to build the thunk with self.contextSwitchAndPushNewScope(f.closure): for name, th in zip(f.params, args): self.addDefinition(name, th) assert f.body.ctor == Node.T.SCOPE_EXPR call_thunk = Thunk(ast=f.body.args[0], closure=self.currentStateClosure()) #--- # Use the call as base for a PRS and assign "this" in all thunks. prs = PartialRecordSnapshot() prs.base = call_thunk for i, e in enumerate(entries): entry_thunks[i].thisReference = prs.copy() if e.ctor == Node.T.REC_ATTR: name, label, node = e.args prs.fieldThunks[name] = th baseRecord = self.evalThunk(call_thunk) if not isinstance(baseRecord, Record): raise Exception("record ctor did not return a record") for i, e in enumerate(entries): if e.ctor == Node.T.REC_ATTR: name, label, node = e.args baseRecord.attr[name] = entry_thunks[i] return baseRecord def force(self, v: JuiValue | Thunk) -> JuiValue: match v: case Record() as r: for a in r.attr: r.attr[a] = self.force(r.attr[a]) for i, e in enumerate(r.children): r.children[i] = self.force(e) return r case Thunk() as th: self.evalThunk(th) if th.evaluated: return self.force(th.result) return th.result case _: return v