add docs and tests for assert, unify with SRFI 145

This commit is contained in:
Alex Shinn 2021-05-06 20:34:03 +09:00
parent e3782e35a5
commit 12ad1d37d8
5 changed files with 152 additions and 81 deletions

30
lib/chibi/assert-test.sld Normal file
View file

@ -0,0 +1,30 @@
(define-library (chibi assert-test)
(import (chibi) (chibi assert) (chibi test))
(export run-tests)
(begin
(define-syntax test-assert
(syntax-rules ()
((test-assert irritants expr)
(protect (exn
(else
(test irritants (exception-irritants exn))))
expr
(error "assertion not triggered")))))
(define (run-tests)
(test-begin "assert")
(test-assert '((= x (+ x 1))
(x 3))
(let ((x 3)) (assert (= x (+ x 1)))))
(test-assert '((= x (+ y 1))
(x 3)
(y 42))
(let ((x 3) (y 42)) (assert (= x (+ y 1)))))
(test-assert '((eq? x 'three)
(x 3))
(let ((x 3)) (assert (eq? x 'three))))
(test-assert '((eq? x 'three)
"expected three: "
3)
(let ((x 3)) (assert (eq? x 'three) "expected three: " x)))
(test-end))))

View file

@ -1,3 +1,39 @@
;;> A nice assert macro.
;;>
;;> Assert macros are common in Scheme, in particular being helpful
;;> for domain checks at the beginning of a procedure to catch errors
;;> as early as possible. Compared to statically typed languages this
;;> has the advantages that the assertions are optional, and that they
;;> are not limited by the type system. SRFI 145 provides the related
;;> notion of assumptions, but the motivation there is to provide
;;> hints to optimizing compilers, and these are not required to
;;> actually signal an error.
;;>
;;> \macro{(assert expr [msg ...])}
;;>
;;> Equivalent to SRFI 145 \code{assume} except that an error is
;;> guaranteed to be raised if \var{expr} is false. Conceptually
;;> shorthand for
;;>
;;> \code{(or \var{expr}
;;> (error "assertion failed" \var{msg} ...))}
;;>
;;> that is, evaluates \var{expr} and returns it if true, but raises
;;> an exception otherwise. The error is augmented to include the
;;> text of the failed \var{expr}. If no additional \var{msg}
;;> arguments are provided then \var{expr} is scanned for free
;;> variables in non-operator positions to report values from, e.g. in
;;>
;;> \code{(let ((x 3))
;;> (assert (= x (+ x 1))))}
;;>
;;> the error would also report the bound value of \code{x}. This
;;> uses the technique from Oleg Kiselyov's \hyperlink[http://okmij.org/ftp/Scheme/assert-syntax-rule.txt]{good assert macro},
;;> which is convenient but fallible. It is thus best to keep the
;;> body of the assertion simple, moving any predicates you need to
;;> external utilities, or provide an explicit \var{msg}.
(define-library (chibi assert) (define-library (chibi assert)
(export assert) (export assert)
(cond-expand (cond-expand
@ -10,11 +46,11 @@
(if (identifier? (cadr expr)) (if (identifier? (cadr expr))
(car (cddr expr)) (car (cddr expr))
(cadr (cddr expr)))))) (cadr (cddr expr))))))
(define-syntax syntax-memq? (define-syntax syntax-id-memq?
(er-macro-transformer (er-macro-transformer
(lambda (expr rename compare) (lambda (expr rename compare)
(let ((expr (cdr expr))) (let ((expr (cdr expr)))
(if (memq (car expr) (cadr expr)) (if (any (lambda (x) (compare x (car expr))) (cadr expr))
(car (cddr expr)) (car (cddr expr))
(cadr (cddr expr))))))))) (cadr (cddr expr)))))))))
(else (else
@ -32,7 +68,7 @@
((sym? x sk fk) sk) ((sym? x sk fk) sk)
((sym? y sk fk) fk)))) ((sym? y sk fk) fk))))
(sym? abracadabra success-k failure-k))))) (sym? abracadabra success-k failure-k)))))
(define-syntax syntax-memq? (define-syntax syntax-id-memq?
(syntax-rules () (syntax-rules ()
((syntax-memq? id (ids ...) sk fk) ((syntax-memq? id (ids ...) sk fk)
(let-syntax (let-syntax
@ -42,35 +78,38 @@
((memq? any-other sk2 fk2) sk2)))) ((memq? any-other sk2 fk2) sk2))))
(memq? random-symbol-to-match sk fk)))))))) (memq? random-symbol-to-match sk fk))))))))
(begin (begin
(define-syntax report-vars (define-syntax extract-vars
(syntax-rules () (syntax-rules ()
((report-vars (op arg0 arg1 ...) (next ...) res) ((report-vars (op arg0 arg1 ...) (next ...) res)
(syntax-memq? op (quote quasiquote lambda let let* letrec letrec* (syntax-id-memq? op (quote quasiquote lambda let let* letrec letrec*
let-syntax letrec-syntax let-values let*-values let-syntax letrec-syntax let-values let*-values
receive match case define define-syntax do) receive match case define define-syntax do)
(next ... res) (next ... res)
(report-vars arg0 (extract-vars arg0
(report-vars (op arg1 ...) (next ...)) (extract-vars (op arg1 ...) (next ...))
res))) res)))
((report-vars (op . x) (next ...) res) ((report-vars (op . x) (next ...) res)
(next ... res)) (next ... res))
((report-vars x (next ...) (res ...)) ((report-vars x (next ...) (res ...))
(syntax-identifier? x (syntax-identifier? x
(syntax-memq? x (res ...) (syntax-id-memq? x (res ...)
(next ... (res ...)) (next ... (res ...))
(next ... (res ... x))) (next ... (res ... x)))
(next ... (res ...)))))) (next ... (res ...))))))
(define-syntax qq-vars
(syntax-rules ()
((qq-vars (next ...) (var ...))
(next ... `(var ,var) ...))))
(define-syntax report-final (define-syntax report-final
(syntax-rules () (syntax-rules ()
((report-final expr (var ...)) ((report-final expr msg ...)
(error "assertion failed" 'expr `(var ,var) ...)))) (error "assertion failed" 'expr msg ...))))
(define-syntax assert (define-syntax assert
(syntax-rules (report:) (syntax-rules ()
((assert test report: msg ...) ((assert test)
(unless test (or test
(error msg ...))) (extract-vars test (qq-vars (report-final test)) ())))
((assert test0 test1 ...) ((assert test msg ...)
(if test0 (or test
(assert test1 ...) (report-final test msg ...)))
(report-vars test0 (report-final test0) ())))
((assert) #t))))) ((assert) #t)))))

View file

@ -1,6 +1,6 @@
(define-library (srfi 145) (define-library (srfi 145)
(export assume) (export assume)
(import (scheme base)) (import (scheme base) (chibi assert))
(cond-expand (cond-expand
((or elide-assumptions ((or elide-assumptions
(and (not assumptions) (and (not assumptions)
@ -17,7 +17,6 @@
(define-syntax assume (define-syntax assume
(syntax-rules () (syntax-rules ()
((assume expression objs ...) ((assume expression objs ...)
(or expression (assert expression objs ...))
(error "invalid assumption" 'expression objs ...)))
((assume) ((assume)
(syntax-error "assume requires an expression")))))))) (syntax-error "assume requires an expression"))))))))

View file

@ -30,12 +30,12 @@
(ub interval-ub)) (ub interval-ub))
(define (%make-interval lo hi) (define (%make-interval lo hi)
(assert (translation? lo) (assert (and (translation? lo)
(translation? hi) (translation? hi)
(not (vector-empty? lo)) (not (vector-empty? lo))
(not (vector-empty? hi)) (not (vector-empty? hi))
(= (vector-length lo) (vector-length hi)) (= (vector-length lo) (vector-length hi))
(vector-every < lo hi)) (vector-every < lo hi)))
(%%make-interval lo hi)) (%%make-interval lo hi))
(define (make-interval x . o) (define (make-interval x . o)
@ -54,7 +54,7 @@
(define (interval-upper-bounds->vector iv) (vector-copy (interval-ub iv))) (define (interval-upper-bounds->vector iv) (vector-copy (interval-ub iv)))
(define (interval= iv1 iv2) (define (interval= iv1 iv2)
(assert (interval? iv1) (interval? iv2)) (assert (and (interval? iv1) (interval? iv2)))
(equal? iv1 iv2)) (equal? iv1 iv2))
(define (interval-volume iv) (define (interval-volume iv)
@ -63,16 +63,16 @@
(interval-lb iv) (interval-ub iv))) (interval-lb iv) (interval-ub iv)))
(define (interval-subset? iv1 iv2) (define (interval-subset? iv1 iv2)
(assert (interval? iv1) (interval? iv2) (assert (and (interval? iv1) (interval? iv2)
(= (interval-dimension iv1) (interval-dimension iv2))) (= (interval-dimension iv1) (interval-dimension iv2))))
(and (vector-every >= (interval-lb iv1) (interval-lb iv2)) (and (vector-every >= (interval-lb iv1) (interval-lb iv2))
(vector-every <= (interval-ub iv1) (interval-ub iv2)))) (vector-every <= (interval-ub iv1) (interval-ub iv2))))
(define (interval-contains-multi-index? iv i0 . o) (define (interval-contains-multi-index? iv i0 . o)
(assert (interval? iv)) (assert (interval? iv))
(let ((i (list->vector (cons i0 o)))) (let ((i (list->vector (cons i0 o))))
(assert (= (interval-dimension iv) (vector-length i)) (assert (and (= (interval-dimension iv) (vector-length i))
(vector-every integer? i)) (vector-every integer? i)))
(and (vector-every >= i (interval-lb iv)) (and (vector-every >= i (interval-lb iv))
(vector-every < i (interval-ub iv))))) (vector-every < i (interval-ub iv)))))
@ -136,8 +136,8 @@
(define (interval-intersect iv0 . o) (define (interval-intersect iv0 . o)
(let ((ls (cons iv0 o))) (let ((ls (cons iv0 o)))
(assert (every interval? ls) (assert (and (every interval? ls)
(or (null? o) (apply = (map interval-dimension ls)))) (or (null? o) (apply = (map interval-dimension ls)))))
(let ((lower (apply vector-map max (map interval-lb ls))) (let ((lower (apply vector-map max (map interval-lb ls)))
(upper (apply vector-map min (map interval-ub ls)))) (upper (apply vector-map min (map interval-ub ls))))
(and (vector-every < lower upper) (and (vector-every < lower upper)
@ -148,7 +148,7 @@
(interval-dilate iv translation translation)) (interval-dilate iv translation translation))
(define (interval-permute iv perm) (define (interval-permute iv perm)
(assert (interval? iv) (permutation? perm)) (assert (and (interval? iv) (permutation? perm)))
(let* ((len (interval-dimension iv)) (let* ((len (interval-dimension iv))
(lower (make-vector len)) (lower (make-vector len))
(upper (make-vector len))) (upper (make-vector len)))
@ -167,11 +167,11 @@
(vector-copy upper 0 dim))))) (vector-copy upper 0 dim)))))
(define (interval-scale iv scales) (define (interval-scale iv scales)
(assert (interval? iv) (assert (and (interval? iv)
(vector? scales) (vector? scales)
(= (interval-dimension iv) (vector-length scales)) (= (interval-dimension iv) (vector-length scales))
(vector-every exact-integer? scales) (vector-every exact-integer? scales)
(vector-every positive? scales)) (vector-every positive? scales)))
(make-interval (make-interval
(vector-map (lambda (u s) (exact (ceiling (/ u s)))) (vector-map (lambda (u s) (exact (ceiling (/ u s))))
(interval-ub iv) (interval-ub iv)
@ -273,14 +273,14 @@
(safe? array-safe?)) (safe? array-safe?))
(define (%make-array domain getter setter storage body coeffs indexer safe?) (define (%make-array domain getter setter storage body coeffs indexer safe?)
(assert (interval? domain) (assert (and (interval? domain)
(procedure? getter) (procedure? getter)
(or (not setter) (procedure? setter)) (or (not setter) (procedure? setter))
(or (not storage) (storage-class? storage))) (or (not storage) (storage-class? storage))))
(%%make-array domain getter setter storage body coeffs indexer safe?)) (%%make-array domain getter setter storage body coeffs indexer safe?))
(define (make-array domain getter . o) (define (make-array domain getter . o)
(assert (interval? domain) (procedure? getter)) (assert (and (interval? domain) (procedure? getter)))
(%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f)) (%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f))
(define (array-dimension a) (define (array-dimension a)
@ -483,7 +483,7 @@
#t)))) #t))))
(define (specialized-array-share array new-domain project) (define (specialized-array-share array new-domain project)
(assert (specialized-array? array) (interval? new-domain)) (assert (and (specialized-array? array) (interval? new-domain)))
(let* ((body (array-body array)) (let* ((body (array-body array))
(coeffs (coeffs
(indexer->coeffs (indexer->coeffs
@ -509,8 +509,8 @@
(mutable? (if (pair? o) (car o) (specialized-array-default-mutable?))) (mutable? (if (pair? o) (car o) (specialized-array-default-mutable?)))
(o (if (pair? o) (cdr o) '())) (o (if (pair? o) (cdr o) '()))
(safe? (if (pair? o) (car o) (specialized-array-default-safe?)))) (safe? (if (pair? o) (car o) (specialized-array-default-safe?))))
(assert (storage-class? storage) (interval? new-domain) (assert (and (storage-class? storage) (interval? new-domain)
(boolean? mutable?) (boolean? safe?)) (boolean? mutable?) (boolean? safe?)))
(let* ((body ((storage-class-maker storage) (let* ((body ((storage-class-maker storage)
(interval-volume new-domain) (interval-volume new-domain)
(storage-class-default storage))) (storage-class-default storage)))
@ -556,9 +556,9 @@
)))))))))) ))))))))))
(define (array-extract array new-domain) (define (array-extract array new-domain)
(assert (array? array) (assert (and (array? array)
(interval? new-domain) (interval? new-domain)
(interval-subset? new-domain (array-domain array))) (interval-subset? new-domain (array-domain array))))
(if (specialized-array? array) (if (specialized-array? array)
(specialized-array-share array new-domain values) (specialized-array-share array new-domain values)
(make-array new-domain (make-array new-domain
@ -566,14 +566,14 @@
(array-setter array)))) (array-setter array))))
(define (array-tile array sizes) (define (array-tile array sizes)
(assert (array? array) (assert (and (array? array)
(vector? sizes) (vector? sizes)
(= (array-dimension array) (vector-length sizes)) (= (array-dimension array) (vector-length sizes))
(vector-every exact-integer? sizes) (vector-every exact-integer? sizes)
(vector-every >= sizes (interval-lower-bounds->vector (vector-every >= sizes (interval-lower-bounds->vector
(array-domain array))) (array-domain array)))
(vector-every < sizes (interval-upper-bounds->vector (vector-every < sizes (interval-upper-bounds->vector
(array-domain array)))) (array-domain array)))))
(let ((domain (make-interval (let ((domain (make-interval
(vector-map (vector-map
(lambda (lo hi s) (exact (ceiling (/ (- hi lo) s)))) (lambda (lo hi s) (exact (ceiling (/ (- hi lo) s))))
@ -662,9 +662,9 @@
(define (array-reverse array . o) (define (array-reverse array . o)
(assert (array? array)) (assert (array? array))
(let ((flip? (if (pair? o) (car o) (make-vector (array-dimension array) #t)))) (let ((flip? (if (pair? o) (car o) (make-vector (array-dimension array) #t))))
(assert (vector? flip?) (assert (and (vector? flip?)
(= (array-dimension array) (vector-length flip?)) (= (array-dimension array) (vector-length flip?))
(vector-every boolean? flip?)) (vector-every boolean? flip?)))
(let* ((flips (vector->list flip?)) (let* ((flips (vector->list flip?))
(domain (array-domain array)) (domain (array-domain array))
(lowers (interval-lower-bounds->list domain)) (lowers (interval-lower-bounds->list domain))
@ -715,7 +715,7 @@
(apply array-set! array val (map * multi-index scales-ls)))))))) (apply array-set! array val (map * multi-index scales-ls))))))))
(define (array-outer-product op array1 array2) (define (array-outer-product op array1 array2)
(assert (procedure? op) (array? array1) (array? array2)) (assert (and (procedure? op) (array? array1) (array? array2)))
(make-array (interval-cartesian-product (array-domain array1) (make-array (interval-cartesian-product (array-domain array1)
(array-domain array2)) (array-domain array2))
(let ((getter1 (array-getter array1)) (let ((getter1 (array-getter array1))
@ -800,8 +800,8 @@
(car (cddr o)) (car (cddr o))
(specialized-array-default-safe?))) (specialized-array-default-safe?)))
(res (make-specialized-array domain storage safe?))) (res (make-specialized-array domain storage safe?)))
(assert (interval? domain) (storage-class? storage) (assert (and (interval? domain) (storage-class? storage)
(boolean? mutable?) (boolean? safe?)) (boolean? mutable?) (boolean? safe?)))
(interval-fold (interval-fold
(lambda (ls . multi-index) (lambda (ls . multi-index)
(apply array-set! res (car ls) multi-index) (apply array-set! res (car ls) multi-index)
@ -811,13 +811,14 @@
res)) res))
(define (array-assign! destination source) (define (array-assign! destination source)
(assert (array? destination) (assert
(and (array? destination)
(mutable-array? destination) (mutable-array? destination)
(array? source) (array? source)
(or (equal? (array-domain destination) (array-domain source)) (or (equal? (array-domain destination) (array-domain source))
(and (array-elements-in-order? destination) (and (array-elements-in-order? destination)
(equal? (interval-volume (array-domain destination)) (equal? (interval-volume (array-domain destination))
(interval-volume (array-domain source)))))) (interval-volume (array-domain source)))))))
(let ((getter (array-getter source)) (let ((getter (array-getter source))
(setter (array-setter destination))) (setter (array-setter destination)))
(if (equal? (array-domain destination) (array-domain source)) (if (equal? (array-domain destination) (array-domain source))
@ -878,9 +879,9 @@
(lp (+ i 1) (cdr ls))))))))) (lp (+ i 1) (cdr ls)))))))))
(define (specialized-array-reshape array new-domain . o) (define (specialized-array-reshape array new-domain . o)
(assert (specialized-array? array) (assert (and (specialized-array? array)
(= (interval-volume (array-domain array)) (= (interval-volume (array-domain array))
(interval-volume new-domain))) (interval-volume new-domain))))
(let ((copy-on-failure? (and (pair? o) (car o)))) (let ((copy-on-failure? (and (pair? o) (car o))))
(cond (cond
((reshape-without-copy array new-domain)) ((reshape-without-copy array new-domain))

View file

@ -36,6 +36,7 @@
(rename (srfi 166 test) (run-tests run-srfi-166-tests)) (rename (srfi 166 test) (run-tests run-srfi-166-tests))
(rename (srfi 219 test) (run-tests run-srfi-219-tests)) (rename (srfi 219 test) (run-tests run-srfi-219-tests))
(rename (scheme bytevector-test) (run-tests run-scheme-bytevector-tests)) (rename (scheme bytevector-test) (run-tests run-scheme-bytevector-tests))
(rename (chibi assert-test) (run-tests run-assert-tests))
(rename (chibi base64-test) (run-tests run-base64-tests)) (rename (chibi base64-test) (run-tests run-base64-tests))
(rename (chibi bytevector-test) (run-tests run-bytevector-tests)) (rename (chibi bytevector-test) (run-tests run-bytevector-tests))
(rename (chibi crypto md5-test) (run-tests run-md5-tests)) (rename (chibi crypto md5-test) (run-tests run-md5-tests))
@ -106,6 +107,7 @@
(run-srfi-166-tests) (run-srfi-166-tests)
(run-srfi-219-tests) (run-srfi-219-tests)
(run-scheme-bytevector-tests) (run-scheme-bytevector-tests)
(run-assert-tests)
(run-base64-tests) (run-base64-tests)
(run-bytevector-tests) (run-bytevector-tests)
(run-doc-tests) (run-doc-tests)