diff --git a/lib/chibi/type-inference.scm b/lib/chibi/type-inference.scm index 605a442a..137cc251 100644 --- a/lib/chibi/type-inference.scm +++ b/lib/chibi/type-inference.scm @@ -1,3 +1,7 @@ +;; type-inference.scm -- general type-inference for Scheme +;; +;; Copyright (c) 2010 Alex Shinn. All rights reserved. +;; BSD-style license: http://synthcode.com/license.txt (define (typed? x) (and (lambda? x) @@ -10,7 +14,13 @@ (and (pair? a) (equal? (car a) 'and))) (define (unfinalized-type? a) - (and (pair? a) (memq (car a) '(return-type param-type)))) + (and (pair? a) + (or (memq (car a) '(return-type param-type)) + (and (memq (car a) '(and or)) + (any unfinalized-type? (cdr a)))))) + +(define (finalized-type? a) + (not (unfinalized-type? a))) (define (numeric-type? a) (or (eq? a ) (eq? a ) (eq? a ))) @@ -55,6 +65,17 @@ (cons (car a) (lset-adjoin equal? (cdr a) b)))) (else (list 'and a b)))) +(define (type-of x) + (cond ((boolean? x) ) + ((char? x) ) + ((symbol? x) ) + ((string? x) ) + ((and (integer? x) (exact? x)) ) + ((flonum? x) ) + ((pair? x) ) + ((vector? x) ) + (else ))) + (define (lambda-param-types-initialize! f) (lambda-param-types-set! f (map (lambda (p) (list 'param-type f p)) (lambda-params f)))) @@ -79,24 +100,32 @@ => (lambda (cell) (set-car! cell y))))) (define (type-analyze-expr x) - ;;(write `(type-analyze-expr ,x ,(ast->sexp x)) (current-error-port)) (newline (current-error-port)) (match x (($ name params body defs) - (lambda-return-type-set! x (list 'return-type x)) - (lambda-param-types-initialize! x) - (let ((ret-type (type-analyze-expr body))) - (lambda-return-type-set! x ret-type) - (cons 'lambda (cons ret-type (lambda-param-types x))))) + (cond + ((not (lambda-return-type x)) + (lambda-return-type-set! x (list 'return-type x)) + (lambda-param-types-initialize! x) + (let ((ret-type (type-analyze-expr body))) + (lambda-return-type-set! x ret-type) + (cons 'lambda (cons ret-type (lambda-param-types x))))))) (($ ref value) (type-analyze-expr value) (if #f #f)) - (($ name (_ . loc) source) - (if (lambda? loc) - (lambda-param-type-ref loc name) - )) + (($ name (value . loc) source) + (cond + ((lambda? loc) (lambda-param-type-ref loc name)) + ((procedure? loc) + (let ((sig (procedure-signature loc))) + (if (and (pair? sig) (car sig)) + (cons 'lambda sig) + (list 'return-type (procedure-analysis loc))))) + (else ))) (($ test pass fail) - (type-analyze-expr test) - (type-union (type-analyze-expr pass) (type-analyze-expr fail))) + (let ((test-type (type-analyze-expr test)) + (pass-type (type-analyze-expr pass)) + (fail-type (type-analyze-expr fail))) + (type-union pass-type fail-type))) (($ ls) (let lp ((ls ls)) (cond ((null? (cdr ls)) @@ -124,7 +153,10 @@ (else (let ((t (type-analyze-expr (car a)))) (cond - ((not (type-subset? t p-type)) + ((and t p-type + (finalized-type? t) + (finalized-type? p-type) + (not (type-subset? t p-type))) (display "WARNING: incompatible type: " (current-error-port)) (write (list x t p-type) (current-error-port)) @@ -138,32 +170,76 @@ (let ((f-type (type-analyze-expr f))) ;; XXXX apply f-type to params (for-each type-analyze-expr args) - (if (and (pair? f-type) (eq? 'lambda (car f-type))) - (cadr f-type) - ))))) + (cond + ((and (pair? f-type) (eq? (car f-type) 'lambda)) + (cadr f-type)) + ((and (pair? f-type) (memq (car f-type) '(return-type param-type))) + f-type) + (else + )))))) (else - ;;(write `(unknown type ,x) (current-error-port)) (newline (current-error-port)) - ))) + (type-of x)))) + +(define (resolve-delayed-type x) + (let lp ((x x) (seen '()) (default )) + (match x + (('return-type f) + (if (memq f seen) + default + (lp (lambda-return-type f) (cons f seen) default))) + (('param-type f p) + (if (member x seen) + default + (lp (lambda-param-type-ref f p) (cons x seen) default))) + (('or y ...) + (let ((z (find finalized-type? y))) + (if z + (let ((default (if (eq? default ) + (lp z seen default) + (type-union (lp z seen default) default)))) + (fold type-union + default + (map (lambda (y1) (lp y1 seen default)) (delete z y)))) + (fold type-union default (map (lambda (y1) (lp y1 seen default)) y))))) + (('and y ...) + (fold type-intersection default (map (lambda (y1) (lp y1 seen default)) y))) + (('not y) + (list 'not (lp y seen default))) + (else + x)))) (define (type-resolve-circularities x) - #f) + (match x + (($ name params body defs) + (if (unfinalized-type? (lambda-return-type x)) + (lambda-return-type-set! x (resolve-delayed-type + (lambda-return-type x)))) + (for-each + (lambda (p t) + (if (unfinalized-type? t) + (lambda-param-type-set! x p (resolve-delayed-type t)))) + params + (lambda-param-types x)) + (type-resolve-circularities (lambda-body x))) + (($ ref value) + (type-resolve-circularities value)) + (($ test pass fail) + (type-resolve-circularities test) + (type-resolve-circularities pass) + (type-resolve-circularities fail)) + (($ ls) + (for-each type-resolve-circularities ls)) + ((app ...) + (for-each type-resolve-circularities app)) + (else #f))) -;; basic type inference on the body of a module -;; - internal references are to lambdas -;; - external references are to procedures (with completed type info) -;; - for each lambda -;; + add parameter constraints (intersection) from body -;; + add return type constaints (union) from last form(s) -;; - when complete, resolve cycles (e.g. even/odd => boolean) (define (type-analyze-module-body name ls) - ;;(write `(type-analyze-module-body ,name) (current-error-port)) (newline (current-error-port)) (for-each type-analyze-expr ls) (for-each type-resolve-circularities ls)) (define (type-analyze-module name) (let* ((mod (analyze-module name)) (ls (and (vector? mod) (module-ast mod)))) - ;;(write `(analyzing ,ls) (current-error-port)) (newline (current-error-port)) (and ls (let ((x (let lp ((ls ls)) ;; first lambda (and (pair? ls) @@ -184,9 +260,15 @@ res (lp (- n 1) (cons (opcode-param-type x n) res))))) +(define (opcode-type x) + (cons 'lambda (cons (opcode-return-type x) (opcode-param-types x)))) + +(define (lambda-type x) + (cons 'lambda (cons (lambda-return-type x) (lambda-param-types x)))) + (define (procedure-signature x) (if (opcode? x) - (cons (opcode-return-type x) (opcode-param-types x)) + (cdr (opcode-type x)) (let lp ((count 0)) (let ((lam (procedure-analysis x))) (cond @@ -196,7 +278,6 @@ (and (type-analyze-module (car mod)) (lp (+ count 1))))) ((lambda? lam) - (cons (lambda-return-type lam) - (lambda-param-types lam))) + (cdr (lambda-type lam))) (else #f))))))