Implement and use array-freeze!

Fixes #960.
This commit is contained in:
Alex Shinn 2024-05-24 19:20:14 +09:00
parent f60298b707
commit 33a59952a8
4 changed files with 28 additions and 19 deletions

View file

@ -52,6 +52,7 @@
array-assign! array-ref array-set! array-decurry array-assign! array-ref array-set! array-decurry
specialized-array-reshape specialized-array-reshape
array-copy! array-stack! array-decurry! array-append! array-block! array-copy! array-stack! array-decurry! array-append! array-block!
array-freeze!
) )
(include "231/transforms.scm") (include "231/transforms.scm")
(cond-expand (cond-expand

View file

@ -324,6 +324,9 @@
(lambda (val . multi-index) (lambda (val . multi-index)
(setter body (apply indexer multi-index) val))) (setter body (apply indexer multi-index) val)))
(define (array-freeze! array)
(%array-setter-set! array #f)
array)
;; Indexing ;; Indexing

View file

@ -44,5 +44,6 @@
specialized-array-share array-ref array-set! specialized-array-share array-ref array-set!
%make-specialized %array-setter-set! %make-specialized %array-setter-set!
specialized-getter specialized-setter specialized-getter specialized-setter
array-freeze!
) )
(include "base.scm")) (include "base.scm"))

View file

