Implement and use array-freeze!

Fixes #960.
This commit is contained in:
Alex Shinn 2024-05-24 19:20:14 +09:00
parent f60298b707
commit 33a59952a8
4 changed files with 28 additions and 19 deletions

View file

@ -52,6 +52,7 @@
array-assign! array-ref array-set! array-decurry
specialized-array-reshape
array-copy! array-stack! array-decurry! array-append! array-block!
array-freeze!
)
(include "231/transforms.scm")
(cond-expand

View file

@ -324,6 +324,9 @@
(lambda (val . multi-index)
(setter body (apply indexer multi-index) val)))
(define (array-freeze! array)
(%array-setter-set! array #f)
array)
;; Indexing

View file

@ -44,5 +44,6 @@
specialized-array-share array-ref array-set!
%make-specialized %array-setter-set!
specialized-getter specialized-setter
array-freeze!
)
(include "base.scm"))

View file

@ -68,6 +68,14 @@
;; Array transformations
(define (make-specialized-array/default domain . o)
(let ((storage (if (pair? o) (car o) generic-storage-class)))
(apply make-specialized-array
domain
storage
(storage-class-default storage)
(if (pair? o) (cdr o) '()))))
(define (array-copy array . o)
(assert (array? array))
(let ((specialized? (specialized-array? array))
@ -97,9 +105,7 @@
(res (%make-specialized domain storage body coeffs indexer
safe? #t #t)))
(array-assign! res array)
(unless mutable?
(%array-setter-set! res #f))
res))))
(if mutable? res (array-freeze! res))))))
(define array-copy! array-copy)
@ -422,7 +428,7 @@
(safe? (if (and (pair? o) (pair? (cdr o)) (pair? (cddr o)))
(car (cddr o))
(specialized-array-default-safe?)))
(res (make-specialized-array domain storage (storage-class-default storage) safe?)))
(res (make-specialized-array/default domain storage safe?)))
(assert (and (interval? domain) (storage-class? storage)
(boolean? mutable?) (boolean? safe?)))
(interval-fold
@ -431,7 +437,7 @@
(cdr ls))
ls
domain)
res))
(if mutable? res (array-freeze! res))))
(define (array->vector array)
(list->vector (array->list array)))
@ -502,10 +508,9 @@
(cond
((reshape-without-copy array new-domain))
(copy-on-failure?
(let ((res (make-specialized-array
(let ((res (make-specialized-array/default
new-domain
(array-storage-class array)
(storage-class-default (array-storage-class array))
(array-safe? array))))
(array-assign! res array)
res))
@ -612,7 +617,7 @@
(vector-ref c-hi axis)
(cdr arrays)))
(let* ((c-domain (make-interval c-lo c-hi))
(c (make-specialized-array c-domain storage (storage-class-default storage) safe?))
(c (make-specialized-array/default c-domain storage safe?))
(b-trans (make-vector (array-dimension a) 0)))
(array-assign!
(array-extract c (make-interval c-lo (interval-widths a-domain)))
@ -621,7 +626,7 @@
(b-offset (- (interval-upper-bound a-domain axis)
(interval-lower-bound a-domain axis))))
(if (null? arrays)
c
(if mutable? c (array-freeze! c))
(let* ((b (car arrays))
(b-domain (array-domain b))
(b-offset2 (+ b-offset (interval-width b-domain axis)))
@ -661,11 +666,10 @@
(make-interval
`#(,@(take a-lbs axis) 0 ,@(drop a-lbs axis))
`#(,@(take a-ubs axis) ,(length arrays) ,@(drop a-ubs axis))))
(res (make-specialized-array domain
(or (array-storage-class a)
generic-storage-class)
(storage-class-default storage)
safe?))
(res (make-specialized-array/default domain
(or (array-storage-class a)
generic-storage-class)
safe?))
;; Stack by permuting the desired axis to the first
;; dimension and currying on that, assigning the
;; corresponding array argument to each element.
@ -675,7 +679,7 @@
(get-view (array-getter curried)))
(let lp ((ls arrays) (i 0))
(cond
((null? ls) res)
((null? ls) (if mutable? res (array-freeze! res)))
(else
(array-assign! (get-view i) (car ls))
(lp (cdr ls) (+ i 1))))))))
@ -722,7 +726,7 @@
(vector-iota (array-dimension a) 0)))
(domain
(make-interval (vector-map vector-last tile-offsets)))
(res (make-specialized-array domain storage (storage-class-default storage) safe?)))
(res (make-specialized-array/default domain storage safe?)))
(interval-for-each
(lambda multi-index
(let* ((multi-index/0 (list->vector (map - multi-index index0)))
@ -747,7 +751,7 @@
(interval-lower-bounds->vector
(array-domain subarray)))))))
a-domain)
res))))
(if mutable? res (array-freeze! res))))))
(define array-block! array-block)
@ -763,12 +767,12 @@
(elt0 (apply array-ref a (interval-lower-bounds->list a-domain)))
(elt-domain (array-domain elt0))
(domain (interval-cartesian-product a-domain elt-domain))
(res (make-specialized-array domain storage (storage-class-default storage) safe?))
(res (make-specialized-array/default domain storage safe?))
(curried-res (array-curry res (interval-dimension elt-domain))))
;; Prepare a res with the flattened domain, create a new curried
;; view of the res with the same domain as a, and assign each
;; curried view from a to the res.
(array-for-each array-assign! curried-res a)
res))
(if mutable? res (array-freeze! res))))
(define array-decurry! array-decurry)