array simplification and performance tweaks

This commit is contained in:
Alex Shinn 2021-05-07 16:15:48 +09:00
parent 12ad1d37d8
commit ca47a41ccf
2 changed files with 99 additions and 104 deletions

View file

@ -1,7 +1,7 @@
;; Miscellaneous Functions ;; Miscellaneous Functions
(define (translation? x) (define (translation? x)
(and (vector? x) (vector-every exact-integer? x))) (and (vector? x) (not (vector-empty? x)) (vector-every exact-integer? x)))
(define (permutation? x) (define (permutation? x)
(and (translation? x) (and (translation? x)
@ -15,11 +15,11 @@
(u1vector-set! seen (vector-ref x i) 1) (u1vector-set! seen (vector-ref x i) 1)
(lp (+ i 1))))))))) (lp (+ i 1)))))))))
(define (all-equal? ls) (define (same-dimensions? ls)
(or (null? ls) (or (null? ls)
(null? (cdr ls)) (null? (cdr ls))
(and (equal? (car ls) (cadr ls)) (and (equal? (array-dimension (car ls)) (array-dimension (cadr ls)))
(all-equal? (cdr ls))))) (same-dimensions? (cdr ls)))))
;; Intervals ;; Intervals
@ -32,8 +32,6 @@
(define (%make-interval lo hi) (define (%make-interval lo hi)
(assert (and (translation? lo) (assert (and (translation? lo)
(translation? hi) (translation? hi)
(not (vector-empty? lo))
(not (vector-empty? hi))
(= (vector-length lo) (vector-length hi)) (= (vector-length lo) (vector-length hi))
(vector-every < lo hi))) (vector-every < lo hi)))
(%%make-interval lo hi)) (%%make-interval lo hi))
@ -92,20 +90,18 @@
(set-car! (car rev-index) (car rev-lowers)) (set-car! (car rev-index) (car rev-lowers))
(rev-index-next! (cdr rev-index) (cdr rev-lowers) (cdr rev-uppers))))) (rev-index-next! (cdr rev-index) (cdr rev-lowers) (cdr rev-uppers)))))
(define (interval-fold kons knil iv . o) (define (interval-fold kons knil iv)
(case (interval-dimension iv) (case (interval-dimension iv)
((1) ((1)
(let ((end (interval-upper-bound iv 0))) (let ((end (interval-upper-bound iv 0)))
(do ((i (if (pair? o) (car o) (interval-lower-bound iv 0)) (do ((i (interval-lower-bound iv 0) (+ i 1))
(+ i 1))
(acc knil (kons acc i))) (acc knil (kons acc i)))
((>= i end) acc)))) ((>= i end) acc))))
((2) ((2)
(let ((end0 (interval-upper-bound iv 0)) (let ((end0 (interval-upper-bound iv 0))
(start1 (if (pair? o) (cadr o) (interval-lower-bound iv 1))) (start1 (interval-lower-bound iv 1))
(end1 (interval-upper-bound iv 1))) (end1 (interval-upper-bound iv 1)))
(do ((i (if (pair? o) (car o) (interval-lower-bound iv 0)) (do ((i (interval-lower-bound iv 0) (+ i 1))
(+ i 1))
(acc knil (acc knil
(do ((j start1 (+ j 1)) (do ((j start1 (+ j 1))
(acc acc (kons acc i j))) (acc acc (kons acc i j)))
@ -114,8 +110,7 @@
(else (else
(let* ((rev-lowers (reverse (interval-lower-bounds->list iv))) (let* ((rev-lowers (reverse (interval-lower-bounds->list iv)))
(rev-uppers (reverse (interval-upper-bounds->list iv))) (rev-uppers (reverse (interval-upper-bounds->list iv)))
(multi-index (multi-index (interval-lower-bounds->list iv))
(list-copy (if (pair? o) o (interval-lower-bounds->list iv))))
(rev-index (pair-fold cons '() multi-index))) (rev-index (pair-fold cons '() multi-index)))
(let lp ((acc knil)) (let lp ((acc knil))
(let ((acc (apply kons acc multi-index))) (let ((acc (apply kons acc multi-index)))
@ -261,7 +256,7 @@
;; Arrays ;; Arrays
(define-record-type Array (define-record-type Array
(%%make-array domain getter setter storage body coeffs indexer safe?) (%%make-array domain getter setter storage body coeffs indexer safe? adjacent?)
array? array?
(domain array-domain) (domain array-domain)
(getter array-getter) (getter array-getter)
@ -270,18 +265,21 @@
(body array-body) (body array-body)
(coeffs array-coeffs) (coeffs array-coeffs)
(indexer array-indexer) (indexer array-indexer)
(safe? array-safe?)) (safe? array-safe?)
(adjacent? array-adjacent? array-adjacent?-set!))
(define (%make-array domain getter setter storage body coeffs indexer safe?) (define (%make-array domain getter setter storage body coeffs
indexer safe? adjacent?)
(assert (and (interval? domain) (assert (and (interval? domain)
(procedure? getter) (procedure? getter)
(or (not setter) (procedure? setter)) (or (not setter) (procedure? setter))
(or (not storage) (storage-class? storage)))) (or (not storage) (storage-class? storage))))
(%%make-array domain getter setter storage body coeffs indexer safe?)) (%%make-array
domain getter setter storage body coeffs indexer safe? adjacent?))
(define (make-array domain getter . o) (define (make-array domain getter . o)
(assert (and (interval? domain) (procedure? getter))) (assert (and (interval? domain) (procedure? getter)))
(%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f)) (%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f #f))
(define (array-dimension a) (define (array-dimension a)
(interval-dimension (array-domain a))) (interval-dimension (array-domain a)))
@ -438,7 +436,8 @@
;; Specialized arrays ;; Specialized arrays
(define (%make-specialized domain storage body coeffs indexer safe? mutable?) (define (%make-specialized domain storage body coeffs indexer
safe? mutable? adjacent?)
(%make-array (%make-array
domain domain
(specialized-getter body indexer (storage-class-getter storage)) (specialized-getter body indexer (storage-class-getter storage))
@ -448,7 +447,8 @@
body body
coeffs coeffs
indexer indexer
safe?)) safe?
adjacent?))
(define (make-specialized-array domain . o) (define (make-specialized-array domain . o)
(let* ((storage (if (pair? o) (car o) generic-storage-class)) (let* ((storage (if (pair? o) (car o) generic-storage-class))
@ -461,14 +461,12 @@
(coeffs (default-coeffs domain)) (coeffs (default-coeffs domain))
(indexer (coeffs->indexer coeffs domain))) (indexer (coeffs->indexer coeffs domain)))
(assert (boolean? safe?)) (assert (boolean? safe?))
(%make-specialized domain storage body coeffs indexer safe? #t))) (%make-specialized domain storage body coeffs indexer safe? #t #t)))
(define (specialized-array? x) (define (specialized-array? x)
(and (array? x) (array-storage-class x) #t)) (and (array? x) (array-storage-class x) #t))
(define (array-elements-in-order? array) (define (compute-array-elements-in-order? array)
(assert (specialized-array? array))
;; TODO: speed this up and/or cache it
(let ((indexer (array-indexer array))) (let ((indexer (array-indexer array)))
(call-with-current-continuation (call-with-current-continuation
(lambda (return) (lambda (return)
@ -482,6 +480,14 @@
(array-domain array)) (array-domain array))
#t)))) #t))))
(define (array-elements-in-order? array)
(assert (specialized-array? array))
(let ((res (array-adjacent? array)))
(when (eq? res 'unknown)
(set! res (compute-array-elements-in-order? array))
(array-adjacent?-set! array res))
res))
(define (specialized-array-share array new-domain project) (define (specialized-array-share array new-domain project)
(assert (and (specialized-array? array) (interval? new-domain))) (assert (and (specialized-array? array) (interval? new-domain)))
(let* ((body (array-body array)) (let* ((body (array-body array))
@ -496,7 +502,7 @@
(coeffs->indexer coeffs new-domain)) (coeffs->indexer coeffs new-domain))
(storage (array-storage-class array))) (storage (array-storage-class array)))
(%make-specialized new-domain storage body coeffs indexer (%make-specialized new-domain storage body coeffs indexer
(array-safe? array) (array-setter array)))) (array-safe? array) (array-setter array) 'unknown)))
;; Array transformations ;; Array transformations
@ -521,7 +527,7 @@
(setter (specialized-setter body indexer (setter (specialized-setter body indexer
(storage-class-setter storage))) (storage-class-setter storage)))
(res (%make-specialized new-domain storage body coeffs indexer (res (%make-specialized new-domain storage body coeffs indexer
safe? #t))) safe? #t #t)))
(array-assign! res array) (array-assign! res array)
(unless mutable? (unless mutable?
(%array-setter-set! res #f)) (%array-setter-set! res #f))
@ -561,9 +567,7 @@
(interval-subset? new-domain (array-domain array)))) (interval-subset? new-domain (array-domain array))))
(if (specialized-array? array) (if (specialized-array? array)
(specialized-array-share array new-domain values) (specialized-array-share array new-domain values)
(make-array new-domain (make-array new-domain (array-getter array) (array-setter array))))
(array-getter array)
(array-setter array))))
(define (array-tile array sizes) (define (array-tile array sizes)
(assert (and (array? array) (assert (and (array? array)
@ -595,8 +599,8 @@
(lambda (i lo hi s) (lambda (i lo hi s)
(min hi (+ lo (* (+ i 1) s)))) (min hi (+ lo (* (+ i 1) s))))
multi-index multi-index
(interval-lower-bound (array-domain array)) (interval-lb (array-domain array))
(interval-upper-bound (array-domain array)) (interval-ub (array-domain array))
sizes))))))) sizes)))))))
(define (array-translate array translation) (define (array-translate array translation)
@ -626,10 +630,8 @@
(define (inverse-permutation permutation) (define (inverse-permutation permutation)
(list->vector (list->vector
(map (map car
car (list-sort (lambda (a b) (< (cdr a) (cdr b)))
(list-sort
(lambda (a b) (< (cdr a) (cdr b)))
(map cons (map cons
(iota (vector-length permutation)) (iota (vector-length permutation))
(vector->list permutation)))))) (vector->list permutation))))))
@ -729,7 +731,7 @@
(make-array (array-domain array) (make-array (array-domain array)
(let* ((ls (cons array arrays)) (let* ((ls (cons array arrays))
(getters (map array-getter ls))) (getters (map array-getter ls)))
(assert (all-equal? (map array-dimension ls))) (assert (same-dimensions? ls))
(lambda multi-index (lambda multi-index
(apply f (map (lambda (g) (apply g multi-index)) getters)))))) (apply f (map (lambda (g) (apply g multi-index)) getters))))))
@ -737,7 +739,7 @@
(interval-for-each (interval-for-each
(let* ((ls (cons array arrays)) (let* ((ls (cons array arrays))
(getters (map array-getter ls))) (getters (map array-getter ls)))
(assert (all-equal? (map array-dimension ls))) (assert (same-dimensions? ls))
(lambda multi-index (lambda multi-index
(apply f (map (lambda (g) (apply g multi-index)) getters)))) (apply f (map (lambda (g) (apply g multi-index)) getters))))
(array-domain array))) (array-domain array)))
@ -752,41 +754,39 @@
(fold-right kons knil (array->list array))) (fold-right kons knil (array->list array)))
(define (array-reduce op array) (define (array-reduce op array)
;; (let* ((domain (array-domain array)) (let* ((domain (array-domain array))
;; (init-index (interval-lower-bounds->list domain)) (init-index (interval-lower-bounds->list domain))
;; (knil (apply array-ref array init-index))) (knil (list 'first-element)))
;; (if (rev-index-next! (pair-fold cons '() init-index) (interval-fold
;; (reverse (interval-lower-bounds->list domain)) (lambda (acc . multi-index)
;; (reverse (interval-upper-bounds->list domain))) (if (eq? acc knil)
;; (apply interval-fold (apply array-ref array multi-index)
;; (lambda (acc . multi-index) (op acc (apply array-ref array multi-index))))
;; (op acc (apply array-ref array multi-index))) knil
;; knil domain)))
;; domain
;; init-index)
;; knil))
(reduce (lambda (elt acc) (op acc elt)) 'never-used (array->list array)))
(define (array-any pred array . arrays) (define (array-any pred array . arrays)
(assert (all-equal? (map array-dimension (cons array arrays)))) (assert (same-dimensions? (cons array arrays)))
(call-with-current-continuation (call-with-current-continuation
(lambda (return) (lambda (return)
(apply array-for-each (apply array-for-each
(lambda args (if (apply pred args) (return #t))) (lambda args (cond ((apply pred args) => return)))
#f
array array
arrays) arrays)
#f))) #f)))
(define (array-every pred array . arrays) (define (array-every pred array . arrays)
(assert (all-equal? (map array-dimension (cons array arrays)))) (assert (same-dimensions? (cons array arrays)))
(call-with-current-continuation (call-with-current-continuation
(lambda (return) (lambda (return)
;; TODO: return last value (interval-fold
(apply array-for-each (let ((getters (map array-getter (cons array arrays))))
(lambda args (if (not (apply pred args)) (return #f))) (lambda (acc . multi-index)
array (or (apply pred (map (lambda (g) (apply g multi-index)) getters))
arrays) (return #f))))
#t))) #t
(array-domain array)))))
(define (array->list array) (define (array->list array)
(reverse (array-fold cons '() array))) (reverse (array-fold cons '() array)))
@ -811,21 +811,24 @@
res)) res))
(define (array-assign! destination source) (define (array-assign! destination source)
(assert (assert (and (mutable-array? destination) (array? source)))
(and (array? destination)
(mutable-array? destination)
(array? source)
(or (equal? (array-domain destination) (array-domain source))
(and (array-elements-in-order? destination)
(equal? (interval-volume (array-domain destination))
(interval-volume (array-domain source)))))))
(let ((getter (array-getter source)) (let ((getter (array-getter source))
(setter (array-setter destination))) (setter (array-setter destination)))
(if (equal? (array-domain destination) (array-domain source)) (cond
((equal? (array-domain destination) (array-domain source))
(interval-for-each (interval-for-each
(case (array-dimension destination)
((1) (lambda (i) (setter (getter i) i)))
((2) (lambda (i j) (setter (getter i j) i j)))
((3) (lambda (i j k) (setter (getter i j k) i j k)))
(else
(lambda multi-index (lambda multi-index
(apply setter (apply getter multi-index) multi-index)) (apply setter (apply getter multi-index) multi-index))))
(array-domain source)) (array-domain source)))
(else
(assert (and (array-elements-in-order? destination)
(equal? (interval-volume (array-domain destination))
(interval-volume (array-domain source)))))
(let* ((dst-domain (array-domain destination)) (let* ((dst-domain (array-domain destination))
(rev-lowers (reverse (interval-lower-bounds->list dst-domain))) (rev-lowers (reverse (interval-lower-bounds->list dst-domain)))
(rev-uppers (reverse (interval-upper-bounds->list dst-domain))) (rev-uppers (reverse (interval-upper-bounds->list dst-domain)))
@ -835,7 +838,7 @@
(lambda multi-index (lambda multi-index
(apply setter (apply getter multi-index) dst-index) (apply setter (apply getter multi-index) dst-index)
(rev-index-next! rev-index rev-lowers rev-uppers)) (rev-index-next! rev-index rev-lowers rev-uppers))
(array-domain source)))) (array-domain source)))))
destination)) destination))
(define (reshape-without-copy array new-domain) (define (reshape-without-copy array new-domain)
@ -847,16 +850,15 @@
(apply orig-indexer (apply orig-indexer
(invert-default-index domain (invert-default-index domain
(apply tmp-indexer multi-index))))) (apply tmp-indexer multi-index)))))
(new-coeffs (new-coeffs (indexer->coeffs new-indexer new-domain #t))
(indexer->coeffs new-indexer new-domain #t)) (flat-indexer (coeffs->indexer new-coeffs new-domain))
(flat-indexer
(coeffs->indexer new-coeffs new-domain))
(new-indexer (coeffs->indexer new-coeffs new-domain)) (new-indexer (coeffs->indexer new-coeffs new-domain))
(body (array-body array)) (body (array-body array))
(storage (array-storage-class array)) (storage (array-storage-class array))
(res (res
(%make-specialized new-domain storage body new-coeffs flat-indexer (%make-specialized new-domain storage body new-coeffs flat-indexer
(array-safe? array) (array-setter array)))) (array-safe? array) (array-setter array)
(array-adjacent? array))))
(let ((multi-index (interval-lower-bounds->list domain)) (let ((multi-index (interval-lower-bounds->list domain))
(orig-default-indexer (default-indexer domain))) (orig-default-indexer (default-indexer domain)))
(let lp ((i 0) (let lp ((i 0)
@ -886,20 +888,11 @@
(cond (cond
((reshape-without-copy array new-domain)) ((reshape-without-copy array new-domain))
(copy-on-failure? (copy-on-failure?
(let* ((res (make-specialized-array (let ((res (make-specialized-array
new-domain new-domain
(array-storage-class array) (array-storage-class array)
(array-safe? array))) (array-safe? array))))
(setter (array-setter res)) (array-assign! res array)
(multi-index (interval-lower-bounds->list new-domain))
(rev-index (pair-fold cons '() multi-index))
(rev-lowers (reverse (interval-lower-bounds->list new-domain)))
(rev-uppers (reverse (interval-upper-bounds->list new-domain))))
(array-for-each
(lambda (x)
(apply setter x multi-index)
(rev-index-next! rev-index rev-lowers rev-uppers))
array)
res)) res))
(else (else
(error "can't reshape" array new-domain))))) (error "can't reshape" array new-domain)))))

View file

@ -890,6 +890,8 @@ OTHER DEALINGS IN THE SOFTWARE.
(define (run-tests) (define (run-tests)
(random-source-pseudo-randomize! default-random-source 7 23)
(test-begin "srfi-179: nonempty intervals and generalized arrays") (test-begin "srfi-179: nonempty intervals and generalized arrays")
(test-group "interval tests" (test-group "interval tests"