Fix array-stack, interval folds and list*->array.

Issue #900.
This commit is contained in:
Alex Shinn 2023-03-19 23:56:01 +09:00
parent e6d7e4fffb
commit 3b8f07b12e
5 changed files with 132 additions and 63 deletions

View file

@ -17,6 +17,7 @@
interval-upper-bounds->list interval-lower-bounds->vector
interval-upper-bounds->vector interval= interval-volume
interval-subset? interval-contains-multi-index? interval-projections
interval-fold-left interval-fold-right
interval-for-each interval-dilate interval-intersect
interval-translate interval-permute
interval-scale interval-cartesian-product
@ -49,5 +50,6 @@
array->list* list*->array array->vector* vector*->array
array-assign! array-ref array-set! array-decurry
specialized-array-reshape
array-copy! array-stack! array-decurry! array-append! array-block!
)
(include "231/transforms.scm"))

View file

@ -133,12 +133,12 @@
(vector-ref ivc 3))
(values ivc (vector-ref ivc 0)))))
(define (interval-fold kons knil iv)
(define (interval-fold-left f kons knil iv)
(case (interval-dimension iv)
((1)
(let ((end (interval-upper-bound iv 0)))
(do ((i (interval-lower-bound iv 0) (+ i 1))
(acc knil (kons acc i)))
(acc knil (kons acc (f i))))
((>= i end) acc))))
((2)
(let ((end0 (interval-upper-bound iv 0))
@ -147,17 +147,28 @@
(do ((i (interval-lower-bound iv 0) (+ i 1))
(acc knil
(do ((j start1 (+ j 1))
(acc acc (kons acc i j)))
(acc acc (kons acc (f i j))))
((>= j end1) acc))))
((>= i end0) acc))))
(else
(let ((ivc (interval-cursor iv)))
(let lp ((acc knil))
(let ((acc (apply kons acc (interval-cursor-get ivc))))
(let ((acc (kons acc (apply f (interval-cursor-get ivc)))))
(if (interval-cursor-next! ivc)
(lp acc)
acc)))))))
(define (interval-fold kons knil iv)
(interval-fold-left list (lambda (acc idx) (apply kons acc idx)) knil iv))
(define (interval-fold-right f kons knil iv)
(let ((ivc (interval-cursor iv)))
(let lp ()
(let ((item (apply f (interval-cursor-get ivc))))
(if (interval-cursor-next! ivc)
(kons item (lp))
(kons item knil))))))
(define (interval-for-each f iv)
(interval-fold (lambda (acc . multi-index) (apply f multi-index)) #f iv)
(if #f #f))

View file

@ -16,6 +16,7 @@
interval-upper-bounds->list interval-lower-bounds->vector
interval-upper-bounds->vector interval= interval-volume
interval-subset? interval-contains-multi-index? interval-projections
interval-fold-left interval-fold-right
interval-for-each interval-dilate interval-intersect
interval-translate interval-permute
interval-scale interval-cartesian-product

View file

@ -1,35 +1,33 @@
#|
Adapted from original SRFI reference test suite:
;; Adapted from original SRFI reference test suite:
SRFI 179: Nonempty Intervals and Generalized Arrays (Updated)
;; SRFI 179: Nonempty Intervals and Generalized Arrays (Updated)
Copyright 2016, 2018, 2020 Bradley J Lucier.
All Rights Reserved.
;; Copyright 2016, 2018, 2020 Bradley J Lucier.
;; All Rights Reserved.
Permission is hereby granted, free of charge,
to any person obtaining a copy of this software
and associated documentation files (the "Software"),
to deal in the Software without restriction,
including without limitation the rights to use, copy,
modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit
persons to whom the Software is furnished to do so,
subject to the following conditions:
;; Permission is hereby granted, free of charge,
;; to any person obtaining a copy of this software
;; and associated documentation files (the "Software"),
;; to deal in the Software without restriction,
;; including without limitation the rights to use, copy,
;; modify, merge, publish, distribute, sublicense,
;; and/or sell copies of the Software, and to permit
;; persons to whom the Software is furnished to do so,
;; subject to the following conditions:
The above copyright notice and this permission notice
(including the next paragraph) shall be included in
all copies or substantial portions of the Software.
;; The above copyright notice and this permission notice
;; (including the next paragraph) shall be included in
;; all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO
EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN
AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.
|#
;; THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
;; ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT
;; LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
;; FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO
;; EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
;; FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN
;; AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
;; OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
;; OTHER DEALINGS IN THE SOFTWARE.
;;; A test program for SRFI 179:
;;; Nonempty Intervals and Generalized Arrays (Updated)
@ -3052,6 +3050,40 @@ OTHER DEALINGS IN THE SOFTWARE.
(make-array (make-interval '#(2 3)) list)))
)
(test-group "stack/block"
(let* ((a
(make-array (make-interval '#(4 10)) list))
(a-column
(array-getter ;; the getter of ...
(array-curry ;; a 1-D array of the columns of A
(array-permute a '#(1 0))
1))))
(test '(((0 1) (0 2) (0 5) (0 8))
((1 1) (1 2) (1 5) (1 8))
((2 1) (2 2) (2 5) (2 8))
((3 1) (3 2) (3 5) (3 8)))
(array->list*
(array-stack ;; stack into a new 2-D array ...
1 ;; along axis 1 (i.e., columns) ...
(map a-column '(1 2 5 8)))) ;; the columns of A you want
))
'(test '((0 1 4 6 7 8)
(2 3 5 9 10 11)
(12 13 14 15 16 17))
(array->list*
(array-block (list*->array
2
(list (list (list*->array 2 '((0 1)
(2 3)))
(list*->array 2 '((4)
(5)))
(list*->array 2 '((6 7 8)
(9 10 11))))
(list (list*->array 2 '((12 13)))
(list*->array 2 '((14)))
(list*->array 2 '((15 16 17)))))))))
)
(test-group "assign/product"
(do ((d 1 (fx+ d 1)))
((= d 6))

View file

@ -102,6 +102,8 @@
(%array-setter-set! res #f))
res))))
(define array-copy! array-copy)
(define (array-curry array inner-dimension)
(call-with-values
(lambda () (interval-projections (array-domain array) inner-dimension))
@ -492,14 +494,15 @@
(append-map flatten ls)
ls))
(define (list*->array nested-ls . o)
(let lp ((ls nested-ls) (lens '()))
(define (list*->array dimension nested-ls . o)
(let lp ((ls nested-ls) (lens '()) (d dimension))
(cond
((pair? ls) (lp (car ls) (cons (length ls) lens)))
((positive? d)
(lp (car ls) (cons (length ls) lens) (- d 1)))
(else
(apply list->array
(flatten nested-ls)
(make-interval (list->vector (reverse lens)))
(flatten nested-ls)
o)))))
(define (array->list* a)
@ -543,14 +546,15 @@
(append-map flatten-vec vec)
(vector->list vec)))
(define (vector*->array nested-vec . o)
(let lp ((vec nested-vec) (lens '()))
(define (vector*->array dimension nested-vec . o)
(let lp ((vec nested-vec) (lens '()) (d dimension))
(cond
((vector? vec) (lp (vector-ref vec 0) (cons (vector-length vec) lens)))
((positive? d)
(lp (vector-ref vec 0) (cons (vector-length vec) lens) (- d 1)))
(else
(apply list->array
(flatten-vec nested-vec)
(make-interval (list->vector (reverse lens)))
(flatten-vec nested-vec)
o)))))
(define (dimensions-compatible? a-domain b-domain axis)
@ -609,31 +613,46 @@
(array-assign! view b)
(lp (cdr arrays) b-offset2)))))))))
(define (array-stack axis a . o)
(define array-append! array-append)
(define (array-stack axis arrays . o)
(assert (and (exact-integer? axis)
(array? a)
(< -1 axis (array-dimension a))
(every array? o)
(every (lambda (b) (interval= (array-domain a) (array-domain b))) o)))
(let* ((a-lbs (interval-lower-bounds->list (array-domain a)))
(a-ubs (interval-upper-bounds->list (array-domain a)))
(domain
(make-interval
`#(,@(take a-lbs axis) 0 ,@(drop a-lbs axis))
`#(,@(take a-ubs axis) ,(+ 1 (length o)) ,@(drop a-ubs axis))))
(res (make-specialized-array domain
(or (array-storage-class a)
generic-storage-class)))
(perm `#(,axis ,@(delete axis (iota (+ 1 (array-dimension a))))))
(permed (if (zero? axis) res (array-permute res perm)))
(curried (array-curry permed 1))
(get-view (array-getter curried)))
(let lp ((ls (cons a o)) (i 0))
(cond
((null? ls) res)
(else
(array-assign! (get-view i) (car ls))
(lp (cdr ls) (+ i 1)))))))
(pair? arrays)
(every array? arrays)
(<= 0 axis (array-dimension (car arrays)))))
(let ((a (car arrays))
(storage (if (pair? o) (car o) generic-storage-class))
(mutable? (if (and (pair? o) (pair? (cdr o)))
(cadr o)
(specialized-array-default-mutable?)))
(safe? (if (and (pair? o) (pair? (cdr o)) (pair? (cddr o)))
(car (cddr o))
(specialized-array-default-safe?))))
(assert (every (lambda (b)
(interval= (array-domain a)
(array-domain b)))
(cdr arrays)))
(let* ((a-lbs (interval-lower-bounds->list (array-domain a)))
(a-ubs (interval-upper-bounds->list (array-domain a)))
(domain
(make-interval
`#(,@(take a-lbs axis) 0 ,@(drop a-lbs axis))
`#(,@(take a-ubs axis) ,(length arrays) ,@(drop a-ubs axis))))
(res (make-specialized-array domain
(or (array-storage-class a)
generic-storage-class)))
(perm `#(,axis ,@(delete axis (iota (+ 1 (array-dimension a))))))
(permed (if (zero? axis) res (array-permute res perm)))
(curried (array-curry permed 1))
(get-view (array-getter curried)))
(let lp ((ls arrays) (i 0))
(cond
((null? ls) res)
(else
(array-assign! (get-view i) (car ls))
(lp (cdr ls) (+ i 1))))))))
(define array-stack! array-stack)
(define (array-block a . o)
(let ((storage (if (pair? o) (car o) generic-storage-class))
@ -656,6 +675,8 @@
(error "TODO: array-block copy data unimplemented")
res))))
(define array-block! array-block)
(define (array-decurry a . o)
(let* ((storage (if (pair? o) (car o) generic-storage-class))
(mutable? (if (and (pair? o) (pair? (cdr o)))
@ -675,3 +696,5 @@
;; curried view from a to the res.
(array-for-each array-assign! curried-res a)
res))
(define array-decurry! array-decurry)