juic: fix all errors raised by mypy

This commit is contained in:
Lephenixnoir 2024-08-28 13:58:27 +02:00
parent 4579acc0f4
commit ce39929bb4
No known key found for this signature in database
GPG key ID: 1BBA026E13FC0495
3 changed files with 30 additions and 19 deletions

View file

@ -33,7 +33,7 @@ class Function:
# Expression node to evaluate when calling the function # Expression node to evaluate when calling the function
body: "juic.parser.Node" body: "juic.parser.Node"
# Parameter names, must all be unique. May be empty # Parameter names, must all be unique. May be empty
params: list[str] = field(default_factory=[]) params: list[str] = field(default_factory=list)
# Name of variadic argument if one; must also be unique # Name of variadic argument if one; must also be unique
variadic: str | None = None variadic: str | None = None

View file

@ -253,7 +253,7 @@ class Context:
raise JuiStaticError(f"multiple variadic arguments: {names}") raise JuiStaticError(f"multiple variadic arguments: {names}")
if variadics and variadics != [len(params) - 1]: if variadics and variadics != [len(params) - 1]:
name = params[variadic][0] name = params[variadics[0]][0]
raise JuiStaticError(f"variadic argument {name} not at the end") raise JuiStaticError(f"variadic argument {name} not at the end")
variadic = params[variadics[0]][0] if variadics else None variadic = params[variadics[0]][0] if variadics else None
@ -297,7 +297,10 @@ class Context:
case "+": return x + y case "+": return x + y
case "-": return x - y case "-": return x - y
case "*": return x * y case "*": return x * y
case "/": return x / y if type(x) == float else x // y case "/":
match x, y:
case float(), float(): return (x / y)
case _, _: return (x // y)
case "%": return x % y case "%": return x % y
case Node.T.OP, [("!" | "+" | "-") as op, X]: case Node.T.OP, [("!" | "+" | "-") as op, X]:
@ -384,7 +387,6 @@ class Context:
case Node.T.REC_VALUE, _: case Node.T.REC_VALUE, _:
raise NotImplementedError raise NotImplementedError
case _, _:
raise Exception("invalid expr o(x_x)o: " + str(node)) raise Exception("invalid expr o(x_x)o: " + str(node))
def project(self, v: JuiValue, field: str) -> JuiValue: def project(self, v: JuiValue, field: str) -> JuiValue:
@ -392,14 +394,14 @@ class Context:
case Record() as r: case Record() as r:
if field not in r.attr: if field not in r.attr:
raise Exception(f"access to undefined field {field}") raise Exception(f"access to undefined field {field}")
return r.attr[field] return self.evalThunk(r.attr[field])
case PartialRecordSnapshot() as prs: case PartialRecordSnapshot() as prs:
if field in prs.fieldThunks: if field in prs.fieldThunks:
return prs.fieldThunks[field] return self.evalThunk(prs.fieldThunks[field])
elif prs.base is None: elif prs.base is None:
raise Exception(f"access to undefined field {field} of 'this'") raise Exception(f"access to undefined field {field} of 'this'")
else: else:
return self.project(prs.base, v) return self.project(self.evalThunk(prs.base), field)
case _: case _:
raise NotImplementedError # unreachable raise NotImplementedError # unreachable
@ -407,9 +409,11 @@ class Context:
match node.ctor, node.args: match node.ctor, node.args:
case Node.T.LET_DECL, [name, X]: case Node.T.LET_DECL, [name, X]:
self.addDefinition(name, self.evalExpr(X)) self.addDefinition(name, self.evalExpr(X))
return None
case Node.T.FUN_DECL, [name, params, body]: case Node.T.FUN_DECL, [name, params, body]:
self.addDefinition(name, self.makeFunction(params, body)) self.addDefinition(name, self.makeFunction(params, body))
return None
case Node.T.REC_DECL, _: case Node.T.REC_DECL, _:
raise NotImplementedError raise NotImplementedError
@ -427,8 +431,8 @@ class Context:
print(" " + juiValueString(vs)) print(" " + juiValueString(vs))
print("vs.") print("vs.")
print(" " + juiValueString(ve)) print(" " + juiValueString(ve))
return None
case _, _:
return self.evalExpr(node) return self.evalExpr(node)
# TODO: Context.eval*: continue failed computations to find other errors? # TODO: Context.eval*: continue failed computations to find other errors?
@ -498,7 +502,7 @@ class Context:
# Base record constructor: starts out with an empty record. All # Base record constructor: starts out with an empty record. All
# arguments are children. Easy. # arguments are children. Easy.
if type(ctor) == RecordType: if type(ctor) == RecordType:
r = Record(base=ctor, attr=dict(), children=dict()) r = Record(base=ctor, attr=dict(), children=[])
if len(args) > 0: if len(args) > 0:
raise JuiRuntimeError(f"arguments given to type rec ctor") raise JuiRuntimeError(f"arguments given to type rec ctor")
@ -517,6 +521,8 @@ class Context:
return r return r
assert isinstance(ctor, Function)
# TODO: Factor this with function. In fact, this should reduce to a call # TODO: Factor this with function. In fact, this should reduce to a call
# Check number of arguments # Check number of arguments
req = str(len(ctor.params)) + ("+" if ctor.variadic is not None else "") req = str(len(ctor.params)) + ("+" if ctor.variadic is not None else "")
@ -603,16 +609,16 @@ class Context:
# How about that's what Record does, and it just has accessors? # How about that's what Record does, and it just has accessors?
raise NotImplementedError raise NotImplementedError
def force(self, v: JuiValue) -> JuiValue: def force(self, v: JuiValue | Thunk) -> JuiValue:
match v: match v:
case Record() as r: case Record() as r:
for a in r.attr: for a in r.attr:
r.attr[a] = self.force(r.attr[a]) self.force(r.attr[a])
return r return r
case Thunk() as th: case Thunk() as th:
self.evalThunk(th) self.evalThunk(th)
if th.evaluated: if th.evaluated:
return self.force(th.result) return self.force(th.result)
return th return th.result
case _: case _:
return v return v

View file

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
import typing from typing import Any, Tuple
import enum import enum
import sys import sys
import re import re
@ -23,8 +23,8 @@ class SyntaxError(Exception):
@dataclass @dataclass
class Token: class Token:
type: typing.Any type: Any
value: typing.Any value: Any
loc: Loc loc: Loc
def __str__(self): def __str__(self):
@ -45,7 +45,8 @@ class NaiveRegexLexer:
# Override with list of (regex, token type, token value). Both the token # Override with list of (regex, token type, token value). Both the token
# type and value can be functions, in which case they'll be called with the # type and value can be functions, in which case they'll be called with the
# match object as parameter. # match object as parameter.
TOKEN_REGEX = [] Rule = Tuple[str, Any, Any] | Tuple[str, Any, Any, int]
TOKEN_REGEX: list[Rule] = []
# Override with token predicate that matches token to be discarded and not # Override with token predicate that matches token to be discarded and not
# sent to the parser (typically, whitespace and comments). # sent to the parser (typically, whitespace and comments).
TOKEN_DISCARD = lambda _: False TOKEN_DISCARD = lambda _: False
@ -175,6 +176,7 @@ class LL1Parser:
# Rule combinators implementing unary and binary operators with precedence # Rule combinators implementing unary and binary operators with precedence
@staticmethod
def binaryOpsLeft(ctor, ops): def binaryOpsLeft(ctor, ops):
def decorate(f): def decorate(f):
def symbol(self): def symbol(self):
@ -185,6 +187,7 @@ class LL1Parser:
return symbol return symbol
return decorate return decorate
@staticmethod
def binaryOps(ctor, ops, *, rassoc=False): def binaryOps(ctor, ops, *, rassoc=False):
def decorate(f): def decorate(f):
def symbol(self): def symbol(self):
@ -197,9 +200,11 @@ class LL1Parser:
return symbol return symbol
return decorate return decorate
@staticmethod
def binaryOpsRight(ctor, ops): def binaryOpsRight(ctor, ops):
return LL1Parser.binaryOps(ctor, ops, rassoc=True) return LL1Parser.binaryOps(ctor, ops, rassoc=True)
@staticmethod
def unaryOps(ctor, ops, assoc=True): def unaryOps(ctor, ops, assoc=True):
def decorate(f): def decorate(f):
def symbol(self): def symbol(self):
@ -276,7 +281,7 @@ class Node:
"LET_DECL", "FUN_DECL", "REC_DECL", "SET_STMT", "LET_DECL", "FUN_DECL", "REC_DECL", "SET_STMT",
"UNIT_TEST"]) "UNIT_TEST"])
ctor: T ctor: T
args: list[typing.Any] args: list[Any]
def dump(self, indent=0): def dump(self, indent=0):
print(" " * indent + self.ctor.name, end=" ") print(" " * indent + self.ctor.name, end=" ")