various type inference fixes, mutually recursive functions now resolve correctly

This commit is contained in:
Alex Shinn 2010-08-01 18:12:49 +09:00
parent 2fe2e9f002
commit fbf7319a30

View file

@ -1,3 +1,7 @@
;; type-inference.scm -- general type-inference for Scheme
;;
;; Copyright (c) 2010 Alex Shinn. All rights reserved.
;; BSD-style license: http://synthcode.com/license.txt
(define (typed? x) (define (typed? x)
(and (lambda? x) (and (lambda? x)
@ -10,7 +14,13 @@
(and (pair? a) (equal? (car a) 'and))) (and (pair? a) (equal? (car a) 'and)))
(define (unfinalized-type? a) (define (unfinalized-type? a)
(and (pair? a) (memq (car a) '(return-type param-type)))) (and (pair? a)
(or (memq (car a) '(return-type param-type))
(and (memq (car a) '(and or))
(any unfinalized-type? (cdr a))))))
(define (finalized-type? a)
(not (unfinalized-type? a)))
(define (numeric-type? a) (define (numeric-type? a)
(or (eq? a <number>) (eq? a <flonum>) (eq? a <integer>))) (or (eq? a <number>) (eq? a <flonum>) (eq? a <integer>)))
@ -55,6 +65,17 @@
(cons (car a) (lset-adjoin equal? (cdr a) b)))) (cons (car a) (lset-adjoin equal? (cdr a) b))))
(else (list 'and a b)))) (else (list 'and a b))))
(define (type-of x)
(cond ((boolean? x) <boolean>)
((char? x) <char>)
((symbol? x) <symbol>)
((string? x) <string>)
((and (integer? x) (exact? x)) <integer>)
((flonum? x) <flonum>)
((pair? x) <pair>)
((vector? x) <vector>)
(else <object>)))
(define (lambda-param-types-initialize! f) (define (lambda-param-types-initialize! f)
(lambda-param-types-set! f (map (lambda (p) (list 'param-type f p)) (lambda-param-types-set! f (map (lambda (p) (list 'param-type f p))
(lambda-params f)))) (lambda-params f))))
@ -79,24 +100,32 @@
=> (lambda (cell) (set-car! cell y))))) => (lambda (cell) (set-car! cell y)))))
(define (type-analyze-expr x) (define (type-analyze-expr x)
;;(write `(type-analyze-expr ,x ,(ast->sexp x)) (current-error-port)) (newline (current-error-port))
(match x (match x
(($ <lam> name params body defs) (($ <lam> name params body defs)
(lambda-return-type-set! x (list 'return-type x)) (cond
(lambda-param-types-initialize! x) ((not (lambda-return-type x))
(let ((ret-type (type-analyze-expr body))) (lambda-return-type-set! x (list 'return-type x))
(lambda-return-type-set! x ret-type) (lambda-param-types-initialize! x)
(cons 'lambda (cons ret-type (lambda-param-types x))))) (let ((ret-type (type-analyze-expr body)))
(lambda-return-type-set! x ret-type)
(cons 'lambda (cons ret-type (lambda-param-types x)))))))
(($ <set> ref value) (($ <set> ref value)
(type-analyze-expr value) (type-analyze-expr value)
(if #f #f)) (if #f #f))
(($ <ref> name (_ . loc) source) (($ <ref> name (value . loc) source)
(if (lambda? loc) (cond
(lambda-param-type-ref loc name) ((lambda? loc) (lambda-param-type-ref loc name))
<object>)) ((procedure? loc)
(let ((sig (procedure-signature loc)))
(if (and (pair? sig) (car sig))
(cons 'lambda sig)
(list 'return-type (procedure-analysis loc)))))
(else <object>)))
(($ <cnd> test pass fail) (($ <cnd> test pass fail)
(type-analyze-expr test) (let ((test-type (type-analyze-expr test))
(type-union (type-analyze-expr pass) (type-analyze-expr fail))) (pass-type (type-analyze-expr pass))
(fail-type (type-analyze-expr fail)))
(type-union pass-type fail-type)))
(($ <seq> ls) (($ <seq> ls)
(let lp ((ls ls)) (let lp ((ls ls))
(cond ((null? (cdr ls)) (cond ((null? (cdr ls))
@ -124,7 +153,10 @@
(else (else
(let ((t (type-analyze-expr (car a)))) (let ((t (type-analyze-expr (car a))))
(cond (cond
((not (type-subset? t p-type)) ((and t p-type
(finalized-type? t)
(finalized-type? p-type)
(not (type-subset? t p-type)))
(display "WARNING: incompatible type: " (display "WARNING: incompatible type: "
(current-error-port)) (current-error-port))
(write (list x t p-type) (current-error-port)) (write (list x t p-type) (current-error-port))
@ -138,32 +170,76 @@
(let ((f-type (type-analyze-expr f))) (let ((f-type (type-analyze-expr f)))
;; XXXX apply f-type to params ;; XXXX apply f-type to params
(for-each type-analyze-expr args) (for-each type-analyze-expr args)
(if (and (pair? f-type) (eq? 'lambda (car f-type))) (cond
(cadr f-type) ((and (pair? f-type) (eq? (car f-type) 'lambda))
<object>))))) (cadr f-type))
((and (pair? f-type) (memq (car f-type) '(return-type param-type)))
f-type)
(else
<object>))))))
(else (else
;;(write `(unknown type ,x) (current-error-port)) (newline (current-error-port)) (type-of x))))
<object>)))
(define (resolve-delayed-type x)
(let lp ((x x) (seen '()) (default <object>))
(match x
(('return-type f)
(if (memq f seen)
default
(lp (lambda-return-type f) (cons f seen) default)))
(('param-type f p)
(if (member x seen)
default
(lp (lambda-param-type-ref f p) (cons x seen) default)))
(('or y ...)
(let ((z (find finalized-type? y)))
(if z
(let ((default (if (eq? default <object>)
(lp z seen default)
(type-union (lp z seen default) default))))
(fold type-union
default
(map (lambda (y1) (lp y1 seen default)) (delete z y))))
(fold type-union default (map (lambda (y1) (lp y1 seen default)) y)))))
(('and y ...)
(fold type-intersection default (map (lambda (y1) (lp y1 seen default)) y)))
(('not y)
(list 'not (lp y seen default)))
(else
x))))
(define (type-resolve-circularities x) (define (type-resolve-circularities x)
#f) (match x
(($ <lam> name params body defs)
(if (unfinalized-type? (lambda-return-type x))
(lambda-return-type-set! x (resolve-delayed-type
(lambda-return-type x))))
(for-each
(lambda (p t)
(if (unfinalized-type? t)
(lambda-param-type-set! x p (resolve-delayed-type t))))
params
(lambda-param-types x))
(type-resolve-circularities (lambda-body x)))
(($ <set> ref value)
(type-resolve-circularities value))
(($ <cnd> test pass fail)
(type-resolve-circularities test)
(type-resolve-circularities pass)
(type-resolve-circularities fail))
(($ <seq> ls)
(for-each type-resolve-circularities ls))
((app ...)
(for-each type-resolve-circularities app))
(else #f)))
;; basic type inference on the body of a module
;; - internal references are to lambdas
;; - external references are to procedures (with completed type info)
;; - for each lambda
;; + add parameter constraints (intersection) from body
;; + add return type constaints (union) from last form(s)
;; - when complete, resolve cycles (e.g. even/odd => boolean)
(define (type-analyze-module-body name ls) (define (type-analyze-module-body name ls)
;;(write `(type-analyze-module-body ,name) (current-error-port)) (newline (current-error-port))
(for-each type-analyze-expr ls) (for-each type-analyze-expr ls)
(for-each type-resolve-circularities ls)) (for-each type-resolve-circularities ls))
(define (type-analyze-module name) (define (type-analyze-module name)
(let* ((mod (analyze-module name)) (let* ((mod (analyze-module name))
(ls (and (vector? mod) (module-ast mod)))) (ls (and (vector? mod) (module-ast mod))))
;;(write `(analyzing ,ls) (current-error-port)) (newline (current-error-port))
(and ls (and ls
(let ((x (let lp ((ls ls)) ;; first lambda (let ((x (let lp ((ls ls)) ;; first lambda
(and (pair? ls) (and (pair? ls)
@ -184,9 +260,15 @@
res res
(lp (- n 1) (cons (opcode-param-type x n) res))))) (lp (- n 1) (cons (opcode-param-type x n) res)))))
(define (opcode-type x)
(cons 'lambda (cons (opcode-return-type x) (opcode-param-types x))))
(define (lambda-type x)
(cons 'lambda (cons (lambda-return-type x) (lambda-param-types x))))
(define (procedure-signature x) (define (procedure-signature x)
(if (opcode? x) (if (opcode? x)
(cons (opcode-return-type x) (opcode-param-types x)) (cdr (opcode-type x))
(let lp ((count 0)) (let lp ((count 0))
(let ((lam (procedure-analysis x))) (let ((lam (procedure-analysis x)))
(cond (cond
@ -196,7 +278,6 @@
(and (type-analyze-module (car mod)) (and (type-analyze-module (car mod))
(lp (+ count 1))))) (lp (+ count 1)))))
((lambda? lam) ((lambda? lam)
(cons (lambda-return-type lam) (cdr (lambda-type lam)))
(lambda-param-types lam)))
(else (else
#f)))))) #f))))))