Fix array-append signature.

This commit is contained in:
Alex Shinn 2023-05-20 16:31:23 +09:00
parent 870e484b50
commit d6c58a7e11
2 changed files with 18 additions and 12 deletions

View file

@ -61,7 +61,7 @@
(define (interval-width iv i) (define (interval-width iv i)
(- (interval-upper-bound iv i) (interval-lower-bound iv i))) (- (interval-upper-bound iv i) (interval-lower-bound iv i)))
(define (interval-widths iv) (define (interval-widths iv)
(vector-map - (interval-lb iv) (interval-ub iv))) (vector-map - (interval-ub iv) (interval-lb iv)))
(define (interval= iv1 iv2) (define (interval= iv1 iv2)
(assert (and (interval? iv1) (interval? iv2))) (assert (and (interval? iv1) (interval? iv2)))

View file

@ -568,15 +568,23 @@
(interval-lower-bound b-domain d)))) (interval-lower-bound b-domain d))))
(lp (- d 1))))))) (lp (- d 1)))))))
(define (array-append axis a . o) (define (array-append axis arrays . o)
(assert (and (exact-integer? axis) (assert (and (exact-integer? axis)
(array? a) (pair? arrays)
(< -1 axis (array-dimension a)) (< -1 axis (array-dimension (car arrays)))
(every array? o))) (every array? arrays)))
(let ((a-domain (array-domain a))) (let* ((a (car arrays))
(a-domain (array-domain a))
(storage (if (pair? o) (car o) generic-storage-class))
(mutable? (if (and (pair? o) (pair? (cdr o)))
(cadr o)
(specialized-array-default-mutable?)))
(safe? (if (and (pair? o) (pair? (cdr o)) (pair? (cddr o)))
(car (cddr o))
(specialized-array-default-safe?))))
(assert (every (lambda (b) (assert (every (lambda (b)
(dimensions-compatible? a-domain (array-domain b) axis)) (dimensions-compatible? a-domain (array-domain b) axis))
o)) (cdr arrays)))
(let* ((a-lo (interval-lower-bounds->vector a-domain)) (let* ((a-lo (interval-lower-bounds->vector a-domain))
(c-lo (make-vector (interval-dimension a-domain) 0)) (c-lo (make-vector (interval-dimension a-domain) 0))
(c-hi (interval-widths a-domain))) (c-hi (interval-widths a-domain)))
@ -585,16 +593,14 @@
(fold (lambda (b sum) (fold (lambda (b sum)
(+ sum (interval-width (array-domain b) axis))) (+ sum (interval-width (array-domain b) axis)))
(vector-ref c-hi axis) (vector-ref c-hi axis)
o)) (cdr arrays)))
(let* ((c-domain (make-interval c-lo c-hi)) (let* ((c-domain (make-interval c-lo c-hi))
(c (make-specialized-array c-domain (c (make-specialized-array c-domain storage mutable? safe?))
(or (array-storage-class a)
generic-storage-class)))
(b-trans (make-vector (array-dimension a) 0))) (b-trans (make-vector (array-dimension a) 0)))
(array-assign! (array-assign!
(array-extract c (make-interval c-lo (interval-widths a-domain))) (array-extract c (make-interval c-lo (interval-widths a-domain)))
(array-translate a (vector-map - a-lo))) (array-translate a (vector-map - a-lo)))
(let lp ((arrays o) (let lp ((arrays (cdr arrays))
(b-offset (- (interval-upper-bound a-domain axis) (b-offset (- (interval-upper-bound a-domain axis)
(interval-lower-bound a-domain axis)))) (interval-lower-bound a-domain axis))))
(if (null? arrays) (if (null? arrays)