Fix validation on specialized-array-reshape.

This commit is contained in:
Alex Shinn 2021-04-30 13:38:53 +09:00
parent 76284f79f0
commit 3c138dc808
2 changed files with 126 additions and 98 deletions

View file

@ -265,7 +265,7 @@
array?
(domain array-domain)
(getter array-getter)
(setter array-setter)
(setter array-setter %array-setter-set!)
(storage array-storage-class)
(body array-body)
(coeffs array-coeffs)
@ -326,19 +326,25 @@
(if (and verify? (zero? count))
(lp 1 multi-index offset (+ count 1))
res))
((= (+ 1 (car ls)) (interval-upper-bound domain (- i 1)))
((= (+ 1 (interval-lower-bound domain (- i 1)))
(interval-upper-bound domain (- i 1)))
(lp (+ i 1) (cdr ls) offset count))
(else
(set-car! ls (+ 1 (car ls)))
(let* ((offset2 (apply indexer multi-index))
(coeff (- offset2 offset)))
(cond
((> count 0)
(and (= coeff (vector-ref res i))
(lp (+ i 1) (cdr ls) offset2 count)))
(else
(vector-set! res i coeff)
(lp (+ i 1) (cdr ls) offset2 count)))))))))
(let ((dir (if (and (> count 0)
(= (+ (car ls) 1)
(interval-upper-bound domain (- i 1))))
-1
1)))
(set-car! ls (+ (car ls) dir))
(let* ((offset2 (apply indexer multi-index))
(coeff (* dir (- offset2 offset))))
(cond
((> count 0)
(and (= coeff (vector-ref res i))
(lp (+ i 1) (cdr ls) offset2 count)))
(else
(vector-set! res i coeff)
(lp (+ i 1) (cdr ls) offset2 count))))))))))
(define (coeffs->indexer coeffs domain)
(case (vector-length coeffs)
@ -396,25 +402,47 @@
(define (default-indexer domain)
(coeffs->indexer (default-coeffs domain) domain))
;; converts the raw integer index to the multi-index in domain that
;; would map to it using the default indexer.
;; Converts the raw integer index to the multi-index in domain that
;; would map to it using the default indexer (i.e. iterating over the
;; possible multi-indices in domain in lexicographic order would
;; produce 0 through volume-1).
(define (invert-default-index domain raw-index)
(let lp ((index raw-index)
(i (- (interval-dimension domain) 1))
(scale 1)
(i 0)
(scale (/ (interval-volume domain)
(max 1
(- (interval-upper-bound domain 0)
(interval-lower-bound domain 0)))))
(res '()))
(if (negative? i)
res
(let* ((width (- (interval-upper-bound domain i)
(interval-lower-bound domain i)))
(elt (modulo index width)))
(lp (quotient (- index elt) scale)
(- i 1)
(* scale width)
(cons (+ elt (interval-lower-bound domain i)) res))))))
(cond
((>= (+ i 1) (interval-dimension domain))
(reverse (cons (+ index (interval-lower-bound domain i)) res)))
(else
(let ((digit (quotient index scale)))
(lp (- index (* digit scale))
(+ i 1)
(/ scale
(max 1
(- (interval-upper-bound domain (+ i 1))
(interval-lower-bound domain (+ i 1)))))
(cons (+ digit
(interval-lower-bound domain i))
res)))))))
;; Specialized arrays
(define (%make-specialized domain storage body coeffs indexer safe? mutable?)
(%make-array
domain
(specialized-getter body indexer (storage-class-getter storage))
(and mutable?
(specialized-setter body indexer (storage-class-setter storage)))
storage
body
coeffs
indexer
safe?))
(define (make-specialized-array domain . o)
(let* ((storage (if (pair? o) (car o) generic-storage-class))
(safe? (if (and (pair? o) (pair? (cdr o)))
@ -426,15 +454,7 @@
(coeffs (default-coeffs domain))
(indexer (coeffs->indexer coeffs domain)))
(assert (boolean? safe?))
(%make-array
domain
(specialized-getter body indexer (storage-class-getter storage))
(specialized-setter body indexer (storage-class-setter storage))
storage
body
coeffs
indexer
safe?)))
(%make-specialized domain storage body coeffs indexer safe? #t)))
(define (specialized-array? x)
(and (array? x) (array-storage-class x) #t))
@ -468,15 +488,8 @@
(indexer
(coeffs->indexer coeffs new-domain))
(storage (array-storage-class array)))
(%make-array
new-domain
(specialized-getter body indexer (storage-class-getter storage))
(specialized-setter body indexer (storage-class-setter storage))
storage
body
coeffs
indexer
(array-safe? array))))
(%make-specialized new-domain storage body coeffs indexer
(array-safe? array) (array-setter array))))
;; Array transformations
@ -500,9 +513,12 @@
(storage-class-getter storage)))
(setter (specialized-setter body indexer
(storage-class-setter storage)))
(res (%make-array new-domain getter setter
storage body coeffs indexer safe?)))
(array-assign! res array))))
(res (%make-specialized new-domain storage body coeffs indexer
safe? #t)))
(array-assign! res array)
(unless mutable?
(%array-setter-set! res #f))
res)))
(define (array-curry array inner-dimension)
(call-with-values
@ -814,16 +830,45 @@
(array-domain source))))
destination))
(define (reshape-indexer array new-domain)
(let ((orig-indexer (array-indexer array))
(tmp-indexer (default-indexer new-domain)))
(indexer->coeffs
(lambda multi-index
(apply orig-indexer
(invert-default-index (array-domain array)
(apply tmp-indexer multi-index))))
new-domain
#t)))
(define (reshape-without-copy array new-domain)
(let* ((domain (array-domain array))
(orig-indexer (array-indexer array))
(tmp-indexer (default-indexer new-domain))
(new-indexer
(lambda multi-index
(apply orig-indexer
(invert-default-index domain
(apply tmp-indexer multi-index)))))
(new-coeffs
(indexer->coeffs new-indexer new-domain #t))
(flat-indexer
(coeffs->indexer new-coeffs new-domain))
(new-indexer (coeffs->indexer new-coeffs new-domain))
(body (array-body array))
(storage (array-storage-class array))
(res
(%make-specialized new-domain storage body new-coeffs flat-indexer
(array-safe? array) (array-setter array))))
(let ((multi-index (interval-lower-bounds->list domain))
(orig-default-indexer (default-indexer domain)))
(let lp ((i 0)
(ls multi-index))
(let ((reshaped-index
(invert-default-index
new-domain
(apply orig-default-indexer multi-index))))
(cond
((not (equal? (apply flat-indexer reshaped-index)
(apply orig-indexer multi-index)))
#f)
((null? ls)
res)
((= (+ 1 (interval-lower-bound domain i))
(interval-upper-bound domain i))
(lp (+ i 1) (cdr ls)))
(else
(set-car! ls (+ 1 (car ls)))
(lp (+ i 1) (cdr ls)))))))))
(define (specialized-array-reshape array new-domain . o)
(assert (specialized-array? array)
@ -831,24 +876,7 @@
(interval-volume new-domain)))
(let ((copy-on-failure? (and (pair? o) (car o))))
(cond
((reshape-indexer array new-domain)
=> (lambda (new-coeffs)
(let* ((new-indexer (coeffs->indexer new-coeffs new-domain))
(body (array-body array))
(storage (array-storage-class array)))
(%make-array
new-domain
(specialized-getter body
new-indexer
(storage-class-getter storage))
(specialized-setter body
new-indexer
(storage-class-setter storage))
storage
body
new-coeffs
new-indexer
(array-safe? array)))))
((reshape-without-copy array new-domain))
(copy-on-failure?
(let* ((res (make-specialized-array
new-domain

View file

@ -3406,33 +3406,33 @@ OTHER DEALINGS IN THE SOFTWARE.
(specialized-array-reshape array (make-interval '#(4))))
(array->list array)))
;; (test-error
;; (specialized-array-reshape
;; (array-reverse
;; (array-copy (make-array (make-interval '#(2 1 3 1)) list))
;; '#(#t #f #f #f))
;; (make-interval '#(6))))
(test-error
(specialized-array-reshape
(array-reverse
(array-copy (make-array (make-interval '#(2 1 3 1)) list))
'#(#t #f #f #f))
(make-interval '#(6))))
;; (test-error
;; (specialized-array-reshape
;; (array-reverse
;; (array-copy (make-array (make-interval '#(2 1 3 1)) list))
;; '#(#t #f #f #f))
;; (make-interval '#(3 2))))
(test-error
(specialized-array-reshape
(array-reverse
(array-copy (make-array (make-interval '#(2 1 3 1)) list))
'#(#t #f #f #f))
(make-interval '#(3 2))))
;; (test-error
;; (specialized-array-reshape
;; (array-reverse
;; (array-copy (make-array (make-interval '#(2 1 3 1)) list))
;; '#(#f #f #t #f))
;; (make-interval '#(6))))
(test-error
(specialized-array-reshape
(array-reverse
(array-copy (make-array (make-interval '#(2 1 3 1)) list))
'#(#f #f #t #f))
(make-interval '#(6))))
;; (test-error
;; (specialized-array-reshape
;; (array-reverse
;; (array-copy (make-array (make-interval '#(2 1 3 1)) list))
;; '#(#f #f #t #t))
;; (make-interval '#(3 2))))
(test-error
(specialized-array-reshape
(array-reverse
(array-copy (make-array (make-interval '#(2 1 3 1)) list))
'#(#f #f #t #t))
(make-interval '#(3 2))))
(test-error
(specialized-array-reshape