flattening array indexers

This commit is contained in:
Alex Shinn 2021-04-28 22:53:16 +09:00
parent bf03c1cfa1
commit 76284f79f0
3 changed files with 145 additions and 60 deletions

View file

@ -261,26 +261,27 @@
;; Arrays ;; Arrays
(define-record-type Array (define-record-type Array
(%%make-array domain getter setter storage body indexer safe?) (%%make-array domain getter setter storage body coeffs indexer safe?)
array? array?
(domain array-domain) (domain array-domain)
(getter array-getter) (getter array-getter)
(setter array-setter) (setter array-setter)
(storage array-storage-class) (storage array-storage-class)
(body array-body) (body array-body)
(coeffs array-coeffs)
(indexer array-indexer) (indexer array-indexer)
(safe? array-safe?)) (safe? array-safe?))
(define (%make-array domain getter setter storage body indexer safe?) (define (%make-array domain getter setter storage body coeffs indexer safe?)
(assert (interval? domain) (assert (interval? domain)
(procedure? getter) (procedure? getter)
(or (not setter) (procedure? setter)) (or (not setter) (procedure? setter))
(or (not storage) (storage-class? storage))) (or (not storage) (storage-class? storage)))
(%%make-array domain getter setter storage body indexer safe?)) (%%make-array domain getter setter storage body coeffs indexer safe?))
(define (make-array domain getter . o) (define (make-array domain getter . o)
(assert (interval? domain) (procedure? getter)) (assert (interval? domain) (procedure? getter))
(%make-array domain getter (and (pair? o) (car o)) #f #f #f #f)) (%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f))
(define (array-dimension a) (define (array-dimension a)
(interval-dimension (array-domain a))) (interval-dimension (array-domain a)))
@ -308,25 +309,92 @@
(lambda (val . multi-index) (lambda (val . multi-index)
(setter body (apply indexer multi-index) val))) (setter body (apply indexer multi-index) val)))
;; Indexing
(define (indexer->coeffs indexer domain . o)
(let* ((verify? (and (pair? o) (car o)))
(res (make-vector (+ 1 (interval-dimension domain)) 0))
(multi-index (interval-lower-bounds->list domain))
(base (apply indexer multi-index)))
(vector-set! res 0 base)
(let lp ((i 1)
(ls multi-index)
(offset base)
(count 0))
(cond
((null? ls)
(if (and verify? (zero? count))
(lp 1 multi-index offset (+ count 1))
res))
((= (+ 1 (car ls)) (interval-upper-bound domain (- i 1)))
(lp (+ i 1) (cdr ls) offset count))
(else
(set-car! ls (+ 1 (car ls)))
(let* ((offset2 (apply indexer multi-index))
(coeff (- offset2 offset)))
(cond
((> count 0)
(and (= coeff (vector-ref res i))
(lp (+ i 1) (cdr ls) offset2 count)))
(else
(vector-set! res i coeff)
(lp (+ i 1) (cdr ls) offset2 count)))))))))
(define (coeffs->indexer coeffs domain)
(case (vector-length coeffs)
((2)
(let ((a (vector-ref coeffs 0))
(b (vector-ref coeffs 1))
(lo-x (interval-lower-bound domain 0)))
(lambda (x) (+ a (* b (- x lo-x))))))
((3)
(let ((a (vector-ref coeffs 0))
(b (vector-ref coeffs 1))
(c (vector-ref coeffs 2))
(lo-x (interval-lower-bound domain 0))
(lo-y (interval-lower-bound domain 1)))
(lambda (x y) (+ a (* b (- x lo-x)) (* c (- y lo-y))))))
(else
(lambda multi-index
(let ((lim (vector-length coeffs)))
(let lp ((ls multi-index)
(i 1)
(res (vector-ref coeffs 0)))
(cond
((null? ls)
(if (< i lim)
(error "multi-index too short for domain" multi-index domain)
res))
((>= i lim)
(error "multi-index too long for domain" multi-index domain))
(else
(lp (cdr ls)
(+ i 1)
(+ res (* (- (car ls) (interval-lower-bound domain (- i 1)))
(vector-ref coeffs i))))))))))))
(define (default-coeffs domain)
(let* ((dim (interval-dimension domain))
(res (make-vector (+ 1 dim))))
(vector-set! res 0 0)
(vector-set! res dim 1)
(let lp ((i (- dim 1))
(scale 1))
(cond
((< i 0)
res)
((= (+ 1 (interval-lower-bound domain i))
(interval-upper-bound domain i))
(vector-set! res (+ i 1) 0)
(lp (- i 1) scale))
(else
(let ((coeff (* scale (- (interval-upper-bound domain i)
(interval-lower-bound domain i)))))
(vector-set! res (+ i 1) scale)
(lp (- i 1) coeff)))))))
(define (default-indexer domain) (define (default-indexer domain)
(lambda multi-index (coeffs->indexer (default-coeffs domain) domain))
(let ((dim (interval-dimension domain)))
(let lp ((ls multi-index)
(i 0)
(res 0))
(cond
((null? ls)
(if (< i dim)
(error "multi-index too short for domain" multi-index domain)
res))
((>= i dim)
(error "multi-index too long for domain" multi-index domain))
(else
(lp (cdr ls)
(+ i 1)
(+ (- (car ls) (interval-lower-bound domain i))
(* res (- (interval-upper-bound domain i)
(interval-lower-bound domain i)))))))))))
;; converts the raw integer index to the multi-index in domain that ;; converts the raw integer index to the multi-index in domain that
;; would map to it using the default indexer. ;; would map to it using the default indexer.
@ -345,6 +413,8 @@
(* scale width) (* scale width)
(cons (+ elt (interval-lower-bound domain i)) res)))))) (cons (+ elt (interval-lower-bound domain i)) res))))))
;; Specialized arrays
(define (make-specialized-array domain . o) (define (make-specialized-array domain . o)
(let* ((storage (if (pair? o) (car o) generic-storage-class)) (let* ((storage (if (pair? o) (car o) generic-storage-class))
(safe? (if (and (pair? o) (pair? (cdr o))) (safe? (if (and (pair? o) (pair? (cdr o)))
@ -353,7 +423,8 @@
(body ((storage-class-maker storage) (body ((storage-class-maker storage)
(interval-volume domain) (interval-volume domain)
(storage-class-default storage))) (storage-class-default storage)))
(indexer (default-indexer domain))) (coeffs (default-coeffs domain))
(indexer (coeffs->indexer coeffs domain)))
(assert (boolean? safe?)) (assert (boolean? safe?))
(%make-array (%make-array
domain domain
@ -361,6 +432,7 @@
(specialized-setter body indexer (storage-class-setter storage)) (specialized-setter body indexer (storage-class-setter storage))
storage storage
body body
coeffs
indexer indexer
safe?))) safe?)))
@ -385,21 +457,29 @@
(define (specialized-array-share array new-domain project) (define (specialized-array-share array new-domain project)
(assert (specialized-array? array) (interval? new-domain)) (assert (specialized-array? array) (interval? new-domain))
(let ((body (array-body array)) (let* ((body (array-body array))
(indexer (lambda multi-index (coeffs
(call-with-values (indexer->coeffs
(lambda () (apply project multi-index)) (lambda multi-index
(array-indexer array)))) (call-with-values
(storage (array-storage-class array))) (lambda () (apply project multi-index))
(array-indexer array)))
new-domain))
(indexer
(coeffs->indexer coeffs new-domain))
(storage (array-storage-class array)))
(%make-array (%make-array
new-domain new-domain
(specialized-getter body indexer (storage-class-getter storage)) (specialized-getter body indexer (storage-class-getter storage))
(specialized-setter body indexer (storage-class-setter storage)) (specialized-setter body indexer (storage-class-setter storage))
storage storage
body body
coeffs
indexer indexer
(array-safe? array)))) (array-safe? array))))
;; Array transformations
(define (array-copy array . o) (define (array-copy array . o)
(assert (array? array)) (assert (array? array))
(let* ((storage (if (pair? o) (car o) generic-storage-class)) (let* ((storage (if (pair? o) (car o) generic-storage-class))
@ -414,13 +494,14 @@
(let* ((body ((storage-class-maker storage) (let* ((body ((storage-class-maker storage)
(interval-volume new-domain) (interval-volume new-domain)
(storage-class-default storage))) (storage-class-default storage)))
(indexer (default-indexer new-domain)) (coeffs (default-coeffs new-domain))
(getter (specialized-getter body indexer (indexer (coeffs->indexer coeffs new-domain))
(storage-class-getter storage))) (getter (specialized-getter body indexer
(setter (specialized-setter body indexer (storage-class-getter storage)))
(storage-class-setter storage))) (setter (specialized-setter body indexer
(res (%make-array new-domain getter setter (storage-class-setter storage)))
storage body indexer safe?))) (res (%make-array new-domain getter setter
storage body coeffs indexer safe?)))
(array-assign! res array)))) (array-assign! res array))))
(define (array-curry array inner-dimension) (define (array-curry array inner-dimension)
@ -736,10 +817,13 @@
(define (reshape-indexer array new-domain) (define (reshape-indexer array new-domain)
(let ((orig-indexer (array-indexer array)) (let ((orig-indexer (array-indexer array))
(tmp-indexer (default-indexer new-domain))) (tmp-indexer (default-indexer new-domain)))
(lambda multi-index (indexer->coeffs
(apply orig-indexer (lambda multi-index
(invert-default-index (array-domain array) (apply orig-indexer
(apply tmp-indexer multi-index)))))) (invert-default-index (array-domain array)
(apply tmp-indexer multi-index))))
new-domain
#t)))
(define (specialized-array-reshape array new-domain . o) (define (specialized-array-reshape array new-domain . o)
(assert (specialized-array? array) (assert (specialized-array? array)
@ -748,9 +832,10 @@
(let ((copy-on-failure? (and (pair? o) (car o)))) (let ((copy-on-failure? (and (pair? o) (car o))))
(cond (cond
((reshape-indexer array new-domain) ((reshape-indexer array new-domain)
=> (lambda (new-indexer) => (lambda (new-coeffs)
(let ((body (array-body array)) (let* ((new-indexer (coeffs->indexer new-coeffs new-domain))
(storage (array-storage-class array))) (body (array-body array))
(storage (array-storage-class array)))
(%make-array (%make-array
new-domain new-domain
(specialized-getter body (specialized-getter body
@ -761,6 +846,7 @@
(storage-class-setter storage)) (storage-class-setter storage))
storage storage
body body
new-coeffs
new-indexer new-indexer
(array-safe? array))))) (array-safe? array)))))
(copy-on-failure? (copy-on-failure?

View file

@ -3,7 +3,6 @@
(scheme list) (scheme list)
(scheme vector) (scheme vector)
(scheme sort) (scheme sort)
(scheme write) ;
(srfi 160 base) (srfi 160 base)
(chibi assert)) (chibi assert))
(export (export

View file

@ -3434,23 +3434,23 @@ OTHER DEALINGS IN THE SOFTWARE.
;; '#(#f #f #t #t)) ;; '#(#f #f #t #t))
;; (make-interval '#(3 2)))) ;; (make-interval '#(3 2))))
;; (test-error (test-error
;; (specialized-array-reshape (specialized-array-reshape
;; (array-sample (array-sample
;; (array-reverse (array-reverse
;; (array-copy (make-array (make-interval '#(2 1 3 1)) list)) (array-copy (make-array (make-interval '#(2 1 3 1)) list))
;; '#(#f #f #f #t)) '#(#f #f #f #t))
;; '#(1 1 2 1)) '#(1 1 2 1))
;; (make-interval '#(4)))) (make-interval '#(4))))
;; (test-error (test-error
;; (specialized-array-reshape (specialized-array-reshape
;; (array-sample (array-sample
;; (array-reverse (array-reverse
;; (array-copy (make-array (make-interval '#(2 1 4 1)) list)) (array-copy (make-array (make-interval '#(2 1 4 1)) list))
;; '#(#f #f #t #t)) '#(#f #f #t #t))
;; '#(1 1 2 1)) '#(1 1 2 1))
;; (make-interval '#(4)))) (make-interval '#(4))))
) )
(test-group "curry tests" (test-group "curry tests"