array simplification and performance tweaks

This commit is contained in:
Alex Shinn 2021-05-07 16:15:48 +09:00
parent 12ad1d37d8
commit ca47a41ccf
2 changed files with 99 additions and 104 deletions

View file

@ -1,7 +1,7 @@
;; Miscellaneous Functions
(define (translation? x)
(and (vector? x) (vector-every exact-integer? x)))
(and (vector? x) (not (vector-empty? x)) (vector-every exact-integer? x)))
(define (permutation? x)
(and (translation? x)
@ -15,11 +15,11 @@
(u1vector-set! seen (vector-ref x i) 1)
(lp (+ i 1)))))))))
(define (all-equal? ls)
(define (same-dimensions? ls)
(or (null? ls)
(null? (cdr ls))
(and (equal? (car ls) (cadr ls))
(all-equal? (cdr ls)))))
(and (equal? (array-dimension (car ls)) (array-dimension (cadr ls)))
(same-dimensions? (cdr ls)))))
;; Intervals
@ -32,8 +32,6 @@
(define (%make-interval lo hi)
(assert (and (translation? lo)
(translation? hi)
(not (vector-empty? lo))
(not (vector-empty? hi))
(= (vector-length lo) (vector-length hi))
(vector-every < lo hi)))
(%%make-interval lo hi))
@ -92,20 +90,18 @@
(set-car! (car rev-index) (car rev-lowers))
(rev-index-next! (cdr rev-index) (cdr rev-lowers) (cdr rev-uppers)))))
(define (interval-fold kons knil iv . o)
(define (interval-fold kons knil iv)
(case (interval-dimension iv)
((1)
(let ((end (interval-upper-bound iv 0)))
(do ((i (if (pair? o) (car o) (interval-lower-bound iv 0))
(+ i 1))
(do ((i (interval-lower-bound iv 0) (+ i 1))
(acc knil (kons acc i)))
((>= i end) acc))))
((2)
(let ((end0 (interval-upper-bound iv 0))
(start1 (if (pair? o) (cadr o) (interval-lower-bound iv 1)))
(start1 (interval-lower-bound iv 1))
(end1 (interval-upper-bound iv 1)))
(do ((i (if (pair? o) (car o) (interval-lower-bound iv 0))
(+ i 1))
(do ((i (interval-lower-bound iv 0) (+ i 1))
(acc knil
(do ((j start1 (+ j 1))
(acc acc (kons acc i j)))
@ -114,8 +110,7 @@
(else
(let* ((rev-lowers (reverse (interval-lower-bounds->list iv)))
(rev-uppers (reverse (interval-upper-bounds->list iv)))
(multi-index
(list-copy (if (pair? o) o (interval-lower-bounds->list iv))))
(multi-index (interval-lower-bounds->list iv))
(rev-index (pair-fold cons '() multi-index)))
(let lp ((acc knil))
(let ((acc (apply kons acc multi-index)))
@ -261,7 +256,7 @@
;; Arrays
(define-record-type Array
(%%make-array domain getter setter storage body coeffs indexer safe?)
(%%make-array domain getter setter storage body coeffs indexer safe? adjacent?)
array?
(domain array-domain)
(getter array-getter)
@ -270,18 +265,21 @@
(body array-body)
(coeffs array-coeffs)
(indexer array-indexer)
(safe? array-safe?))
(safe? array-safe?)
(adjacent? array-adjacent? array-adjacent?-set!))
(define (%make-array domain getter setter storage body coeffs indexer safe?)
(define (%make-array domain getter setter storage body coeffs
indexer safe? adjacent?)
(assert (and (interval? domain)
(procedure? getter)
(or (not setter) (procedure? setter))
(or (not storage) (storage-class? storage))))
(%%make-array domain getter setter storage body coeffs indexer safe?))
(%%make-array
domain getter setter storage body coeffs indexer safe? adjacent?))
(define (make-array domain getter . o)
(assert (and (interval? domain) (procedure? getter)))
(%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f))
(%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f #f))
(define (array-dimension a)
(interval-dimension (array-domain a)))
@ -438,7 +436,8 @@
;; Specialized arrays
(define (%make-specialized domain storage body coeffs indexer safe? mutable?)
(define (%make-specialized domain storage body coeffs indexer
safe? mutable? adjacent?)
(%make-array
domain
(specialized-getter body indexer (storage-class-getter storage))
@ -448,7 +447,8 @@
body
coeffs
indexer
safe?))
safe?
adjacent?))
(define (make-specialized-array domain . o)
(let* ((storage (if (pair? o) (car o) generic-storage-class))
@ -461,14 +461,12 @@
(coeffs (default-coeffs domain))
(indexer (coeffs->indexer coeffs domain)))
(assert (boolean? safe?))
(%make-specialized domain storage body coeffs indexer safe? #t)))
(%make-specialized domain storage body coeffs indexer safe? #t #t)))
(define (specialized-array? x)
(and (array? x) (array-storage-class x) #t))
(define (array-elements-in-order? array)
(assert (specialized-array? array))
;; TODO: speed this up and/or cache it
(define (compute-array-elements-in-order? array)
(let ((indexer (array-indexer array)))
(call-with-current-continuation
(lambda (return)
@ -482,6 +480,14 @@
(array-domain array))
#t))))
(define (array-elements-in-order? array)
(assert (specialized-array? array))
(let ((res (array-adjacent? array)))
(when (eq? res 'unknown)
(set! res (compute-array-elements-in-order? array))
(array-adjacent?-set! array res))
res))
(define (specialized-array-share array new-domain project)
(assert (and (specialized-array? array) (interval? new-domain)))
(let* ((body (array-body array))
@ -496,7 +502,7 @@
(coeffs->indexer coeffs new-domain))
(storage (array-storage-class array)))
(%make-specialized new-domain storage body coeffs indexer
(array-safe? array) (array-setter array))))
(array-safe? array) (array-setter array) 'unknown)))
;; Array transformations
@ -521,7 +527,7 @@
(setter (specialized-setter body indexer
(storage-class-setter storage)))
(res (%make-specialized new-domain storage body coeffs indexer
safe? #t)))
safe? #t #t)))
(array-assign! res array)
(unless mutable?
(%array-setter-set! res #f))
@ -561,9 +567,7 @@
(interval-subset? new-domain (array-domain array))))
(if (specialized-array? array)
(specialized-array-share array new-domain values)
(make-array new-domain
(array-getter array)
(array-setter array))))
(make-array new-domain (array-getter array) (array-setter array))))
(define (array-tile array sizes)
(assert (and (array? array)
@ -595,8 +599,8 @@
(lambda (i lo hi s)
(min hi (+ lo (* (+ i 1) s))))
multi-index
(interval-lower-bound (array-domain array))
(interval-upper-bound (array-domain array))
(interval-lb (array-domain array))
(interval-ub (array-domain array))
sizes)))))))
(define (array-translate array translation)
@ -626,13 +630,11 @@
(define (inverse-permutation permutation)
(list->vector
(map
car
(list-sort
(lambda (a b) (< (cdr a) (cdr b)))
(map cons
(iota (vector-length permutation))
(vector->list permutation))))))
(map car
(list-sort (lambda (a b) (< (cdr a) (cdr b)))
(map cons
(iota (vector-length permutation))
(vector->list permutation))))))
(define (array-permute array permutation)
(assert (permutation? permutation))
@ -729,7 +731,7 @@
(make-array (array-domain array)
(let* ((ls (cons array arrays))
(getters (map array-getter ls)))
(assert (all-equal? (map array-dimension ls)))
(assert (same-dimensions? ls))
(lambda multi-index
(apply f (map (lambda (g) (apply g multi-index)) getters))))))
@ -737,7 +739,7 @@
(interval-for-each
(let* ((ls (cons array arrays))
(getters (map array-getter ls)))
(assert (all-equal? (map array-dimension ls)))
(assert (same-dimensions? ls))
(lambda multi-index
(apply f (map (lambda (g) (apply g multi-index)) getters))))
(array-domain array)))
@ -752,41 +754,39 @@
(fold-right kons knil (array->list array)))
(define (array-reduce op array)
;; (let* ((domain (array-domain array))
;; (init-index (interval-lower-bounds->list domain))
;; (knil (apply array-ref array init-index)))
;; (if (rev-index-next! (pair-fold cons '() init-index)
;; (reverse (interval-lower-bounds->list domain))
;; (reverse (interval-upper-bounds->list domain)))
;; (apply interval-fold
;; (lambda (acc . multi-index)
;; (op acc (apply array-ref array multi-index)))
;; knil
;; domain
;; init-index)
;; knil))
(reduce (lambda (elt acc) (op acc elt)) 'never-used (array->list array)))
(let* ((domain (array-domain array))
(init-index (interval-lower-bounds->list domain))
(knil (list 'first-element)))
(interval-fold
(lambda (acc . multi-index)
(if (eq? acc knil)
(apply array-ref array multi-index)
(op acc (apply array-ref array multi-index))))
knil
domain)))
(define (array-any pred array . arrays)
(assert (all-equal? (map array-dimension (cons array arrays))))
(assert (same-dimensions? (cons array arrays)))
(call-with-current-continuation
(lambda (return)
(apply array-for-each
(lambda args (if (apply pred args) (return #t)))
(lambda args (cond ((apply pred args) => return)))
#f
array
arrays)
#f)))
(define (array-every pred array . arrays)
(assert (all-equal? (map array-dimension (cons array arrays))))
(assert (same-dimensions? (cons array arrays)))
(call-with-current-continuation
(lambda (return)
;; TODO: return last value
(apply array-for-each
(lambda args (if (not (apply pred args)) (return #f)))
array
arrays)
#t)))
(interval-fold
(let ((getters (map array-getter (cons array arrays))))
(lambda (acc . multi-index)
(or (apply pred (map (lambda (g) (apply g multi-index)) getters))
(return #f))))
#t
(array-domain array)))))
(define (array->list array)
(reverse (array-fold cons '() array)))
@ -811,31 +811,34 @@
res))
(define (array-assign! destination source)
(assert
(and (array? destination)
(mutable-array? destination)
(array? source)
(or (equal? (array-domain destination) (array-domain source))
(and (array-elements-in-order? destination)
(equal? (interval-volume (array-domain destination))
(interval-volume (array-domain source)))))))
(assert (and (mutable-array? destination) (array? source)))
(let ((getter (array-getter source))
(setter (array-setter destination)))
(if (equal? (array-domain destination) (array-domain source))
(cond
((equal? (array-domain destination) (array-domain source))
(interval-for-each
(case (array-dimension destination)
((1) (lambda (i) (setter (getter i) i)))
((2) (lambda (i j) (setter (getter i j) i j)))
((3) (lambda (i j k) (setter (getter i j k) i j k)))
(else
(lambda multi-index
(apply setter (apply getter multi-index) multi-index))))
(array-domain source)))
(else
(assert (and (array-elements-in-order? destination)
(equal? (interval-volume (array-domain destination))
(interval-volume (array-domain source)))))
(let* ((dst-domain (array-domain destination))
(rev-lowers (reverse (interval-lower-bounds->list dst-domain)))
(rev-uppers (reverse (interval-upper-bounds->list dst-domain)))
(dst-index (list-copy (interval-lower-bounds->list dst-domain)))
(rev-index (pair-fold cons '() dst-index)))
(interval-for-each
(lambda multi-index
(apply setter (apply getter multi-index) multi-index))
(array-domain source))
(let* ((dst-domain (array-domain destination))
(rev-lowers (reverse (interval-lower-bounds->list dst-domain)))
(rev-uppers (reverse (interval-upper-bounds->list dst-domain)))
(dst-index (list-copy (interval-lower-bounds->list dst-domain)))
(rev-index (pair-fold cons '() dst-index)))
(interval-for-each
(lambda multi-index
(apply setter (apply getter multi-index) dst-index)
(rev-index-next! rev-index rev-lowers rev-uppers))
(array-domain source))))
(apply setter (apply getter multi-index) dst-index)
(rev-index-next! rev-index rev-lowers rev-uppers))
(array-domain source)))))
destination))
(define (reshape-without-copy array new-domain)
@ -847,16 +850,15 @@
(apply orig-indexer
(invert-default-index domain
(apply tmp-indexer multi-index)))))
(new-coeffs
(indexer->coeffs new-indexer new-domain #t))
(flat-indexer
(coeffs->indexer new-coeffs new-domain))
(new-coeffs (indexer->coeffs new-indexer new-domain #t))
(flat-indexer (coeffs->indexer new-coeffs new-domain))
(new-indexer (coeffs->indexer new-coeffs new-domain))
(body (array-body array))
(storage (array-storage-class array))
(res
(%make-specialized new-domain storage body new-coeffs flat-indexer
(array-safe? array) (array-setter array))))
(array-safe? array) (array-setter array)
(array-adjacent? array))))
(let ((multi-index (interval-lower-bounds->list domain))
(orig-default-indexer (default-indexer domain)))
(let lp ((i 0)
@ -886,20 +888,11 @@
(cond
((reshape-without-copy array new-domain))
(copy-on-failure?
(let* ((res (make-specialized-array
new-domain
(array-storage-class array)
(array-safe? array)))
(setter (array-setter res))
(multi-index (interval-lower-bounds->list new-domain))
(rev-index (pair-fold cons '() multi-index))
(rev-lowers (reverse (interval-lower-bounds->list new-domain)))
(rev-uppers (reverse (interval-upper-bounds->list new-domain))))
(array-for-each
(lambda (x)
(apply setter x multi-index)
(rev-index-next! rev-index rev-lowers rev-uppers))
array)
(let ((res (make-specialized-array
new-domain
(array-storage-class array)
(array-safe? array))))
(array-assign! res array)
res))
(else
(error "can't reshape" array new-domain)))))

View file

@ -890,6 +890,8 @@ OTHER DEALINGS IN THE SOFTWARE.
(define (run-tests)
(random-source-pseudo-randomize! default-random-source 7 23)
(test-begin "srfi-179: nonempty intervals and generalized arrays")
(test-group "interval tests"