Update to n-ary array-fold-* with correct arg order.

Closes #973.
This commit is contained in:
Alex Shinn 2024-05-28 22:35:31 +09:00
parent 96f17293f2
commit 36d7881763
3 changed files with 83 additions and 51 deletions

View file

@ -45,8 +45,8 @@
array-safe? array-packed? specialized-array-share array-safe? array-packed? specialized-array-share
array-copy array-curry array-extract array-tile array-translate array-copy array-curry array-extract array-tile array-translate
array-permute array-reverse array-sample array-permute array-reverse array-sample
array-outer-product array-map array-for-each array-foldl array-outer-product array-map array-for-each array-fold-left
array-foldr array-reduce array-any array-every array-fold-right array-reduce array-any array-every
array-inner-product array-stack array-append array-block array-inner-product array-stack array-append array-block
array->list list->array array->vector vector->array array->list list->array array->vector vector->array
array->list* list*->array array->vector* vector*->array array->list* list*->array array->vector* vector*->array

View file

@ -204,12 +204,12 @@
(error "array-display can't handle > 2 dimensions: " A)))) (error "array-display can't handle > 2 dimensions: " A))))
(define (myindexer= indexer1 indexer2 interval) (define (myindexer= indexer1 indexer2 interval)
(array-foldl (lambda (x y) (and x y)) (array-fold-left (lambda (x y) (and x y))
#t #t
(make-array interval (make-array interval
(lambda args (lambda args
(= (apply indexer1 args) (= (apply indexer1 args)
(apply indexer2 args)))))) (apply indexer2 args))))))
(define (my-indexer base lower-bounds increments) (define (my-indexer base lower-bounds increments)
(lambda indices (lambda indices
@ -221,12 +221,12 @@
(define (myarray= array1 array2) (define (myarray= array1 array2)
(and (interval= (array-domain array1) (and (interval= (array-domain array1)
(array-domain array2)) (array-domain array2))
(array-foldl (lambda (vs result) (array-fold-left (lambda (result vs)
(and (equal? (car vs) (and (equal? (car vs)
(cadr vs)) (cadr vs))
result)) result))
#t #t
(array-map list array1 array2)))) (array-map list array1 array2))))
(define random-storage-class-and-initializer (define random-storage-class-and-initializer
(let* ((storage-classes (let* ((storage-classes
@ -677,13 +677,13 @@
;;(define test-pgm (read-pgm "girl.pgm")) ;;(define test-pgm (read-pgm "girl.pgm"))
(define (array-dot-product a b) (define (array-dot-product a b)
(array-foldl (lambda (x y) (array-fold-left (lambda (x y)
(+ x y)) (+ x y))
0 0
(array-map (array-map
(lambda (x y) (lambda (x y)
(* x y)) (* x y))
a b))) a b)))
(define (array-convolve source filter) (define (array-convolve source filter)
(let* ((source-domain (let* ((source-domain
@ -708,7 +708,7 @@
(make-array (make-array
result-domain result-domain
(lambda (i j) (lambda (i j)
(array-foldl (array-fold-left
(lambda (p q) (lambda (p q)
(+ p q)) (+ p q))
0 0
@ -737,9 +737,9 @@
(max 0 (min (exact (round pixel)) max-grey))) (max 0 (min (exact (round pixel)) max-grey)))
(define (array-sum a) (define (array-sum a)
(array-foldl + 0 a)) (array-fold-left + 0 a))
(define (array-max a) (define (array-max a)
(array-foldl max -inf.0 a)) (array-fold-left max -inf.0 a))
(define (max-norm a) (define (max-norm a)
(array-max (array-map abs a))) (array-max (array-map abs a)))
@ -1921,10 +1921,10 @@
;; (test-assert (indices-in-proper-order (reverse arguments-2))) ;; (test-assert (indices-in-proper-order (reverse arguments-2)))
;; )) ;; ))
(test-error (array-foldl 1 1 1)) (test-error (array-fold-left 1 1 1))
(test-error (array-foldl list 1 1)) (test-error (array-fold-left list 1 1))
(test-error (array-foldr 1 1 1)) (test-error (array-fold-right 1 1 1))
(test-error (array-foldr list 1 1)) (test-error (array-fold-right list 1 1))
(test-error (array-for-each 1 #f)) (test-error (array-for-each 1 #f))
(test-error (array-for-each list 1 (make-array (make-interval '#(3) '#(4)) (test-error (array-for-each list 1 (make-array (make-interval '#(3) '#(4))
list))) list)))
@ -2112,10 +2112,10 @@
0 1) 0 1)
(matrix 1 0 (matrix 1 0
i 1)))))) i 1))))))
(test (array-foldr x2x2-multiply (matrix 1 0 0 1) A) (test (array-fold-right x2x2-multiply (matrix 1 0 0 1) A)
(array-reduce x2x2-multiply A)) (array-reduce x2x2-multiply A))
(test-not (equal? (array-reduce x2x2-multiply A) (test (array-reduce x2x2-multiply A)
(array-foldl x2x2-multiply (matrix 1 0 0 1) A)))) (array-fold-left x2x2-multiply (matrix 1 0 0 1) A)))
(let ((A_2 (make-array (make-interval '#(1 1) '#(3 7)) (let ((A_2 (make-array (make-interval '#(1 1) '#(3 7))
(lambda (i j) (lambda (i j)
@ -2124,10 +2124,10 @@
j 1) j 1)
(matrix 1 j (matrix 1 j
i -1)))))) i -1))))))
(test (array-foldr x2x2-multiply (matrix 1 0 0 1) A_2) (test (array-fold-right x2x2-multiply (matrix 1 0 0 1) A_2)
(array-reduce x2x2-multiply A_2)) (array-reduce x2x2-multiply A_2))
(test-not (equal? (array-reduce x2x2-multiply A_2) (test (array-reduce x2x2-multiply A_2)
(array-foldl x2x2-multiply (matrix 1 0 0 1) A_2))) (array-fold-left x2x2-multiply (matrix 1 0 0 1) A_2))
(test-not (equal? (array-reduce x2x2-multiply A_2) (test-not (equal? (array-reduce x2x2-multiply A_2)
(array-reduce x2x2-multiply (array-rotate A_2 1))))) (array-reduce x2x2-multiply (array-rotate A_2 1)))))
@ -2138,10 +2138,10 @@
j k) j k)
(matrix k j (matrix k j
i -1)))))) i -1))))))
(test (array-foldr x2x2-multiply (matrix 1 0 0 1) A_3) (test (array-fold-right x2x2-multiply (matrix 1 0 0 1) A_3)
(array-reduce x2x2-multiply A_3)) (array-reduce x2x2-multiply A_3))
(test-not (equal? (array-reduce x2x2-multiply A_3) (test (array-reduce x2x2-multiply A_3)
(array-foldl x2x2-multiply (matrix 1 0 0 1) A_3))) (array-fold-right x2x2-multiply (matrix 1 0 0 1) A_3))
(test-not (equal? (array-reduce x2x2-multiply A_3) (test-not (equal? (array-reduce x2x2-multiply A_3)
(array-reduce x2x2-multiply (array-rotate A_3 1))))) (array-reduce x2x2-multiply (array-rotate A_3 1)))))
@ -2152,10 +2152,10 @@
j k) j k)
(matrix l k (matrix l k
i j)))))) i j))))))
(test (array-foldr x2x2-multiply (matrix 1 0 0 1) A_4) (test (array-fold-right x2x2-multiply (matrix 1 0 0 1) A_4)
(array-reduce x2x2-multiply A_4)) (array-reduce x2x2-multiply A_4))
(test-not (equal? (array-reduce x2x2-multiply A_4) (test (array-reduce x2x2-multiply A_4)
(array-foldl x2x2-multiply (matrix 1 0 0 1) A_4))) (array-fold-left x2x2-multiply (matrix 1 0 0 1) A_4))
(test-not (equal? (array-reduce x2x2-multiply A_4) (test-not (equal? (array-reduce x2x2-multiply A_4)
(array-reduce x2x2-multiply (array-rotate A_4 1))))) (array-reduce x2x2-multiply (array-rotate A_4 1)))))
@ -2166,10 +2166,10 @@
j k) j k)
(matrix (- l m) k (matrix (- l m) k
i j)))))) i j))))))
(test (array-foldr x2x2-multiply (matrix 1 0 0 1) A_5) (test (array-fold-right x2x2-multiply (matrix 1 0 0 1) A_5)
(array-reduce x2x2-multiply A_5)) (array-reduce x2x2-multiply A_5))
(test-not (equal? (array-reduce x2x2-multiply A_5) (test (array-reduce x2x2-multiply A_5)
(array-foldl x2x2-multiply (matrix 1 0 0 1) A_5))) (array-fold-left x2x2-multiply (matrix 1 0 0 1) A_5))
(test-not (equal? (array-reduce x2x2-multiply A_5) (test-not (equal? (array-reduce x2x2-multiply A_5)
(array-reduce x2x2-multiply (array-rotate A_5 1))))) (array-reduce x2x2-multiply (array-rotate A_5 1)))))

