Update to new make-specialized-array signature.

This commit is contained in:
Alex Shinn 2024-05-14 21:46:25 +09:00
parent 2e09a082c8
commit 7ac3cfebe1
2 changed files with 31 additions and 7 deletions

View file

@ -2101,6 +2101,19 @@
(test-error (array-curry 'a 1))
(test-error
(array-curry (make-array (make-interval '#(0) '#(1)) list) 'a))
(let ((A (make-array (make-interval '#(10 10)) list)))
(test (array-ref A 3 4)
(array-ref (array-ref (array-curry A 1) 3) 4)))
(let ((A (make-array (make-interval '#(10 10 10)) list)))
(test (array-ref A 3 4 5)
(array-ref (array-ref (array-curry A 1) 3 4) 5)))
(test '((4 7) (2 6))
(array->list*
(array-ref
(array-curry (list*->array 3 '(((4 7) (2 6)) ((1 0) (0 1))))
2)
0)))
;; (test-error
;; (array-curry (make-array (make-interval '#(0 0) '#(1 1)) list) 0))
;; (test-error
@ -3169,6 +3182,11 @@
1 ;; along axis 1 (i.e., columns) ...
(map a-column '(1 2 5 8)))) ;; the columns of A you want
))
(test '(((4 7) (2 6))
((1 0) (0 1)))
(array->list*
(array-stack 0 (list (list*->array 2 '((4 7) (2 6)))
(list*->array 2 '((1 0) (0 1)))))))
'(test '((0 1 4 6 7 8)
(2 3 5 9 10 11)
(12 13 14 15 16 17))

View file

@ -423,7 +423,7 @@
(safe? (if (and (pair? o) (pair? (cdr o)) (pair? (cddr o)))
(car (cddr o))
(specialized-array-default-safe?)))
(res (make-specialized-array domain storage safe?)))
(res (make-specialized-array domain storage (storage-class-default storage) safe?)))
(assert (and (interval? domain) (storage-class? storage)
(boolean? mutable?) (boolean? safe?)))
(interval-fold
@ -506,6 +506,7 @@
(let ((res (make-specialized-array
new-domain
(array-storage-class array)
(storage-class-default (array-storage-class array))
(array-safe? array))))
(array-assign! res array)
res))
@ -538,7 +539,7 @@
(interval-lower-bound domain 0)))))
(else
(let ((domain (array-domain a))
(b (array-curry a 1)))
(b (array-curry a (- (array-dimension a) 1))))
(map (lambda (i) (array->list* (array-ref b i)))
(iota (interval-width domain 0)
(interval-lower-bound domain 0)))))))
@ -612,7 +613,7 @@
(vector-ref c-hi axis)
(cdr arrays)))
(let* ((c-domain (make-interval c-lo c-hi))
(c (make-specialized-array c-domain storage mutable? safe?))
(c (make-specialized-array c-domain storage (storage-class-default storage) safe?))
(b-trans (make-vector (array-dimension a) 0)))
(array-assign!
(array-extract c (make-interval c-lo (interval-widths a-domain)))
@ -663,10 +664,15 @@
`#(,@(take a-ubs axis) ,(length arrays) ,@(drop a-ubs axis))))
(res (make-specialized-array domain
(or (array-storage-class a)
generic-storage-class)))
generic-storage-class)
(storage-class-default storage)
safe?))
;; Stack by permuting the desired axis to the first
;; dimension and currying on that, assigning the
;; corresponding array argument to each element.
(perm `#(,axis ,@(delete axis (iota (+ 1 (array-dimension a))))))
(permed (if (zero? axis) res (array-permute res perm)))
(curried (array-curry permed 1))
(curried (array-curry permed (- (array-dimension permed) 1)))
(get-view (array-getter curried)))
(let lp ((ls arrays) (i 0))
(cond
@ -694,7 +700,7 @@
(vector-append (interval-widths a-domain)
(interval-widths (array-domain tile0)))))
(scales (vector->list (interval-widths a-domain)))
(res (make-specialized-array domain storage mutable? safe?)))
(res (make-specialized-array domain (storage-class-default storage) safe?)))
(error "TODO: array-block copy data unimplemented")
res))))
@ -712,7 +718,7 @@
(elt0 (apply array-ref a (interval-lower-bounds->list a-domain)))
(elt-domain (array-domain elt0))
(domain (interval-cartesian-product a-domain elt-domain))
(res (make-specialized-array domain storage mutable? safe?))
(res (make-specialized-array domain storage (storage-class-default storage) safe?))
(curried-res (array-curry res (interval-dimension elt-domain))))
;; Prepare a res with the flattened domain, create a new curried
;; view of the res with the same domain as a, and assign each