JustUI/juic/eval.py
2024-09-21 17:20:35 +02:00

609 lines
24 KiB
Python

from dataclasses import dataclass, field
from typing import Tuple, Union
import juic.parser
from juic.parser import Node
from juic.datatypes import *
# Storage of a scope's internal data, including all definitions in the scope.
# The data in this object is never copied, so for instance, the name -> value
# mappings are always unique and can only be found in the relevant's scope
# unique MutableScopeData instance. This data is, of course, not constant,
# which makes it unsuitable for building closures. The Closure type captures an
# immutable snapshot of a scope's contents at one point in time.
@dataclass
class MutableScopeData:
# List of definitions. Each value has a timestamp integer that indicates
# when it was added; this allows snapshots to ignore definitions newer than
# the snapshot by simply comparing timestamp values. The value itself is
# either a thunk (evaluated or not), or a runtime value.
defs: dict[str, Tuple[int, Union[JuiValue, "Thunk"]]] = \
field(default_factory=dict)
# Current timestamp, changed every time a value is added.
timestamp: int = 0
def addDefinition(self, name: str, content: Union[JuiValue, "Thunk"]):
if name in self.defs:
raise JuiRuntimeError(f"{name} already defined in current scope")
self.defs[name] = (self.timestamp, content)
self.timestamp += 1
def lookup(self, name: str, maxTimestamp: int = -1) \
-> Union[JuiValue, "Thunk", None]:
if name not in self.defs:
return None
if maxTimestamp >= 0 and self.defs[name][0] >= maxTimestamp:
return None
return self.defs[name][1]
def dump(self):
digits = len(str(self.timestamp))
print(f"Scope: timestamp={self.timestamp}, defs:")
for name, (ts, content) in self.defs.items():
print(f" {ts: >{digits}} {name}:", juiValueString(content))
# An immutable snapshot of all definitions that can be referred to at a given
# point in time. This specifies the interpretation of all defined identifiers
# but not of the "this" keyword.
@dataclass
class Closure:
# Parent snapshot on which this one is based, which defines the contents of
# surrounding scopes at that time.
parent: Union["Closure", None]
# Reference to the scope in which the closure was created.
scope: MutableScopeData
# Timestamp at which the closure was created. Definitions from `scope`
# created later than the closure will be ignored during name lookup.
timestamp: int = field(init=False)
def __post_init__(self):
# Automatically get the timestamp at the time of creation
self.timestamp = self.scope.timestamp
def lookup(self, name: str) -> Union[JuiValue, "Thunk", None]:
value = self.scope.lookup(name, self.timestamp)
if value is not None or self.parent is None:
return value
return self.parent.lookup(name)
def dump(self, level=0):
print(f"Closure at depth {level}: timestamp={self.timestamp}:")
self.scope.dump()
if self.parent is not None:
self.parent.dump(level+1)
# Thunk representing a suspended/lazy computation. Thunks are evaluated when
# needed (which is the call-by-need evaluation strategy). They are always
# interpreted in the context of the associated closure.
#
# While the object is technically mutable, the result of the computation
# (whether a value or a cyclic invalid result) is entirely determined by the
# expression AST and closure, so the result is just a lazily computed/memoized
# member. That said, the result might not exist if there is a cyclic dependency
# between this and other thunks referenced by the closure.
@dataclass
class Thunk:
# Expression AST that the thunk evaluates to
ast: Node
# Closure in which the thunk can be evaluated, plus the definition for
# "this"; if there is one, it is always a PRS
closure: Closure
thisReference: Union[None, "PartialRecordSnapshot"] = None
# Whether the thunk has been evaluated yet
evaluated: bool = False
# Whether the thunk has been evaluated to an invalid result because of a
# cyclic dependency with other thunks captured by the closure
invalid: bool = False
# If the thunk has been evaluated and is not invalid, resulting value. None
# is a possible resulting value (`null` in the language); do not use this
# field to determine whether the thunk has been successfully evaluted.
result: JuiValue = None
# Whether the thunk is currently under evaluation
_running: bool = False
def __str__(self):
s = "Thunk("
if self._running:
s += "running, "
if self.evaluated:
s += "evaluated"
if self.invalid:
s += ", invalid"
else:
s += " = " + juiValueString(self.result)
else:
s += "unevaluated"
s += "){ " + str(self.ast) + " }"
return s
# Object representing the interpretation of a partially-built record. In a
# record construction or update operation, fields of `this` refer to the latest
# definition before the point of use, if there is one, or the definition in the
# underlying record otherwise, if there is one. Because the use of record
# constructor arguments can create forwards dependencies, all these definitions
# are thunked and computed lazily in a somewhat unpredictable order. A partial
# record snapshot fully specifies the interpretation of all fields of a
# partially-constructed record (i.e. `this`) at the location of one record
# entry. It does so by mapping defined field names to the correct thunks. This
# object is how dependencies between fields are discovered and followed.
@dataclass
class PartialRecordSnapshot:
# Mapping from fields of `this` to thunks
fieldThunks: dict[str, Thunk] = field(default_factory=dict)
# Base object for fields not captured in `fieldThunks`
base: None | Thunk = None
def __str__(self):
return "<PartialRecordSnapshot>" # TODO
def copy(self):
return PartialRecordSnapshot(self.fieldThunks.copy(), self.base)
class JuiTypeError(Exception):
pass
class JuiNameError(Exception):
pass
class JuiStaticError(Exception):
pass
class JuiRuntimeError(Exception):
pass
# TODO: Better diagnostics for type errors
def requireType(v: JuiValue, predicate: Callable[[JuiValue], bool]):
if not predicate(v):
raise JuiTypeError(f"type error: got {v}, needed {predicate}")
def requireSameType(vs: list[JuiValue], predicate: Callable[[JuiValue], bool]):
if len(vs) == 0:
return
if len(set(type(v) for v in vs)) != 1:
vals = ", ".join(str(v) for v in vs)
raise JuiTypeError(f"type error: heterogeneous types for {vals}")
if not predicate(vs[0]):
raise JuiTypeError(f"type error: got {vs[0]}, needed {predicate}")
# Evaluation context. This tracks the current scope and provides evaluation
# functions, error tracking, etc.
class Context:
# Current top-level scope. Can be None. This is the only scope in which we
# can add definitions.
currentScope: MutableScopeData
# Current closure defining the scopes that surround `currentScope`.
currentClosure: Closure | None
def __init__(self, initialClosure = None):
self.currentScope = MutableScopeData()
self.currentClosure = initialClosure
self._contextStack = []
#=== Context switch commodity ===#
_contextStack: list[Tuple[MutableScopeData, Closure]]
class ContextSwitchContextManager:
def __init__(self, parent: "Context", c: Closure):
self.parent = parent
self.targetClosure = c
def __enter__(self):
p = self.parent
p._contextStack.append((p.currentScope, p.currentClosure))
p.currentScope = MutableScopeData()
p.currentClosure = self.targetClosure
def __exit__(self, exc_type, exc_value, traceback):
p = self.parent
s, c = p._contextStack.pop()
p.currentScope = s
p.currentClosure = c
# Temporarily switch the execution context. This is used when we need e.g.
# to evaluate a thunk in its associated closure. Morally one could
# construct a different Context object, but it's easier to accumulate error
# messages and other things in the same Context object.
# Example:
# with self.contextSwitchAndPushNewScope(myClosure):
# myResult = self.eval(myExprNode)
def contextSwitchAndPushNewScope(self, c: Closure):
return Context.ContextSwitchContextManager(self, c)
# Freeze the current scope and push a new empty one to start working in.
# Example:
# with self.pushNewScope():
# self.addDefinition(myParamName, myParamValue)
# myResult = self.eval(myFunctionBody)
def pushNewScope(self):
return self.contextSwitchAndPushNewScope(self.currentStateClosure())
#=== State management ===#
# Build a the closure of the current state
def currentStateClosure(self):
return Closure(parent=self.currentClosure, scope=self.currentScope)
# Add a new definition in the innermost scope
def addDefinition(self, name: str, content: JuiValue | Thunk):
return self.currentScope.addDefinition(name, content)
# Lookup a name
def lookup(self, name: str) -> JuiValue | Thunk | None:
return self.currentStateClosure().lookup(name)
#=== Evaluation helpers ===#
def makeFunction(self, params: list[Tuple[str, bool]], body: Node) \
-> Function:
# Check validity of the variadic parameter specification
variadics = [i for i, (_, v, _) in enumerate(params) if v]
for (name, _, _) in params:
count = len([0 for (n, _, _) in params if n == name])
if count > 1:
raise JuiStaticError(f"duplicate argument {name}")
if len(variadics) > 1:
names = ", ".join(params[i][0] for i in variadics)
raise JuiStaticError(f"multiple variadic arguments: {names}")
if variadics and variadics != [len(params) - 1]:
name = params[variadics[0]][0]
raise JuiStaticError(f"variadic argument {name} not at the end")
variadic = params[variadics[0]][0] if variadics else None
# Build function object
# TODO: Make a full scope not just an expr
return Function(closure = self.currentStateClosure(),
body = Node(Node.T.SCOPE_EXPR, [body]),
params = [name for (name, v, _) in params if not v],
variadic = variadic)
def makeRecordCtor(self, params: list[Tuple[str, bool]], body: Node) \
-> RecordCtor:
return RecordCtor(func=self.makeFunction(params, body))
#=== Main evaluation functions ===#
# Expression evaluator; this function is pure. It can cause thunks to be
# evaluated and thus reveal previously-undiscovered cyclic dependencies,
# but it doesn't modify the context.
def evalExpr(self, node: Node) -> JuiValue:
match node.ctor, node.args:
case Node.T.LIT, [x]:
return x
case Node.T.IDENT, [name]:
match self.lookup(name):
case None:
raise JuiNameError(f"name {name} not defined")
case Thunk() as t:
return self.evalThunk(t)
case value:
return value
case Node.T.OP, ["+", X, Y]:
x = self.evalExpr(X)
y = self.evalExpr(Y)
requireSameType([x, y], juiIsAddable)
return x + y
# TODO: Arithmetic: integer division, type conversions
case Node.T.OP, [("+" | "-" | "*" | "/" | "%") as op, X, Y]:
x = self.evalExpr(X)
y = self.evalExpr(Y)
requireSameType([x, y], juiIsArith)
match op:
case "+": return x + y
case "-": return x - y
case "*": return x * y
case "/":
match x, y:
case float(), float(): return (x / y)
case _, _: return (x // y)
case "%": return x % y
case Node.T.OP, [("!" | "+" | "-") as op, X]:
x = self.evalExpr(X)
requireType(x, juiIsArith)
match op:
case "!": return not x
case "+": return +x
case "-": return -x
case Node.T.OP, [(">"|">="|"<"|"<="|"=="|"!=") as op, X, Y]:
x = self.evalExpr(X)
y = self.evalExpr(Y)
requireSameType([x, y], juiIsComparable)
match op:
case ">": return x > y
case "<": return x < y
case ">=": return x >= y
case "<=": return x <= y
case "==": return x == y
case "!=": return x != y
case Node.T.OP, [("&&" | "||") as op, X, Y]:
x = self.evalExpr(X)
y = self.evalExpr(Y)
requireSameType([x, y], juiIsLogical)
match op:
case "&&": return x and y
case "||": return x or y
case Node.T.OP, ["...", X]:
x = self.evalExpr(X)
requireType(x, juiIsUnpackable)
raise NotImplementedError("unpack operator o(x_x)o")
case Node.T.OP, [("|"|"|>"), X, F]:
f = self.evalExpr(F)
requireType(f, juiIsCallable)
return self.evalCall(f, [X])
case Node.T.OP, ["<|", F, X]:
f = self.evalExpr(F)
requireType(f, juiIsCallable)
return self.evalCall(f, [X])
case Node.T.THIS, []:
v = self.lookup("this")
if v is None:
raise JuiNameError(f"no 'this' in current context")
assert isinstance(v, PartialRecordSnapshot)
return v
case Node.T.PROJ, [R, field]:
r = self.evalExpr(R)
requireType(r, juiIsProjectable)
f = self.project(r, field)
if isinstance(f, Thunk):
return self.evalThunk(f)
else:
return f
case Node.T.CALL, [F, *A]:
f = self.evalExpr(F)
requireType(f, juiIsCallable)
return self.evalCall(f, A)
case Node.T.IF, [C, T, E]:
c = self.evalExpr(C)
requireType(c, juiIsLogical)
if bool(c):
return self.evalExpr(T)
else:
return None if E is None else self.evalExpr(E)
case Node.T.RECORD, [R, *A]:
r = self.evalExpr(R)
requireType(r, juiIsConstructible)
return self.evalRecordConstructor(r, A)
case Node.T.REC_ATTR, _:
raise NotImplementedError
case Node.T.REC_VALUE, args:
raise NotImplementedError
raise Exception("invalid expr o(x_x)o: " + str(node))
def project(self, v: JuiValue, field: str) -> JuiValue:
match v:
case Record() as r:
if field not in r.attr:
raise Exception(f"access to undefined field {field}")
return self.evalValueOrThunk(r.attr[field])
case PartialRecordSnapshot() as prs:
if field in prs.fieldThunks:
return self.evalThunk(prs.fieldThunks[field])
elif prs.base is None:
raise Exception(f"access to undefined field {field} of 'this'")
else:
return self.project(self.evalThunk(prs.base), field)
case _:
raise NotImplementedError # unreachable
def execStmt(self, node: Node) -> JuiValue:
match node.ctor, node.args:
case Node.T.LET_DECL, [name, X, let_type]:
self.addDefinition(name, self.evalExpr(X))
# TODO: Check type?
return None
case Node.T.FUN_DECL, [name, params, body, body_type]:
self.addDefinition(name, self.makeFunction(params, body))
# TODO: Check type when called?
return None
case Node.T.REC_DECL, [name, params, body, body_type]:
self.addDefinition(name, self.makeRecordCtor(params, body))
return None
case Node.T.SET_STMT, _:
raise NotImplementedError
case Node.T.UNIT_TEST, [subject, expected]:
vs = self.evalExpr(subject)
ve = self.evalExpr(expected)
vs = self.force(vs)
ve = self.force(ve)
if not juiValuesEqual(vs, ve):
print("unit test failed:")
print(" " + juiValueString(vs))
print("vs.")
print(" " + juiValueString(ve))
return None
case Node.T.SCOPE_EXPR, [e]:
return self.evalExpr(e)
raise Exception(f"execStmt: unrecognized node {node.ctor.name}")
# TODO: Context.eval*: continue failed computations to find other errors?
def evalThunk(self, th: Thunk) -> JuiValue:
if th.evaluated:
if th.invalid:
raise JuiRuntimeError("cyclic dependency result encountered")
return th.result
if th._running:
raise JuiRuntimeError("cyclic dependency detected!")
with self.contextSwitchAndPushNewScope(th.closure):
if th.thisReference is not None:
self.addDefinition("this", th.thisReference)
th._running = True
result = self.evalExpr(th.ast)
th._running = False
# TODO: Set invalid = True if we continue failed computations
th.evaluated = True
th.result = result
return result
def evalValueOrThunk(self, v: JuiValue | Thunk) -> JuiValue:
return self.evalThunk(v) if isinstance(v, Thunk) else v
def evalCall(self, f: JuiValue, args: list[Node]) -> JuiValue:
# Built-in functions: just evaluate arguments and go
# TODO: Check types of built-in function calls
if type(f) == BuiltinFunction:
return f.func(*[self.evalExpr(a) for a in args])
assert type(f) == Function and "evalCall: bad type check precondition"
# Check number of arguments
req = str(len(f.params)) + ("+" if f.variadic is not None else "")
if len(args) < len(f.params):
raise JuiRuntimeError(f"not enough args (need {req}, got {len(args)})")
if len(args) > len(f.params) and f.variadic is None:
raise JuiRuntimeError(f"too many args (need {req}, got {len(args)})")
# TODO: In order to build variadic set I need a LIST node
if f.variadic is not None:
raise NotImplementedError("list node for building varargs o(x_x)o")
# Run into the function's scope
with self.contextSwitchAndPushNewScope(f.closure):
for name, node in zip(f.params, args):
th = Thunk(ast=node, closure=self.currentStateClosure())
self.addDefinition(name, th)
# self.currentScope.dump()
# self.currentClosure.dump()
assert f.body.ctor == Node.T.SCOPE_EXPR
return self.execStmt(f.body)
def evalRecordConstructor(self, ctor: JuiValue, entries: list[Node]) \
-> JuiValue:
# Base record constructor: starts out with an empty record. All
# arguments are children. Easy.
if isinstance(ctor, RecordType):
r = Record(base=ctor, attr=dict(), children=[])
# Create thunks for all entries while providing them with
# progressively more complete PRS of r.
prs = PartialRecordSnapshot()
for i, e in enumerate(entries):
if e.ctor == Node.T.REC_ATTR:
name, label, node = e.args
elif e.ctor == Node.T.REC_VALUE:
name, label, node = None, None, e.args[0]
th = Thunk(ast=node, closure=self.currentStateClosure())
th.thisReference = prs.copy()
if name is not None:
r.attr[name] = th
prs.fieldThunks[name] = th
else:
r.children.append(th)
return r
# NOTE: NO WAY TO SPECIFY AN ATTRIBUTE IN A NON-STATIC WAY.
# Create thunks for all entries that have everything but the "this".
entry_thunks = []
for e in entries:
if e.ctor == Node.T.REC_ATTR:
name, label, node = e.args
elif e.ctor == Node.T.REC_VALUE:
name, label, node = None, None, e.args[0]
th = Thunk(ast=node, closure=self.currentStateClosure())
entry_thunks.append(th)
# Collect arguments to the constructor and build a thunk the call.
args = [entry_thunks[i]
for i, e in enumerate(entries) if e.ctor == Node.T.REC_VALUE]
# TODO: Merge with an internal version of evalCall()
#---
assert isinstance(ctor, RecordCtor)
f = ctor.func
# Check number of arguments
req = str(len(f.params)) + ("+" if f.variadic is not None else "")
if len(args) < len(f.params):
raise JuiRuntimeError(f"not enough args (need {req}, got {len(args)})")
if len(args) > len(f.params) and f.variadic is None:
raise JuiRuntimeError(f"too many args (need {req}, got {len(args)})")
# TODO: In order to build variadic set I need a LIST node
if f.variadic is not None:
raise NotImplementedError("list node for building varargs o(x_x)o")
# Run into the function's scope to build the thunk
with self.contextSwitchAndPushNewScope(f.closure):
for name, th in zip(f.params, args):
self.addDefinition(name, th)
assert f.body.ctor == Node.T.SCOPE_EXPR
call_thunk = Thunk(ast=f.body.args[0],
closure=self.currentStateClosure())
#---
# Use the call as base for a PRS and assign "this" in all thunks.
prs = PartialRecordSnapshot()
prs.base = call_thunk
for i, e in enumerate(entries):
entry_thunks[i].thisReference = prs.copy()
if e.ctor == Node.T.REC_ATTR:
name, label, node = e.args
prs.fieldThunks[name] = th
baseRecord = self.evalThunk(call_thunk)
if not isinstance(baseRecord, Record):
raise Exception("record ctor did not return a record")
for i, e in enumerate(entries):
if e.ctor == Node.T.REC_ATTR:
name, label, node = e.args
baseRecord.attr[name] = entry_thunks[i]
return baseRecord
def force(self, v: JuiValue | Thunk) -> JuiValue:
match v:
case Record() as r:
for a in r.attr:
r.attr[a] = self.force(r.attr[a])
for i, e in enumerate(r.children):
r.children[i] = self.force(e)
return r
case Thunk() as th:
self.evalThunk(th)
if th.evaluated:
return self.force(th.result)
return th.result
case _:
return v