View file

@ -379,14 +379,46 @@
(apply f (map (lambda (g) (apply g multi-index)) getters)))) (apply f (map (lambda (g) (apply g multi-index)) getters))))
(array-domain array)))) (array-domain array))))
(define (array-foldl kons knil array) (define (array-fold-left operator identity array . arrays)
(interval-fold (lambda (acc . multi-index) (assert (and (procedure? operator)
(kons (apply array-ref array multi-index) acc)) (array? array)
knil (every array? arrays)
(array-domain array))) (every (lambda (a)
(interval= (array-domain array)
(array-domain a)))
arrays)))
(if (null? arrays)
(interval-fold-left (array-getter array)
(lambda (accumulator array-element)
(operator accumulator array-element))
identity
(array-domain array))
(interval-fold-left (array-getter (apply array-map list array arrays))
(lambda (accumulator array-elements)
(apply operator accumulator array-elements))
identity
(array-domain array))))
(define (array-foldr kons knil array) (define (array-fold-right operator identity array . arrays)
(fold-right kons knil (array->list array))) (assert (and (procedure? operator)
(array? array)
(every array? arrays)
(every (lambda (a)
(interval= (array-domain array)
(array-domain a)))
arrays)))
(if (null? arrays)
(interval-fold-right (array-getter array)
(lambda (array-element accumulator)
(operator array-element accumulator))
identity
(array-domain array))
(interval-fold-right
(array-getter (apply array-map list array arrays))
(lambda (array-elements accumulator)
(apply operator (append array-elements (list accumulator))))
identity
(array-domain array))))
(define (array-reduce op array) (define (array-reduce op array)
(let* ((domain (array-domain array)) (let* ((domain (array-domain array))
@ -424,7 +456,7 @@
(array-domain array))))) (array-domain array)))))
(define (array->list array) (define (array->list array)
(reverse (array-foldl cons '() array))) (reverse (array-fold-left xcons '() array)))
(define (list->array domain ls . o) (define (list->array domain ls . o)
(let* ((storage (if (pair? o) (car o) generic-storage-class)) (let* ((storage (if (pair? o) (car o) generic-storage-class))