flattening array indexers

This commit is contained in:
Alex Shinn 2021-04-28 22:53:16 +09:00
parent bf03c1cfa1
commit 76284f79f0
3 changed files with 145 additions and 60 deletions

View file

@ -261,26 +261,27 @@
;; Arrays ;; Arrays
(define-record-type Array (define-record-type Array
(%%make-array domain getter setter storage body indexer safe?) (%%make-array domain getter setter storage body coeffs indexer safe?)
array? array?
(domain array-domain) (domain array-domain)
(getter array-getter) (getter array-getter)
(setter array-setter) (setter array-setter)
(storage array-storage-class) (storage array-storage-class)
(body array-body) (body array-body)
(coeffs array-coeffs)
(indexer array-indexer) (indexer array-indexer)
(safe? array-safe?)) (safe? array-safe?))
(define (%make-array domain getter setter storage body indexer safe?) (define (%make-array domain getter setter storage body coeffs indexer safe?)
(assert (interval? domain) (assert (interval? domain)
(procedure? getter) (procedure? getter)
(or (not setter) (procedure? setter)) (or (not setter) (procedure? setter))
(or (not storage) (storage-class? storage))) (or (not storage) (storage-class? storage)))
(%%make-array domain getter setter storage body indexer safe?)) (%%make-array domain getter setter storage body coeffs indexer safe?))
(define (make-array domain getter . o) (define (make-array domain getter . o)
(assert (interval? domain) (procedure? getter)) (assert (interval? domain) (procedure? getter))
(%make-array domain getter (and (pair? o) (car o)) #f #f #f #f)) (%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f))
(define (array-dimension a) (define (array-dimension a)
(interval-dimension (array-domain a))) (interval-dimension (array-domain a)))
@ -308,25 +309,92 @@
(lambda (val . multi-index) (lambda (val . multi-index)
(setter body (apply indexer multi-index) val))) (setter body (apply indexer multi-index) val)))
(define (default-indexer domain) ;; Indexing
(lambda multi-index
(let ((dim (interval-dimension domain))) (define (indexer->coeffs indexer domain . o)
(let lp ((ls multi-index) (let* ((verify? (and (pair? o) (car o)))
(i 0) (res (make-vector (+ 1 (interval-dimension domain)) 0))
(res 0)) (multi-index (interval-lower-bounds->list domain))
(base (apply indexer multi-index)))
(vector-set! res 0 base)
(let lp ((i 1)
(ls multi-index)
(offset base)
(count 0))
(cond (cond
((null? ls) ((null? ls)
(if (< i dim) (if (and verify? (zero? count))
(lp 1 multi-index offset (+ count 1))
res))
((= (+ 1 (car ls)) (interval-upper-bound domain (- i 1)))
(lp (+ i 1) (cdr ls) offset count))
(else
(set-car! ls (+ 1 (car ls)))
(let* ((offset2 (apply indexer multi-index))
(coeff (- offset2 offset)))
(cond
((> count 0)
(and (= coeff (vector-ref res i))
(lp (+ i 1) (cdr ls) offset2 count)))
(else
(vector-set! res i coeff)
(lp (+ i 1) (cdr ls) offset2 count)))))))))
(define (coeffs->indexer coeffs domain)
(case (vector-length coeffs)
((2)
(let ((a (vector-ref coeffs 0))
(b (vector-ref coeffs 1))
(lo-x (interval-lower-bound domain 0)))
(lambda (x) (+ a (* b (- x lo-x))))))
((3)
(let ((a (vector-ref coeffs 0))
(b (vector-ref coeffs 1))
(c (vector-ref coeffs 2))
(lo-x (interval-lower-bound domain 0))
(lo-y (interval-lower-bound domain 1)))
(lambda (x y) (+ a (* b (- x lo-x)) (* c (- y lo-y))))))
(else
(lambda multi-index
(let ((lim (vector-length coeffs)))
(let lp ((ls multi-index)
(i 1)
(res (vector-ref coeffs 0)))
(cond
((null? ls)
(if (< i lim)
(error "multi-index too short for domain" multi-index domain) (error "multi-index too short for domain" multi-index domain)
res)) res))
((>= i dim) ((>= i lim)
(error "multi-index too long for domain" multi-index domain)) (error "multi-index too long for domain" multi-index domain))
(else (else
(lp (cdr ls) (lp (cdr ls)
(+ i 1) (+ i 1)
(+ (- (car ls) (interval-lower-bound domain i)) (+ res (* (- (car ls) (interval-lower-bound domain (- i 1)))
(* res (- (interval-upper-bound domain i) (vector-ref coeffs i))))))))))))
(interval-lower-bound domain i)))))))))))
(define (default-coeffs domain)
(let* ((dim (interval-dimension domain))
(res (make-vector (+ 1 dim))))
(vector-set! res 0 0)
(vector-set! res dim 1)
(let lp ((i (- dim 1))
(scale 1))
(cond
((< i 0)
res)
((= (+ 1 (interval-lower-bound domain i))
(interval-upper-bound domain i))
(vector-set! res (+ i 1) 0)
(lp (- i 1) scale))
(else
(let ((coeff (* scale (- (interval-upper-bound domain i)
(interval-lower-bound domain i)))))
(vector-set! res (+ i 1) scale)
(lp (- i 1) coeff)))))))
(define (default-indexer domain)
(coeffs->indexer (default-coeffs domain) domain))
;; converts the raw integer index to the multi-index in domain that ;; converts the raw integer index to the multi-index in domain that
;; would map to it using the default indexer. ;; would map to it using the default indexer.
@ -345,6 +413,8 @@
(* scale width) (* scale width)
(cons (+ elt (interval-lower-bound domain i)) res)))))) (cons (+ elt (interval-lower-bound domain i)) res))))))
;; Specialized arrays
(define (make-specialized-array domain . o) (define (make-specialized-array domain . o)
(let* ((storage (if (pair? o) (car o) generic-storage-class)) (let* ((storage (if (pair? o) (car o) generic-storage-class))
(safe? (if (and (pair? o) (pair? (cdr o))) (safe? (if (and (pair? o) (pair? (cdr o)))
@ -353,7 +423,8 @@
(body ((storage-class-maker storage) (body ((storage-class-maker storage)
(interval-volume domain) (interval-volume domain)
(storage-class-default storage))) (storage-class-default storage)))
(indexer (default-indexer domain))) (coeffs (default-coeffs domain))
(indexer (coeffs->indexer coeffs domain)))
(assert (boolean? safe?)) (assert (boolean? safe?))
(%make-array (%make-array
domain domain
@ -361,6 +432,7 @@
(specialized-setter body indexer (storage-class-setter storage)) (specialized-setter body indexer (storage-class-setter storage))
storage storage
body body
coeffs
indexer indexer
safe?))) safe?)))
@ -385,11 +457,16 @@
(define (specialized-array-share array new-domain project) (define (specialized-array-share array new-domain project)
(assert (specialized-array? array) (interval? new-domain)) (assert (specialized-array? array) (interval? new-domain))
(let ((body (array-body array)) (let* ((body (array-body array))
(indexer (lambda multi-index (coeffs
(indexer->coeffs
(lambda multi-index
(call-with-values (call-with-values
(lambda () (apply project multi-index)) (lambda () (apply project multi-index))
(array-indexer array)))) (array-indexer array)))
new-domain))
(indexer
(coeffs->indexer coeffs new-domain))
(storage (array-storage-class array))) (storage (array-storage-class array)))
(%make-array (%make-array
new-domain new-domain
@ -397,9 +474,12 @@
(specialized-setter body indexer (storage-class-setter storage)) (specialized-setter body indexer (storage-class-setter storage))
storage storage
body body
coeffs
indexer indexer
(array-safe? array)))) (array-safe? array))))
;; Array transformations
(define (array-copy array . o) (define (array-copy array . o)
(assert (array? array)) (assert (array? array))
(let* ((storage (if (pair? o) (car o) generic-storage-class)) (let* ((storage (if (pair? o) (car o) generic-storage-class))
@ -414,13 +494,14 @@
(let* ((body ((storage-class-maker storage) (let* ((body ((storage-class-maker storage)
(interval-volume new-domain) (interval-volume new-domain)
(storage-class-default storage))) (storage-class-default storage)))
(indexer (default-indexer new-domain)) (coeffs (default-coeffs new-domain))
(indexer (coeffs->indexer coeffs new-domain))
(getter (specialized-getter body indexer (getter (specialized-getter body indexer
(storage-class-getter storage))) (storage-class-getter storage)))
(setter (specialized-setter body indexer (setter (specialized-setter body indexer
(storage-class-setter storage))) (storage-class-setter storage)))
(res (%make-array new-domain getter setter (res (%make-array new-domain getter setter
storage body indexer safe?))) storage body coeffs indexer safe?)))
(array-assign! res array)))) (array-assign! res array))))
(define (array-curry array inner-dimension) (define (array-curry array inner-dimension)
@ -736,10 +817,13 @@
(define (reshape-indexer array new-domain) (define (reshape-indexer array new-domain)
(let ((orig-indexer (array-indexer array)) (let ((orig-indexer (array-indexer array))
(tmp-indexer (default-indexer new-domain))) (tmp-indexer (default-indexer new-domain)))
(indexer->coeffs
(lambda multi-index (lambda multi-index
(apply orig-indexer (apply orig-indexer
(invert-default-index (array-domain array) (invert-default-index (array-domain array)
(apply tmp-indexer multi-index)))))) (apply tmp-indexer multi-index))))
new-domain
#t)))
(define (specialized-array-reshape array new-domain . o) (define (specialized-array-reshape array new-domain . o)
(assert (specialized-array? array) (assert (specialized-array? array)
@ -748,8 +832,9 @@
(let ((copy-on-failure? (and (pair? o) (car o)))) (let ((copy-on-failure? (and (pair? o) (car o))))
(cond (cond
((reshape-indexer array new-domain) ((reshape-indexer array new-domain)
=> (lambda (new-indexer) => (lambda (new-coeffs)
(let ((body (array-body array)) (let* ((new-indexer (coeffs->indexer new-coeffs new-domain))
(body (array-body array))
(storage (array-storage-class array))) (storage (array-storage-class array)))
(%make-array (%make-array
new-domain new-domain
@ -761,6 +846,7 @@
(storage-class-setter storage)) (storage-class-setter storage))
storage storage
body body
new-coeffs
new-indexer new-indexer
(array-safe? array))))) (array-safe? array)))))
(copy-on-failure? (copy-on-failure?

View file

@ -3,7 +3,6 @@
(scheme list) (scheme list)
(scheme vector) (scheme vector)
(scheme sort) (scheme sort)
(scheme write) ;
(srfi 160 base) (srfi 160 base)
(chibi assert)) (chibi assert))
(export (export

View file

@ -3434,23 +3434,23 @@ OTHER DEALINGS IN THE SOFTWARE.
;; '#(#f #f #t #t)) ;; '#(#f #f #t #t))
;; (make-interval '#(3 2)))) ;; (make-interval '#(3 2))))
;; (test-error (test-error
;; (specialized-array-reshape (specialized-array-reshape
;; (array-sample (array-sample
;; (array-reverse (array-reverse
;; (array-copy (make-array (make-interval '#(2 1 3 1)) list)) (array-copy (make-array (make-interval '#(2 1 3 1)) list))
;; '#(#f #f #f #t)) '#(#f #f #f #t))
;; '#(1 1 2 1)) '#(1 1 2 1))
;; (make-interval '#(4)))) (make-interval '#(4))))
;; (test-error (test-error
;; (specialized-array-reshape (specialized-array-reshape
;; (array-sample (array-sample
;; (array-reverse (array-reverse
;; (array-copy (make-array (make-interval '#(2 1 4 1)) list)) (array-copy (make-array (make-interval '#(2 1 4 1)) list))
;; '#(#f #f #t #t)) '#(#f #f #t #t))
;; '#(1 1 2 1)) '#(1 1 2 1))
;; (make-interval '#(4)))) (make-interval '#(4))))
) )
(test-group "curry tests" (test-group "curry tests"