@ -68,6 +68,14 @@
;; Array transformations ;; Array transformations
(define (make-specialized-array/default domain . o)
(let ((storage (if (pair? o) (car o) generic-storage-class)))
(apply make-specialized-array
domain
storage
(storage-class-default storage)
(if (pair? o) (cdr o) '()))))
(define (array-copy array . o) (define (array-copy array . o)
(assert (array? array)) (assert (array? array))
(let ((specialized? (specialized-array? array)) (let ((specialized? (specialized-array? array))
@ -97,9 +105,7 @@
(res (%make-specialized domain storage body coeffs indexer (res (%make-specialized domain storage body coeffs indexer
safe? #t #t))) safe? #t #t)))
(array-assign! res array) (array-assign! res array)
(unless mutable? (if mutable? res (array-freeze! res))))))
(%array-setter-set! res #f))
res))))
(define array-copy! array-copy) (define array-copy! array-copy)
@ -422,7 +428,7 @@
(safe? (if (and (pair? o) (pair? (cdr o)) (pair? (cddr o))) (safe? (if (and (pair? o) (pair? (cdr o)) (pair? (cddr o)))
(car (cddr o)) (car (cddr o))
(specialized-array-default-safe?))) (specialized-array-default-safe?)))
(res (make-specialized-array domain storage (storage-class-default storage) safe?))) (res (make-specialized-array/default domain storage safe?)))
(assert (and (interval? domain) (storage-class? storage) (assert (and (interval? domain) (storage-class? storage)
(boolean? mutable?) (boolean? safe?))) (boolean? mutable?) (boolean? safe?)))
(interval-fold (interval-fold
@ -431,7 +437,7 @@
(cdr ls)) (cdr ls))
ls ls
domain) domain)
res)) (if mutable? res (array-freeze! res))))
(define (array->vector array) (define (array->vector array)
(list->vector (array->list array))) (list->vector (array->list array)))
@ -502,10 +508,9 @@
(cond (cond
((reshape-without-copy array new-domain)) ((reshape-without-copy array new-domain))
(copy-on-failure? (copy-on-failure?
(let ((res (make-specialized-array (let ((res (make-specialized-array/default
new-domain new-domain
(array-storage-class array) (array-storage-class array)
(storage-class-default (array-storage-class array))
(array-safe? array)))) (array-safe? array))))
(array-assign! res array) (array-assign! res array)
res)) res))
@ -612,7 +617,7 @@
(vector-ref c-hi axis) (vector-ref c-hi axis)
(cdr arrays))) (cdr arrays)))
(let* ((c-domain (make-interval c-lo c-hi)) (let* ((c-domain (make-interval c-lo c-hi))
(c (make-specialized-array c-domain storage (storage-class-default storage) safe?)) (c (make-specialized-array/default c-domain storage safe?))
(b-trans (make-vector (array-dimension a) 0))) (b-trans (make-vector (array-dimension a) 0)))
(array-assign! (array-assign!
(array-extract c (make-interval c-lo (interval-widths a-domain))) (array-extract c (make-interval c-lo (interval-widths a-domain)))
@ -621,7 +626,7 @@
(b-offset (- (interval-upper-bound a-domain axis) (b-offset (- (interval-upper-bound a-domain axis)
(interval-lower-bound a-domain axis)))) (interval-lower-bound a-domain axis))))
(if (null? arrays) (if (null? arrays)
c (if mutable? c (array-freeze! c))
(let* ((b (car arrays)) (let* ((b (car arrays))
(b-domain (array-domain b)) (b-domain (array-domain b))
(b-offset2 (+ b-offset (interval-width b-domain axis))) (b-offset2 (+ b-offset (interval-width b-domain axis)))
@ -661,10 +666,9 @@
(make-interval (make-interval
`#(,@(take a-lbs axis) 0 ,@(drop a-lbs axis)) `#(,@(take a-lbs axis) 0 ,@(drop a-lbs axis))
`#(,@(take a-ubs axis) ,(length arrays) ,@(drop a-ubs axis)))) `#(,@(take a-ubs axis) ,(length arrays) ,@(drop a-ubs axis))))
(res (make-specialized-array domain (res (make-specialized-array/default domain
(or (array-storage-class a) (or (array-storage-class a)
generic-storage-class) generic-storage-class)
(storage-class-default storage)
safe?)) safe?))
;; Stack by permuting the desired axis to the first ;; Stack by permuting the desired axis to the first
;; dimension and currying on that, assigning the ;; dimension and currying on that, assigning the
@ -675,7 +679,7 @@
(get-view (array-getter curried))) (get-view (array-getter curried)))
(let lp ((ls arrays) (i 0)) (let lp ((ls arrays) (i 0))
(cond (cond
((null? ls) res) ((null? ls) (if mutable? res (array-freeze! res)))
(else (else
(array-assign! (get-view i) (car ls)) (array-assign! (get-view i) (car ls))
(lp (cdr ls) (+ i 1)))))))) (lp (cdr ls) (+ i 1))))))))
@ -722,7 +726,7 @@
(vector-iota (array-dimension a) 0))) (vector-iota (array-dimension a) 0)))
(domain (domain
(make-interval (vector-map vector-last tile-offsets))) (make-interval (vector-map vector-last tile-offsets)))
(res (make-specialized-array domain storage (storage-class-default storage) safe?))) (res (make-specialized-array/default domain storage safe?)))
(interval-for-each (interval-for-each
(lambda multi-index (lambda multi-index
(let* ((multi-index/0 (list->vector (map - multi-index index0))) (let* ((multi-index/0 (list->vector (map - multi-index index0)))
@ -747,7 +751,7 @@
(interval-lower-bounds->vector (interval-lower-bounds->vector
(array-domain subarray))))))) (array-domain subarray)))))))
a-domain) a-domain)
res)))) (if mutable? res (array-freeze! res))))))
(define array-block! array-block) (define array-block! array-block)
@ -763,12 +767,12 @@
(elt0 (apply array-ref a (interval-lower-bounds->list a-domain))) (elt0 (apply array-ref a (interval-lower-bounds->list a-domain)))
(elt-domain (array-domain elt0)) (elt-domain (array-domain elt0))
(domain (interval-cartesian-product a-domain elt-domain)) (domain (interval-cartesian-product a-domain elt-domain))
(res (make-specialized-array domain storage (storage-class-default storage) safe?)) (res (make-specialized-array/default domain storage safe?))
(curried-res (array-curry res (interval-dimension elt-domain)))) (curried-res (array-curry res (interval-dimension elt-domain))))
;; Prepare a res with the flattened domain, create a new curried ;; Prepare a res with the flattened domain, create a new curried
;; view of the res with the same domain as a, and assign each ;; view of the res with the same domain as a, and assign each
;; curried view from a to the res. ;; curried view from a to the res.
(array-for-each array-assign! curried-res a) (array-for-each array-assign! curried-res a)
res)) (if mutable? res (array-freeze! res))))
(define array-decurry! array-decurry) (define array-decurry! array-decurry)