chibi-scheme/lib/srfi/113/bags.scm
2018-01-24 23:58:30 +09:00

347 lines
11 KiB
Scheme

(define-record-type Bag (make-bag table comparator) bag?
(table bag-table)
(comparator bag-comparator))
(define (bag comparator . elts)
(let ((res (make-bag (make-hash-table comparator) comparator)))
(for-each (lambda (x) (bag-adjoin! res x)) elts)
res))
(define (bag-unfold comparator stop? mapper successor seed)
(let ((mapper (lambda (acc) (let ((elt (mapper acc))) (values elt 1)))))
(make-bag (hash-table-unfold stop? mapper successor seed comparator)
comparator)))
(define (bag-contains? bag element)
(hash-table-contains? (bag-table bag) element))
(define (bag-empty? bag)
(zero? (bag-size bag)))
(define (bag-disjoint? bag1 bag2)
(if (< (hash-table-size (bag-table bag2))
(hash-table-size (bag-table bag1)))
(bag-disjoint? bag2 bag1)
(let ((ht (bag-table bag2)))
(not (hash-table-find (lambda (key value) (hash-table-contains? ht key))
(bag-table bag1)
(lambda () #f))))))
(define (bag-member bag element default)
;; (let ((cell (hash-table-cell (bag-table bag) element #f)))
;; (if cell (car cell) default))
(if (hash-table-contains? (bag-table bag) element)
element
default))
(define (bag-element-comparator bag)
(bag-comparator bag))
(define (bag-adjoin bag . elts)
(apply bag-adjoin! (bag-copy bag) elts))
(define (bag-adjoin! bag . elts)
(for-each (lambda (elt)
(hash-table-update!/default (bag-table bag)
elt
(lambda (count) (+ 1 count))
0))
elts)
bag)
(define (bag-replace bag element)
(bag-replace! (bag-copy bag) element))
(define (bag-replace! bag element)
(when (hash-table-contains? (bag-table bag) element)
(hash-table-delete! (bag-table bag) element)
(hash-table-set! (bag-table bag) element 1))
bag)
(define (bag-delete bag . elts)
(bag-delete-all bag elts))
(define (bag-delete! bag . elts)
(bag-delete-all! bag elts))
(define (bag-delete-all bag element-list)
(bag-delete-all! (bag-copy bag) element-list))
(define (bag-delete-all! bag element-list)
(let ((ht (bag-table bag)))
(for-each (lambda (elt)
(let ((count (- (hash-table-ref/default ht elt 0) 1)))
(cond
((positive? count) (hash-table-set! ht elt count))
((zero? count) (hash-table-delete! ht elt)))))
element-list))
bag)
(define bag-search!
(let ((not-found (list 'not-found)))
(lambda (bag element failure success)
(let ((elt (hash-table-ref/default (bag-table bag) element not-found)))
(if (eq? elt not-found)
(failure (lambda (obj)
(hash-table-set! (bag-table bag) element 1)
(values bag obj))
(lambda (obj)
(values bag obj)))
(success elt
(lambda (new-element obj)
(hash-table-delete! (bag-table bag) element)
(bag-adjoin! bag new-element)
(values bag obj))
(lambda (obj)
(hash-table-delete! (bag-table bag) element)
(values bag obj))))))))
(define (bag-size bag)
(hash-table-fold (bag-table bag) (lambda (elt count acc) (+ count acc)) 0))
(define (bag-find predicate bag failure)
(call-with-current-continuation
(lambda (return)
(hash-table-for-each
(lambda (elt count) (if (predicate elt) (return elt)))
(bag-table bag))
(failure))))
(define (bag-count predicate bag)
(hash-table-fold (lambda (elt count acc) (+ acc (if (predicate elt) count 0)))
0
(bag-table bag)))
(define (bag-any? predicate bag)
(and (hash-table-find (lambda (key value) (predicate key))
(bag-table bag)
(lambda () #f))
#t))
(define (bag-every? predicate bag)
(not (bag-any? (lambda (x) (not (predicate x))) bag)))
(define (bag-map comparator proc s)
(bag-fold (lambda (elt res) (bag-adjoin! res (proc elt)))
(bag comparator)
s))
(define (bag-for-each proc bag)
(hash-table-for-each (lambda (elt count)
(let lp ((i count))
(when (positive? i)
(proc elt)
(lp (- i 1)))))
(bag-table bag)))
(define (bag-fold proc nil bag)
(hash-table-fold (lambda (elt count acc)
(let lp ((i count) (acc acc))
(if (zero? i)
acc
(lp (- i 1) (proc elt acc)))))
nil
(bag-table bag)))
(define (bag-filter predicate st)
(bag-fold (lambda (elt res)
(if (predicate elt) (bag-adjoin! res elt) res))
(bag (bag-comparator st))
st))
(define bag-filter! bag-filter)
(define (bag-remove predicate bag)
(bag-filter (lambda (elt) (not (predicate elt))) bag))
(define bag-remove! bag-remove)
(define (bag-partition predicate bag)
(values (bag-filter predicate bag)
(bag-remove predicate bag)))
(define bag-partition! bag-partition)
(define (bag-copy bag)
(make-bag (hash-table-copy (bag-table bag))
(bag-comparator bag)))
(define (bag->list bag)
(hash-table-keys (bag-table bag)))
(define (list->bag comparator list)
(fold (lambda (elt bag) (bag-adjoin! bag elt)) (bag comparator) list))
(define (list->bag! bag list)
(fold (lambda (elt bag) (bag-adjoin! bag elt)) bag list))
(define (comparable-bags? bag1 bag2)
(or (eq? (bag-comparator bag1) (bag-comparator bag2))
(error "can't compare bags with different comparators" bag1 bag2)))
(define (bag=? bag1 . bags)
(or (null? bags)
(and (comparable-bags? bag1 (car bags))
(= (bag-size bag1) (bag-size (car bags)))
(bag-every? (lambda (elt) (bag-contains? bag1 elt)) (car bags))
(apply bag=? bags))))
(define (bag<? bag1 . bags)
(or (null? bags)
(and (comparable-bags? bag1 (car bags))
(< (bag-size bag1) (bag-size (car bags)))
(bag-every? (lambda (elt) (bag-contains? (car bags) elt)) bag1)
(apply bag<? bags))))
(define (bag>? . bags)
(apply bag<? (reverse bags)))
(define (bag<=? bag1 . bags)
(or (null? bags)
(and (comparable-bags? bag1 (car bags))
(<= (bag-size bag1) (bag-size (car bags)))
(bag-every? (lambda (elt) (bag-contains? (car bags) elt)) bag1)
(apply bag<=? bags))))
(define (bag>=? . bags)
(apply bag<=? (reverse bags)))
(define (bag-union bag1 . bags)
(apply bag-union! (bag-copy bag1) bags))
(define (bag-intersection bag1 . bags)
(apply bag-intersection! (bag-copy bag1) bags))
(define (bag-difference bag1 . bags)
(apply bag-difference! (bag-copy bag1) bags))
(define (bag-xor bag1 bag2)
(bag-xor! (bag-copy bag1) bag2))
(define (bag-union! bag1 . bags)
(if (null? bags)
bag1
(and (comparable-bags? bag1 (car bags))
(begin
(hash-table-for-each
(lambda (elt count)
(hash-table-update!/default (bag-table bag1)
elt
(lambda (c) (max c count))
count))
(bag-table (car bags)))
(apply bag-union! bag1 (cdr bags))))))
(define (bag-intersection! bag1 . bags)
(if (null? bags)
bag1
(and (comparable-bags? bag1 (car bags))
(let ((ht (bag-table (car bags))))
(hash-table-for-each
(lambda (elt count)
(let ((count2 (min count (hash-table-ref/default ht elt 0))))
(if (positive? count2)
(hash-table-set! (bag-table bag1) elt count2)
(hash-table-delete! (bag-table bag1) elt))))
(bag-table bag1))
(apply bag-intersection! bag1 (cdr bags))))))
(define (bag-difference! bag1 . bags)
(if (null? bags)
bag1
(and (comparable-bags? bag1 (car bags))
(let ((ht (bag-table (car bags))))
(hash-table-for-each
(lambda (elt count)
(let ((count2 (- count (hash-table-ref/default ht elt 0))))
(if (positive? count2)
(hash-table-set! (bag-table bag1) elt count2)
(hash-table-delete! (bag-table bag1) elt))))
(bag-table bag1))
(apply bag-difference! bag1 (cdr bags))))))
(define (bag-xor! bag1 bag2)
(and (comparable-bags? bag1 bag2)
(let ((ht1 (bag-table bag1))
(ht2 (bag-table bag2)))
(hash-table-for-each
(lambda (elt count)
(let ((count2 (abs (- count (hash-table-ref/default ht1 elt 0)))))
(if (positive? count2)
(hash-table-set! ht1 elt count2)
(hash-table-delete! ht1 elt))))
ht2)
bag1)))
(define (bag-sum bag1 . bags)
(apply bag-sum! (bag-copy bag1) bags))
(define (bag-sum! bag1 . bags)
(if (null? bags)
bag1
(and (comparable-bags? bag1 (car bags))
(begin
(hash-table-for-each
(lambda (elt count)
(hash-table-update!/default (bag-table bag1)
elt
(lambda (c) (+ c count))
count))
(bag-table (car bags)))
(apply bag-sum! bag1 (cdr bags))))))
(define (bag-product n bag)
(bag-product! n (bag-copy bag)))
(define (bag-product! n bag)
(for-each
(lambda (elt)
(hash-table-update! (bag-table bag) elt (lambda (count) (* n count))))
(hash-table-keys (bag-table bag)))
bag)
(define (bag-unique-size bag)
(hash-table-size (bag-table bag)))
(define (bag-element-count bag element)
(hash-table-ref/default (bag-table bag) element 0))
(define (bag-for-each-unique proc bag)
(hash-table-for-each proc (bag-table bag)))
(define (bag-fold-unique proc nil bag)
(hash-table-fold proc nil (bag-table bag)))
(define (bag-increment! bag element count)
(let* ((ht (bag-table bag))
(count2 (+ count (hash-table-ref/default ht element 0))))
(if (positive? count2)
(hash-table-set! ht element count2)
(hash-table-delete! ht element))))
(define (bag-decrement! bag element count)
(bag-increment! bag element (- count)))
(define (bag->set bag)
(let ((ht (hash-table-copy (bag-table bag))))
(hash-table-map! (lambda (key count) key) ht)
(make-set ht (bag-comparator bag))))
(define (set->bag set)
(set->bag! (bag (set-comparator set)) set))
(define (set->bag! bag set)
(set-for-each (lambda (elt) (bag-adjoin! bag elt)) set)
bag)
(define (bag->alist bag)
(hash-table->alist (bag-table bag)))
(define (alist->bag comparator alist)
(let ((res (bag comparator)))
(for-each (lambda (x) (bag-increment! res (car x) (cdr x))) alist)
res))
(define the-bag-comparator
(make-comparator bag? bag=? bag<? hash))