various type inference fixes, mutually recursive functions now resolve correctly

This commit is contained in:
Alex Shinn 2010-08-01 18:12:49 +09:00
parent 2fe2e9f002
commit fbf7319a30

View file

@ -1,3 +1,7 @@
;; type-inference.scm -- general type-inference for Scheme
;;
;; Copyright (c) 2010 Alex Shinn. All rights reserved.
;; BSD-style license: http://synthcode.com/license.txt
(define (typed? x)
(and (lambda? x)
@ -10,7 +14,13 @@
(and (pair? a) (equal? (car a) 'and)))
(define (unfinalized-type? a)
(and (pair? a) (memq (car a) '(return-type param-type))))
(and (pair? a)
(or (memq (car a) '(return-type param-type))
(and (memq (car a) '(and or))
(any unfinalized-type? (cdr a))))))
(define (finalized-type? a)
(not (unfinalized-type? a)))
(define (numeric-type? a)
(or (eq? a <number>) (eq? a <flonum>) (eq? a <integer>)))
@ -55,6 +65,17 @@
(cons (car a) (lset-adjoin equal? (cdr a) b))))
(else (list 'and a b))))
(define (type-of x)
(cond ((boolean? x) <boolean>)
((char? x) <char>)
((symbol? x) <symbol>)
((string? x) <string>)
((and (integer? x) (exact? x)) <integer>)
((flonum? x) <flonum>)
((pair? x) <pair>)
((vector? x) <vector>)
(else <object>)))
(define (lambda-param-types-initialize! f)
(lambda-param-types-set! f (map (lambda (p) (list 'param-type f p))
(lambda-params f))))
@ -79,24 +100,32 @@
=> (lambda (cell) (set-car! cell y)))))
(define (type-analyze-expr x)
;;(write `(type-analyze-expr ,x ,(ast->sexp x)) (current-error-port)) (newline (current-error-port))
(match x
(($ <lam> name params body defs)
(lambda-return-type-set! x (list 'return-type x))
(lambda-param-types-initialize! x)
(let ((ret-type (type-analyze-expr body)))
(lambda-return-type-set! x ret-type)
(cons 'lambda (cons ret-type (lambda-param-types x)))))
(cond
((not (lambda-return-type x))
(lambda-return-type-set! x (list 'return-type x))
(lambda-param-types-initialize! x)
(let ((ret-type (type-analyze-expr body)))
(lambda-return-type-set! x ret-type)
(cons 'lambda (cons ret-type (lambda-param-types x)))))))
(($ <set> ref value)
(type-analyze-expr value)
(if #f #f))
(($ <ref> name (_ . loc) source)
(if (lambda? loc)
(lambda-param-type-ref loc name)
<object>))
(($ <ref> name (value . loc) source)
(cond
((lambda? loc) (lambda-param-type-ref loc name))
((procedure? loc)
(let ((sig (procedure-signature loc)))
(if (and (pair? sig) (car sig))
(cons 'lambda sig)
(list 'return-type (procedure-analysis loc)))))
(else <object>)))
(($ <cnd> test pass fail)
(type-analyze-expr test)
(type-union (type-analyze-expr pass) (type-analyze-expr fail)))
(let ((test-type (type-analyze-expr test))
(pass-type (type-analyze-expr pass))
(fail-type (type-analyze-expr fail)))
(type-union pass-type fail-type)))
(($ <seq> ls)
(let lp ((ls ls))
(cond ((null? (cdr ls))
@ -124,7 +153,10 @@
(else
(let ((t (type-analyze-expr (car a))))
(cond
((not (type-subset? t p-type))
((and t p-type
(finalized-type? t)
(finalized-type? p-type)
(not (type-subset? t p-type)))
(display "WARNING: incompatible type: "
(current-error-port))
(write (list x t p-type) (current-error-port))
@ -138,32 +170,76 @@
(let ((f-type (type-analyze-expr f)))
;; XXXX apply f-type to params
(for-each type-analyze-expr args)
(if (and (pair? f-type) (eq? 'lambda (car f-type)))
(cadr f-type)
<object>)))))
(cond
((and (pair? f-type) (eq? (car f-type) 'lambda))
(cadr f-type))
((and (pair? f-type) (memq (car f-type) '(return-type param-type)))
f-type)
(else
<object>))))))
(else
;;(write `(unknown type ,x) (current-error-port)) (newline (current-error-port))
<object>)))
(type-of x))))
(define (resolve-delayed-type x)
(let lp ((x x) (seen '()) (default <object>))
(match x
(('return-type f)
(if (memq f seen)
default
(lp (lambda-return-type f) (cons f seen) default)))
(('param-type f p)
(if (member x seen)
default
(lp (lambda-param-type-ref f p) (cons x seen) default)))
(('or y ...)
(let ((z (find finalized-type? y)))
(if z
(let ((default (if (eq? default <object>)
(lp z seen default)
(type-union (lp z seen default) default))))
(fold type-union
default
(map (lambda (y1) (lp y1 seen default)) (delete z y))))
(fold type-union default (map (lambda (y1) (lp y1 seen default)) y)))))
(('and y ...)
(fold type-intersection default (map (lambda (y1) (lp y1 seen default)) y)))
(('not y)
(list 'not (lp y seen default)))
(else
x))))
(define (type-resolve-circularities x)
#f)
(match x
(($ <lam> name params body defs)
(if (unfinalized-type? (lambda-return-type x))
(lambda-return-type-set! x (resolve-delayed-type
(lambda-return-type x))))
(for-each
(lambda (p t)
(if (unfinalized-type? t)
(lambda-param-type-set! x p (resolve-delayed-type t))))
params
(lambda-param-types x))
(type-resolve-circularities (lambda-body x)))
(($ <set> ref value)
(type-resolve-circularities value))
(($ <cnd> test pass fail)
(type-resolve-circularities test)
(type-resolve-circularities pass)
(type-resolve-circularities fail))
(($ <seq> ls)
(for-each type-resolve-circularities ls))
((app ...)
(for-each type-resolve-circularities app))
(else #f)))
;; basic type inference on the body of a module
;; - internal references are to lambdas
;; - external references are to procedures (with completed type info)
;; - for each lambda
;; + add parameter constraints (intersection) from body
;; + add return type constaints (union) from last form(s)
;; - when complete, resolve cycles (e.g. even/odd => boolean)
(define (type-analyze-module-body name ls)
;;(write `(type-analyze-module-body ,name) (current-error-port)) (newline (current-error-port))
(for-each type-analyze-expr ls)
(for-each type-resolve-circularities ls))
(define (type-analyze-module name)
(let* ((mod (analyze-module name))
(ls (and (vector? mod) (module-ast mod))))
;;(write `(analyzing ,ls) (current-error-port)) (newline (current-error-port))
(and ls
(let ((x (let lp ((ls ls)) ;; first lambda
(and (pair? ls)
@ -184,9 +260,15 @@
res
(lp (- n 1) (cons (opcode-param-type x n) res)))))
(define (opcode-type x)
(cons 'lambda (cons (opcode-return-type x) (opcode-param-types x))))
(define (lambda-type x)
(cons 'lambda (cons (lambda-return-type x) (lambda-param-types x))))
(define (procedure-signature x)
(if (opcode? x)
(cons (opcode-return-type x) (opcode-param-types x))
(cdr (opcode-type x))
(let lp ((count 0))
(let ((lam (procedure-analysis x)))
(cond
@ -196,7 +278,6 @@
(and (type-analyze-module (car mod))
(lp (+ count 1)))))
((lambda? lam)
(cons (lambda-return-type lam)
(lambda-param-types lam)))
(cdr (lambda-type lam)))
(else
#f))))))