diff --git a/lib/srfi/179.scm b/lib/srfi/179.scm index 65382c31..0b60192d 100644 --- a/lib/srfi/179.scm +++ b/lib/srfi/179.scm @@ -1,7 +1,7 @@ ;; Miscellaneous Functions (define (translation? x) - (and (vector? x) (vector-every exact-integer? x))) + (and (vector? x) (not (vector-empty? x)) (vector-every exact-integer? x))) (define (permutation? x) (and (translation? x) @@ -15,11 +15,11 @@ (u1vector-set! seen (vector-ref x i) 1) (lp (+ i 1))))))))) -(define (all-equal? ls) +(define (same-dimensions? ls) (or (null? ls) (null? (cdr ls)) - (and (equal? (car ls) (cadr ls)) - (all-equal? (cdr ls))))) + (and (equal? (array-dimension (car ls)) (array-dimension (cadr ls))) + (same-dimensions? (cdr ls))))) ;; Intervals @@ -32,8 +32,6 @@ (define (%make-interval lo hi) (assert (and (translation? lo) (translation? hi) - (not (vector-empty? lo)) - (not (vector-empty? hi)) (= (vector-length lo) (vector-length hi)) (vector-every < lo hi))) (%%make-interval lo hi)) @@ -92,20 +90,18 @@ (set-car! (car rev-index) (car rev-lowers)) (rev-index-next! (cdr rev-index) (cdr rev-lowers) (cdr rev-uppers))))) -(define (interval-fold kons knil iv . o) +(define (interval-fold kons knil iv) (case (interval-dimension iv) ((1) (let ((end (interval-upper-bound iv 0))) - (do ((i (if (pair? o) (car o) (interval-lower-bound iv 0)) - (+ i 1)) + (do ((i (interval-lower-bound iv 0) (+ i 1)) (acc knil (kons acc i))) ((>= i end) acc)))) ((2) (let ((end0 (interval-upper-bound iv 0)) - (start1 (if (pair? o) (cadr o) (interval-lower-bound iv 1))) + (start1 (interval-lower-bound iv 1)) (end1 (interval-upper-bound iv 1))) - (do ((i (if (pair? o) (car o) (interval-lower-bound iv 0)) - (+ i 1)) + (do ((i (interval-lower-bound iv 0) (+ i 1)) (acc knil (do ((j start1 (+ j 1)) (acc acc (kons acc i j))) @@ -114,8 +110,7 @@ (else (let* ((rev-lowers (reverse (interval-lower-bounds->list iv))) (rev-uppers (reverse (interval-upper-bounds->list iv))) - (multi-index - (list-copy (if (pair? o) o (interval-lower-bounds->list iv)))) + (multi-index (interval-lower-bounds->list iv)) (rev-index (pair-fold cons '() multi-index))) (let lp ((acc knil)) (let ((acc (apply kons acc multi-index))) @@ -261,7 +256,7 @@ ;; Arrays (define-record-type Array - (%%make-array domain getter setter storage body coeffs indexer safe?) + (%%make-array domain getter setter storage body coeffs indexer safe? adjacent?) array? (domain array-domain) (getter array-getter) @@ -270,18 +265,21 @@ (body array-body) (coeffs array-coeffs) (indexer array-indexer) - (safe? array-safe?)) + (safe? array-safe?) + (adjacent? array-adjacent? array-adjacent?-set!)) -(define (%make-array domain getter setter storage body coeffs indexer safe?) +(define (%make-array domain getter setter storage body coeffs + indexer safe? adjacent?) (assert (and (interval? domain) (procedure? getter) (or (not setter) (procedure? setter)) (or (not storage) (storage-class? storage)))) - (%%make-array domain getter setter storage body coeffs indexer safe?)) + (%%make-array + domain getter setter storage body coeffs indexer safe? adjacent?)) (define (make-array domain getter . o) (assert (and (interval? domain) (procedure? getter))) - (%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f)) + (%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f #f)) (define (array-dimension a) (interval-dimension (array-domain a))) @@ -438,7 +436,8 @@ ;; Specialized arrays -(define (%make-specialized domain storage body coeffs indexer safe? mutable?) +(define (%make-specialized domain storage body coeffs indexer + safe? mutable? adjacent?) (%make-array domain (specialized-getter body indexer (storage-class-getter storage)) @@ -448,7 +447,8 @@ body coeffs indexer - safe?)) + safe? + adjacent?)) (define (make-specialized-array domain . o) (let* ((storage (if (pair? o) (car o) generic-storage-class)) @@ -461,14 +461,12 @@ (coeffs (default-coeffs domain)) (indexer (coeffs->indexer coeffs domain))) (assert (boolean? safe?)) - (%make-specialized domain storage body coeffs indexer safe? #t))) + (%make-specialized domain storage body coeffs indexer safe? #t #t))) (define (specialized-array? x) (and (array? x) (array-storage-class x) #t)) -(define (array-elements-in-order? array) - (assert (specialized-array? array)) - ;; TODO: speed this up and/or cache it +(define (compute-array-elements-in-order? array) (let ((indexer (array-indexer array))) (call-with-current-continuation (lambda (return) @@ -482,6 +480,14 @@ (array-domain array)) #t)))) +(define (array-elements-in-order? array) + (assert (specialized-array? array)) + (let ((res (array-adjacent? array))) + (when (eq? res 'unknown) + (set! res (compute-array-elements-in-order? array)) + (array-adjacent?-set! array res)) + res)) + (define (specialized-array-share array new-domain project) (assert (and (specialized-array? array) (interval? new-domain))) (let* ((body (array-body array)) @@ -496,7 +502,7 @@ (coeffs->indexer coeffs new-domain)) (storage (array-storage-class array))) (%make-specialized new-domain storage body coeffs indexer - (array-safe? array) (array-setter array)))) + (array-safe? array) (array-setter array) 'unknown))) ;; Array transformations @@ -521,7 +527,7 @@ (setter (specialized-setter body indexer (storage-class-setter storage))) (res (%make-specialized new-domain storage body coeffs indexer - safe? #t))) + safe? #t #t))) (array-assign! res array) (unless mutable? (%array-setter-set! res #f)) @@ -561,9 +567,7 @@ (interval-subset? new-domain (array-domain array)))) (if (specialized-array? array) (specialized-array-share array new-domain values) - (make-array new-domain - (array-getter array) - (array-setter array)))) + (make-array new-domain (array-getter array) (array-setter array)))) (define (array-tile array sizes) (assert (and (array? array) @@ -595,8 +599,8 @@ (lambda (i lo hi s) (min hi (+ lo (* (+ i 1) s)))) multi-index - (interval-lower-bound (array-domain array)) - (interval-upper-bound (array-domain array)) + (interval-lb (array-domain array)) + (interval-ub (array-domain array)) sizes))))))) (define (array-translate array translation) @@ -626,13 +630,11 @@ (define (inverse-permutation permutation) (list->vector - (map - car - (list-sort - (lambda (a b) (< (cdr a) (cdr b))) - (map cons - (iota (vector-length permutation)) - (vector->list permutation)))))) + (map car + (list-sort (lambda (a b) (< (cdr a) (cdr b))) + (map cons + (iota (vector-length permutation)) + (vector->list permutation)))))) (define (array-permute array permutation) (assert (permutation? permutation)) @@ -729,7 +731,7 @@ (make-array (array-domain array) (let* ((ls (cons array arrays)) (getters (map array-getter ls))) - (assert (all-equal? (map array-dimension ls))) + (assert (same-dimensions? ls)) (lambda multi-index (apply f (map (lambda (g) (apply g multi-index)) getters)))))) @@ -737,7 +739,7 @@ (interval-for-each (let* ((ls (cons array arrays)) (getters (map array-getter ls))) - (assert (all-equal? (map array-dimension ls))) + (assert (same-dimensions? ls)) (lambda multi-index (apply f (map (lambda (g) (apply g multi-index)) getters)))) (array-domain array))) @@ -752,41 +754,39 @@ (fold-right kons knil (array->list array))) (define (array-reduce op array) - ;; (let* ((domain (array-domain array)) - ;; (init-index (interval-lower-bounds->list domain)) - ;; (knil (apply array-ref array init-index))) - ;; (if (rev-index-next! (pair-fold cons '() init-index) - ;; (reverse (interval-lower-bounds->list domain)) - ;; (reverse (interval-upper-bounds->list domain))) - ;; (apply interval-fold - ;; (lambda (acc . multi-index) - ;; (op acc (apply array-ref array multi-index))) - ;; knil - ;; domain - ;; init-index) - ;; knil)) - (reduce (lambda (elt acc) (op acc elt)) 'never-used (array->list array))) + (let* ((domain (array-domain array)) + (init-index (interval-lower-bounds->list domain)) + (knil (list 'first-element))) + (interval-fold + (lambda (acc . multi-index) + (if (eq? acc knil) + (apply array-ref array multi-index) + (op acc (apply array-ref array multi-index)))) + knil + domain))) (define (array-any pred array . arrays) - (assert (all-equal? (map array-dimension (cons array arrays)))) + (assert (same-dimensions? (cons array arrays))) (call-with-current-continuation (lambda (return) (apply array-for-each - (lambda args (if (apply pred args) (return #t))) + (lambda args (cond ((apply pred args) => return))) + #f array arrays) #f))) (define (array-every pred array . arrays) - (assert (all-equal? (map array-dimension (cons array arrays)))) + (assert (same-dimensions? (cons array arrays))) (call-with-current-continuation (lambda (return) - ;; TODO: return last value - (apply array-for-each - (lambda args (if (not (apply pred args)) (return #f))) - array - arrays) - #t))) + (interval-fold + (let ((getters (map array-getter (cons array arrays)))) + (lambda (acc . multi-index) + (or (apply pred (map (lambda (g) (apply g multi-index)) getters)) + (return #f)))) + #t + (array-domain array))))) (define (array->list array) (reverse (array-fold cons '() array))) @@ -811,31 +811,34 @@ res)) (define (array-assign! destination source) - (assert - (and (array? destination) - (mutable-array? destination) - (array? source) - (or (equal? (array-domain destination) (array-domain source)) - (and (array-elements-in-order? destination) - (equal? (interval-volume (array-domain destination)) - (interval-volume (array-domain source))))))) + (assert (and (mutable-array? destination) (array? source))) (let ((getter (array-getter source)) (setter (array-setter destination))) - (if (equal? (array-domain destination) (array-domain source)) + (cond + ((equal? (array-domain destination) (array-domain source)) + (interval-for-each + (case (array-dimension destination) + ((1) (lambda (i) (setter (getter i) i))) + ((2) (lambda (i j) (setter (getter i j) i j))) + ((3) (lambda (i j k) (setter (getter i j k) i j k))) + (else + (lambda multi-index + (apply setter (apply getter multi-index) multi-index)))) + (array-domain source))) + (else + (assert (and (array-elements-in-order? destination) + (equal? (interval-volume (array-domain destination)) + (interval-volume (array-domain source))))) + (let* ((dst-domain (array-domain destination)) + (rev-lowers (reverse (interval-lower-bounds->list dst-domain))) + (rev-uppers (reverse (interval-upper-bounds->list dst-domain))) + (dst-index (list-copy (interval-lower-bounds->list dst-domain))) + (rev-index (pair-fold cons '() dst-index))) (interval-for-each (lambda multi-index - (apply setter (apply getter multi-index) multi-index)) - (array-domain source)) - (let* ((dst-domain (array-domain destination)) - (rev-lowers (reverse (interval-lower-bounds->list dst-domain))) - (rev-uppers (reverse (interval-upper-bounds->list dst-domain))) - (dst-index (list-copy (interval-lower-bounds->list dst-domain))) - (rev-index (pair-fold cons '() dst-index))) - (interval-for-each - (lambda multi-index - (apply setter (apply getter multi-index) dst-index) - (rev-index-next! rev-index rev-lowers rev-uppers)) - (array-domain source)))) + (apply setter (apply getter multi-index) dst-index) + (rev-index-next! rev-index rev-lowers rev-uppers)) + (array-domain source))))) destination)) (define (reshape-without-copy array new-domain) @@ -847,16 +850,15 @@ (apply orig-indexer (invert-default-index domain (apply tmp-indexer multi-index))))) - (new-coeffs - (indexer->coeffs new-indexer new-domain #t)) - (flat-indexer - (coeffs->indexer new-coeffs new-domain)) + (new-coeffs (indexer->coeffs new-indexer new-domain #t)) + (flat-indexer (coeffs->indexer new-coeffs new-domain)) (new-indexer (coeffs->indexer new-coeffs new-domain)) (body (array-body array)) (storage (array-storage-class array)) (res (%make-specialized new-domain storage body new-coeffs flat-indexer - (array-safe? array) (array-setter array)))) + (array-safe? array) (array-setter array) + (array-adjacent? array)))) (let ((multi-index (interval-lower-bounds->list domain)) (orig-default-indexer (default-indexer domain))) (let lp ((i 0) @@ -886,20 +888,11 @@ (cond ((reshape-without-copy array new-domain)) (copy-on-failure? - (let* ((res (make-specialized-array - new-domain - (array-storage-class array) - (array-safe? array))) - (setter (array-setter res)) - (multi-index (interval-lower-bounds->list new-domain)) - (rev-index (pair-fold cons '() multi-index)) - (rev-lowers (reverse (interval-lower-bounds->list new-domain))) - (rev-uppers (reverse (interval-upper-bounds->list new-domain)))) - (array-for-each - (lambda (x) - (apply setter x multi-index) - (rev-index-next! rev-index rev-lowers rev-uppers)) - array) + (let ((res (make-specialized-array + new-domain + (array-storage-class array) + (array-safe? array)))) + (array-assign! res array) res)) (else (error "can't reshape" array new-domain))))) diff --git a/lib/srfi/179/test.sld b/lib/srfi/179/test.sld index deb4e039..e32073e1 100644 --- a/lib/srfi/179/test.sld +++ b/lib/srfi/179/test.sld @@ -890,6 +890,8 @@ OTHER DEALINGS IN THE SOFTWARE. (define (run-tests) + (random-source-pseudo-randomize! default-random-source 7 23) + (test-begin "srfi-179: nonempty intervals and generalized arrays") (test-group "interval tests"