juic: fix all errors raised by mypy

This commit is contained in:
Lephenixnoir 2024-08-28 13:58:27 +02:00
parent 4579acc0f4
commit ce39929bb4
No known key found for this signature in database
GPG key ID: 1BBA026E13FC0495
3 changed files with 30 additions and 19 deletions

View file

@ -33,7 +33,7 @@ class Function:
# Expression node to evaluate when calling the function
body: "juic.parser.Node"
# Parameter names, must all be unique. May be empty
params: list[str] = field(default_factory=[])
params: list[str] = field(default_factory=list)
# Name of variadic argument if one; must also be unique
variadic: str | None = None

View file

@ -253,7 +253,7 @@ class Context:
raise JuiStaticError(f"multiple variadic arguments: {names}")
if variadics and variadics != [len(params) - 1]:
name = params[variadic][0]
name = params[variadics[0]][0]
raise JuiStaticError(f"variadic argument {name} not at the end")
variadic = params[variadics[0]][0] if variadics else None
@ -297,7 +297,10 @@ class Context:
case "+": return x + y
case "-": return x - y
case "*": return x * y
case "/": return x / y if type(x) == float else x // y
case "/":
match x, y:
case float(), float(): return (x / y)
case _, _: return (x // y)
case "%": return x % y
case Node.T.OP, [("!" | "+" | "-") as op, X]:
@ -384,22 +387,21 @@ class Context:
case Node.T.REC_VALUE, _:
raise NotImplementedError
case _, _:
raise Exception("invalid expr o(x_x)o: " + str(node))
raise Exception("invalid expr o(x_x)o: " + str(node))
def project(self, v: JuiValue, field: str) -> JuiValue:
match v:
case Record() as r:
if field not in r.attr:
raise Exception(f"access to undefined field {field}")
return r.attr[field]
return self.evalThunk(r.attr[field])
case PartialRecordSnapshot() as prs:
if field in prs.fieldThunks:
return prs.fieldThunks[field]
return self.evalThunk(prs.fieldThunks[field])
elif prs.base is None:
raise Exception(f"access to undefined field {field} of 'this'")
else:
return self.project(prs.base, v)
return self.project(self.evalThunk(prs.base), field)
case _:
raise NotImplementedError # unreachable
@ -407,9 +409,11 @@ class Context:
match node.ctor, node.args:
case Node.T.LET_DECL, [name, X]:
self.addDefinition(name, self.evalExpr(X))
return None
case Node.T.FUN_DECL, [name, params, body]:
self.addDefinition(name, self.makeFunction(params, body))
return None
case Node.T.REC_DECL, _:
raise NotImplementedError
@ -427,9 +431,9 @@ class Context:
print(" " + juiValueString(vs))
print("vs.")
print(" " + juiValueString(ve))
return None
case _, _:
return self.evalExpr(node)
return self.evalExpr(node)
# TODO: Context.eval*: continue failed computations to find other errors?
def evalThunk(self, th: Thunk) -> JuiValue:
@ -498,7 +502,7 @@ class Context:
# Base record constructor: starts out with an empty record. All
# arguments are children. Easy.
if type(ctor) == RecordType:
r = Record(base=ctor, attr=dict(), children=dict())
r = Record(base=ctor, attr=dict(), children=[])
if len(args) > 0:
raise JuiRuntimeError(f"arguments given to type rec ctor")
@ -517,6 +521,8 @@ class Context:
return r
assert isinstance(ctor, Function)
# TODO: Factor this with function. In fact, this should reduce to a call
# Check number of arguments
req = str(len(ctor.params)) + ("+" if ctor.variadic is not None else "")
@ -603,16 +609,16 @@ class Context:
# How about that's what Record does, and it just has accessors?
raise NotImplementedError
def force(self, v: JuiValue) -> JuiValue:
def force(self, v: JuiValue | Thunk) -> JuiValue:
match v:
case Record() as r:
for a in r.attr:
r.attr[a] = self.force(r.attr[a])
self.force(r.attr[a])
return r
case Thunk() as th:
self.evalThunk(th)
if th.evaluated:
return self.force(th.result)
return th
return th.result
case _:
return v

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass
import typing
from typing import Any, Tuple
import enum
import sys
import re
@ -23,8 +23,8 @@ class SyntaxError(Exception):
@dataclass
class Token:
type: typing.Any
value: typing.Any
type: Any
value: Any
loc: Loc
def __str__(self):
@ -45,7 +45,8 @@ class NaiveRegexLexer:
# Override with list of (regex, token type, token value). Both the token
# type and value can be functions, in which case they'll be called with the
# match object as parameter.
TOKEN_REGEX = []
Rule = Tuple[str, Any, Any] | Tuple[str, Any, Any, int]
TOKEN_REGEX: list[Rule] = []
# Override with token predicate that matches token to be discarded and not
# sent to the parser (typically, whitespace and comments).
TOKEN_DISCARD = lambda _: False
@ -175,6 +176,7 @@ class LL1Parser:
# Rule combinators implementing unary and binary operators with precedence
@staticmethod
def binaryOpsLeft(ctor, ops):
def decorate(f):
def symbol(self):
@ -185,6 +187,7 @@ class LL1Parser:
return symbol
return decorate
@staticmethod
def binaryOps(ctor, ops, *, rassoc=False):
def decorate(f):
def symbol(self):
@ -197,9 +200,11 @@ class LL1Parser:
return symbol
return decorate
@staticmethod
def binaryOpsRight(ctor, ops):
return LL1Parser.binaryOps(ctor, ops, rassoc=True)
@staticmethod
def unaryOps(ctor, ops, assoc=True):
def decorate(f):
def symbol(self):
@ -276,7 +281,7 @@ class Node:
"LET_DECL", "FUN_DECL", "REC_DECL", "SET_STMT",
"UNIT_TEST"])
ctor: T
args: list[typing.Any]
args: list[Any]
def dump(self, indent=0):
print(" " * indent + self.ctor.name, end=" ")