From 36d7881763353865f6f9f6d6d611182fb2c8f0d7 Mon Sep 17 00:00:00 2001 From: Alex Shinn Date: Tue, 28 May 2024 22:35:31 +0900 Subject: [PATCH] Update to n-ary array-fold-* with correct arg order. Closes #973. --- lib/srfi/231.sld | 4 +- lib/srfi/231/test.sld | 82 ++++++++++++++++++------------------- lib/srfi/231/transforms.scm | 48 ++++++++++++++++++---- 3 files changed, 83 insertions(+), 51 deletions(-) diff --git a/lib/srfi/231.sld b/lib/srfi/231.sld index fbb1101c..3ed5d57f 100644 --- a/lib/srfi/231.sld +++ b/lib/srfi/231.sld @@ -45,8 +45,8 @@ array-safe? array-packed? specialized-array-share array-copy array-curry array-extract array-tile array-translate array-permute array-reverse array-sample - array-outer-product array-map array-for-each array-foldl - array-foldr array-reduce array-any array-every + array-outer-product array-map array-for-each array-fold-left + array-fold-right array-reduce array-any array-every array-inner-product array-stack array-append array-block array->list list->array array->vector vector->array array->list* list*->array array->vector* vector*->array diff --git a/lib/srfi/231/test.sld b/lib/srfi/231/test.sld index 9e427178..afff5f4c 100644 --- a/lib/srfi/231/test.sld +++ b/lib/srfi/231/test.sld @@ -204,12 +204,12 @@ (error "array-display can't handle > 2 dimensions: " A)))) (define (myindexer= indexer1 indexer2 interval) - (array-foldl (lambda (x y) (and x y)) - #t - (make-array interval - (lambda args - (= (apply indexer1 args) - (apply indexer2 args)))))) + (array-fold-left (lambda (x y) (and x y)) + #t + (make-array interval + (lambda args + (= (apply indexer1 args) + (apply indexer2 args)))))) (define (my-indexer base lower-bounds increments) (lambda indices @@ -221,12 +221,12 @@ (define (myarray= array1 array2) (and (interval= (array-domain array1) (array-domain array2)) - (array-foldl (lambda (vs result) - (and (equal? (car vs) - (cadr vs)) - result)) - #t - (array-map list array1 array2)))) + (array-fold-left (lambda (result vs) + (and (equal? (car vs) + (cadr vs)) + result)) + #t + (array-map list array1 array2)))) (define random-storage-class-and-initializer (let* ((storage-classes @@ -677,13 +677,13 @@ ;;(define test-pgm (read-pgm "girl.pgm")) (define (array-dot-product a b) - (array-foldl (lambda (x y) - (+ x y)) - 0 - (array-map - (lambda (x y) - (* x y)) - a b))) + (array-fold-left (lambda (x y) + (+ x y)) + 0 + (array-map + (lambda (x y) + (* x y)) + a b))) (define (array-convolve source filter) (let* ((source-domain @@ -708,7 +708,7 @@ (make-array result-domain (lambda (i j) - (array-foldl + (array-fold-left (lambda (p q) (+ p q)) 0 @@ -737,9 +737,9 @@ (max 0 (min (exact (round pixel)) max-grey))) (define (array-sum a) - (array-foldl + 0 a)) + (array-fold-left + 0 a)) (define (array-max a) - (array-foldl max -inf.0 a)) + (array-fold-left max -inf.0 a)) (define (max-norm a) (array-max (array-map abs a))) @@ -1921,10 +1921,10 @@ ;; (test-assert (indices-in-proper-order (reverse arguments-2))) ;; )) - (test-error (array-foldl 1 1 1)) - (test-error (array-foldl list 1 1)) - (test-error (array-foldr 1 1 1)) - (test-error (array-foldr list 1 1)) + (test-error (array-fold-left 1 1 1)) + (test-error (array-fold-left list 1 1)) + (test-error (array-fold-right 1 1 1)) + (test-error (array-fold-right list 1 1)) (test-error (array-for-each 1 #f)) (test-error (array-for-each list 1 (make-array (make-interval '#(3) '#(4)) list))) @@ -2112,10 +2112,10 @@ 0 1) (matrix 1 0 i 1)))))) - (test (array-foldr x2x2-multiply (matrix 1 0 0 1) A) + (test (array-fold-right x2x2-multiply (matrix 1 0 0 1) A) (array-reduce x2x2-multiply A)) - (test-not (equal? (array-reduce x2x2-multiply A) - (array-foldl x2x2-multiply (matrix 1 0 0 1) A)))) + (test (array-reduce x2x2-multiply A) + (array-fold-left x2x2-multiply (matrix 1 0 0 1) A))) (let ((A_2 (make-array (make-interval '#(1 1) '#(3 7)) (lambda (i j) @@ -2124,10 +2124,10 @@ j 1) (matrix 1 j i -1)))))) - (test (array-foldr x2x2-multiply (matrix 1 0 0 1) A_2) + (test (array-fold-right x2x2-multiply (matrix 1 0 0 1) A_2) (array-reduce x2x2-multiply A_2)) - (test-not (equal? (array-reduce x2x2-multiply A_2) - (array-foldl x2x2-multiply (matrix 1 0 0 1) A_2))) + (test (array-reduce x2x2-multiply A_2) + (array-fold-left x2x2-multiply (matrix 1 0 0 1) A_2)) (test-not (equal? (array-reduce x2x2-multiply A_2) (array-reduce x2x2-multiply (array-rotate A_2 1))))) @@ -2138,10 +2138,10 @@ j k) (matrix k j i -1)))))) - (test (array-foldr x2x2-multiply (matrix 1 0 0 1) A_3) + (test (array-fold-right x2x2-multiply (matrix 1 0 0 1) A_3) (array-reduce x2x2-multiply A_3)) - (test-not (equal? (array-reduce x2x2-multiply A_3) - (array-foldl x2x2-multiply (matrix 1 0 0 1) A_3))) + (test (array-reduce x2x2-multiply A_3) + (array-fold-right x2x2-multiply (matrix 1 0 0 1) A_3)) (test-not (equal? (array-reduce x2x2-multiply A_3) (array-reduce x2x2-multiply (array-rotate A_3 1))))) @@ -2152,10 +2152,10 @@ j k) (matrix l k i j)))))) - (test (array-foldr x2x2-multiply (matrix 1 0 0 1) A_4) + (test (array-fold-right x2x2-multiply (matrix 1 0 0 1) A_4) (array-reduce x2x2-multiply A_4)) - (test-not (equal? (array-reduce x2x2-multiply A_4) - (array-foldl x2x2-multiply (matrix 1 0 0 1) A_4))) + (test (array-reduce x2x2-multiply A_4) + (array-fold-left x2x2-multiply (matrix 1 0 0 1) A_4)) (test-not (equal? (array-reduce x2x2-multiply A_4) (array-reduce x2x2-multiply (array-rotate A_4 1))))) @@ -2166,10 +2166,10 @@ j k) (matrix (- l m) k i j)))))) - (test (array-foldr x2x2-multiply (matrix 1 0 0 1) A_5) + (test (array-fold-right x2x2-multiply (matrix 1 0 0 1) A_5) (array-reduce x2x2-multiply A_5)) - (test-not (equal? (array-reduce x2x2-multiply A_5) - (array-foldl x2x2-multiply (matrix 1 0 0 1) A_5))) + (test (array-reduce x2x2-multiply A_5) + (array-fold-left x2x2-multiply (matrix 1 0 0 1) A_5)) (test-not (equal? (array-reduce x2x2-multiply A_5) (array-reduce x2x2-multiply (array-rotate A_5 1))))) diff --git a/lib/srfi/231/transforms.scm b/lib/srfi/231/transforms.scm index 8c0686ab..e3a97955 100644 --- a/lib/srfi/231/transforms.scm +++ b/lib/srfi/231/transforms.scm @@ -379,14 +379,46 @@ (apply f (map (lambda (g) (apply g multi-index)) getters)))) (array-domain array)))) -(define (array-foldl kons knil array) - (interval-fold (lambda (acc . multi-index) - (kons (apply array-ref array multi-index) acc)) - knil - (array-domain array))) +(define (array-fold-left operator identity array . arrays) + (assert (and (procedure? operator) + (array? array) + (every array? arrays) + (every (lambda (a) + (interval= (array-domain array) + (array-domain a))) + arrays))) + (if (null? arrays) + (interval-fold-left (array-getter array) + (lambda (accumulator array-element) + (operator accumulator array-element)) + identity + (array-domain array)) + (interval-fold-left (array-getter (apply array-map list array arrays)) + (lambda (accumulator array-elements) + (apply operator accumulator array-elements)) + identity + (array-domain array)))) -(define (array-foldr kons knil array) - (fold-right kons knil (array->list array))) +(define (array-fold-right operator identity array . arrays) + (assert (and (procedure? operator) + (array? array) + (every array? arrays) + (every (lambda (a) + (interval= (array-domain array) + (array-domain a))) + arrays))) + (if (null? arrays) + (interval-fold-right (array-getter array) + (lambda (array-element accumulator) + (operator array-element accumulator)) + identity + (array-domain array)) + (interval-fold-right + (array-getter (apply array-map list array arrays)) + (lambda (array-elements accumulator) + (apply operator (append array-elements (list accumulator)))) + identity + (array-domain array)))) (define (array-reduce op array) (let* ((domain (array-domain array)) @@ -424,7 +456,7 @@ (array-domain array))))) (define (array->list array) - (reverse (array-foldl cons '() array))) + (reverse (array-fold-left xcons '() array))) (define (list->array domain ls . o) (let* ((storage (if (pair? o) (car o) generic-storage-class))