add docs and tests for assert, unify with SRFI 145

This commit is contained in:
Alex Shinn 2021-05-06 20:34:03 +09:00
parent e3782e35a5
commit 12ad1d37d8
5 changed files with 152 additions and 81 deletions

30
lib/chibi/assert-test.sld Normal file
View file

@ -0,0 +1,30 @@
(define-library (chibi assert-test)
(import (chibi) (chibi assert) (chibi test))
(export run-tests)
(begin
(define-syntax test-assert
(syntax-rules ()
((test-assert irritants expr)
(protect (exn
(else
(test irritants (exception-irritants exn))))
expr
(error "assertion not triggered")))))
(define (run-tests)
(test-begin "assert")
(test-assert '((= x (+ x 1))
(x 3))
(let ((x 3)) (assert (= x (+ x 1)))))
(test-assert '((= x (+ y 1))
(x 3)
(y 42))
(let ((x 3) (y 42)) (assert (= x (+ y 1)))))
(test-assert '((eq? x 'three)
(x 3))
(let ((x 3)) (assert (eq? x 'three))))
(test-assert '((eq? x 'three)
"expected three: "
3)
(let ((x 3)) (assert (eq? x 'three) "expected three: " x)))
(test-end))))

View file

@ -1,3 +1,39 @@
;;> A nice assert macro.
;;>
;;> Assert macros are common in Scheme, in particular being helpful
;;> for domain checks at the beginning of a procedure to catch errors
;;> as early as possible. Compared to statically typed languages this
;;> has the advantages that the assertions are optional, and that they
;;> are not limited by the type system. SRFI 145 provides the related
;;> notion of assumptions, but the motivation there is to provide
;;> hints to optimizing compilers, and these are not required to
;;> actually signal an error.
;;>
;;> \macro{(assert expr [msg ...])}
;;>
;;> Equivalent to SRFI 145 \code{assume} except that an error is
;;> guaranteed to be raised if \var{expr} is false. Conceptually
;;> shorthand for
;;>
;;> \code{(or \var{expr}
;;> (error "assertion failed" \var{msg} ...))}
;;>
;;> that is, evaluates \var{expr} and returns it if true, but raises
;;> an exception otherwise. The error is augmented to include the
;;> text of the failed \var{expr}. If no additional \var{msg}
;;> arguments are provided then \var{expr} is scanned for free
;;> variables in non-operator positions to report values from, e.g. in
;;>
;;> \code{(let ((x 3))
;;> (assert (= x (+ x 1))))}
;;>
;;> the error would also report the bound value of \code{x}. This
;;> uses the technique from Oleg Kiselyov's \hyperlink[http://okmij.org/ftp/Scheme/assert-syntax-rule.txt]{good assert macro},
;;> which is convenient but fallible. It is thus best to keep the
;;> body of the assertion simple, moving any predicates you need to
;;> external utilities, or provide an explicit \var{msg}.
(define-library (chibi assert)
(export assert)
(cond-expand
@ -10,11 +46,11 @@
(if (identifier? (cadr expr))
(car (cddr expr))
(cadr (cddr expr))))))
(define-syntax syntax-memq?
(define-syntax syntax-id-memq?
(er-macro-transformer
(lambda (expr rename compare)
(let ((expr (cdr expr)))
(if (memq (car expr) (cadr expr))
(if (any (lambda (x) (compare x (car expr))) (cadr expr))
(car (cddr expr))
(cadr (cddr expr)))))))))
(else
@ -32,7 +68,7 @@
((sym? x sk fk) sk)
((sym? y sk fk) fk))))
(sym? abracadabra success-k failure-k)))))
(define-syntax syntax-memq?
(define-syntax syntax-id-memq?
(syntax-rules ()
((syntax-memq? id (ids ...) sk fk)
(let-syntax
@ -42,35 +78,38 @@
((memq? any-other sk2 fk2) sk2))))
(memq? random-symbol-to-match sk fk))))))))
(begin
(define-syntax report-vars
(define-syntax extract-vars
(syntax-rules ()
((report-vars (op arg0 arg1 ...) (next ...) res)
(syntax-memq? op (quote quasiquote lambda let let* letrec letrec*
let-syntax letrec-syntax let-values let*-values
receive match case define define-syntax do)
(next ... res)
(report-vars arg0
(report-vars (op arg1 ...) (next ...))
res)))
(syntax-id-memq? op (quote quasiquote lambda let let* letrec letrec*
let-syntax letrec-syntax let-values let*-values
receive match case define define-syntax do)
(next ... res)
(extract-vars arg0
(extract-vars (op arg1 ...) (next ...))
res)))
((report-vars (op . x) (next ...) res)
(next ... res))
((report-vars x (next ...) (res ...))
(syntax-identifier? x
(syntax-memq? x (res ...)
(next ... (res ...))
(next ... (res ... x)))
(syntax-id-memq? x (res ...)
(next ... (res ...))
(next ... (res ... x)))
(next ... (res ...))))))
(define-syntax qq-vars
(syntax-rules ()
((qq-vars (next ...) (var ...))
(next ... `(var ,var) ...))))
(define-syntax report-final
(syntax-rules ()
((report-final expr (var ...))
(error "assertion failed" 'expr `(var ,var) ...))))
((report-final expr msg ...)
(error "assertion failed" 'expr msg ...))))
(define-syntax assert
(syntax-rules (report:)
((assert test report: msg ...)
(unless test
(error msg ...)))
((assert test0 test1 ...)
(if test0
(assert test1 ...)
(report-vars test0 (report-final test0) ())))
(syntax-rules ()
((assert test)
(or test
(extract-vars test (qq-vars (report-final test)) ())))
((assert test msg ...)
(or test
(report-final test msg ...)))
((assert) #t)))))

View file

@ -1,6 +1,6 @@
(define-library (srfi 145)
(export assume)
(import (scheme base))
(import (scheme base) (chibi assert))
(cond-expand
((or elide-assumptions
(and (not assumptions)
@ -17,7 +17,6 @@
(define-syntax assume
(syntax-rules ()
((assume expression objs ...)
(or expression
(error "invalid assumption" 'expression objs ...)))
(assert expression objs ...))
((assume)
(syntax-error "assume requires an expression"))))))))

View file

@ -30,12 +30,12 @@
(ub interval-ub))
(define (%make-interval lo hi)
(assert (translation? lo)
(translation? hi)
(not (vector-empty? lo))
(not (vector-empty? hi))
(= (vector-length lo) (vector-length hi))
(vector-every < lo hi))
(assert (and (translation? lo)
(translation? hi)
(not (vector-empty? lo))
(not (vector-empty? hi))
(= (vector-length lo) (vector-length hi))
(vector-every < lo hi)))
(%%make-interval lo hi))
(define (make-interval x . o)
@ -54,7 +54,7 @@
(define (interval-upper-bounds->vector iv) (vector-copy (interval-ub iv)))
(define (interval= iv1 iv2)
(assert (interval? iv1) (interval? iv2))
(assert (and (interval? iv1) (interval? iv2)))
(equal? iv1 iv2))
(define (interval-volume iv)
@ -63,16 +63,16 @@
(interval-lb iv) (interval-ub iv)))
(define (interval-subset? iv1 iv2)
(assert (interval? iv1) (interval? iv2)
(= (interval-dimension iv1) (interval-dimension iv2)))
(assert (and (interval? iv1) (interval? iv2)
(= (interval-dimension iv1) (interval-dimension iv2))))
(and (vector-every >= (interval-lb iv1) (interval-lb iv2))
(vector-every <= (interval-ub iv1) (interval-ub iv2))))
(define (interval-contains-multi-index? iv i0 . o)
(assert (interval? iv))
(let ((i (list->vector (cons i0 o))))
(assert (= (interval-dimension iv) (vector-length i))
(vector-every integer? i))
(assert (and (= (interval-dimension iv) (vector-length i))
(vector-every integer? i)))
(and (vector-every >= i (interval-lb iv))
(vector-every < i (interval-ub iv)))))
@ -136,8 +136,8 @@
(define (interval-intersect iv0 . o)
(let ((ls (cons iv0 o)))
(assert (every interval? ls)
(or (null? o) (apply = (map interval-dimension ls))))
(assert (and (every interval? ls)
(or (null? o) (apply = (map interval-dimension ls)))))
(let ((lower (apply vector-map max (map interval-lb ls)))
(upper (apply vector-map min (map interval-ub ls))))
(and (vector-every < lower upper)
@ -148,7 +148,7 @@
(interval-dilate iv translation translation))
(define (interval-permute iv perm)
(assert (interval? iv) (permutation? perm))
(assert (and (interval? iv) (permutation? perm)))
(let* ((len (interval-dimension iv))
(lower (make-vector len))
(upper (make-vector len)))
@ -167,11 +167,11 @@
(vector-copy upper 0 dim)))))
(define (interval-scale iv scales)
(assert (interval? iv)
(vector? scales)
(= (interval-dimension iv) (vector-length scales))
(vector-every exact-integer? scales)
(vector-every positive? scales))
(assert (and (interval? iv)
(vector? scales)
(= (interval-dimension iv) (vector-length scales))
(vector-every exact-integer? scales)
(vector-every positive? scales)))
(make-interval
(vector-map (lambda (u s) (exact (ceiling (/ u s))))
(interval-ub iv)
@ -273,14 +273,14 @@
(safe? array-safe?))
(define (%make-array domain getter setter storage body coeffs indexer safe?)
(assert (interval? domain)
(procedure? getter)
(or (not setter) (procedure? setter))
(or (not storage) (storage-class? storage)))
(assert (and (interval? domain)
(procedure? getter)
(or (not setter) (procedure? setter))
(or (not storage) (storage-class? storage))))
(%%make-array domain getter setter storage body coeffs indexer safe?))
(define (make-array domain getter . o)
(assert (interval? domain) (procedure? getter))
(assert (and (interval? domain) (procedure? getter)))
(%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f))
(define (array-dimension a)
@ -483,7 +483,7 @@
#t))))
(define (specialized-array-share array new-domain project)
(assert (specialized-array? array) (interval? new-domain))
(assert (and (specialized-array? array) (interval? new-domain)))
(let* ((body (array-body array))
(coeffs
(indexer->coeffs
@ -509,8 +509,8 @@
(mutable? (if (pair? o) (car o) (specialized-array-default-mutable?)))
(o (if (pair? o) (cdr o) '()))
(safe? (if (pair? o) (car o) (specialized-array-default-safe?))))
(assert (storage-class? storage) (interval? new-domain)
(boolean? mutable?) (boolean? safe?))
(assert (and (storage-class? storage) (interval? new-domain)
(boolean? mutable?) (boolean? safe?)))
(let* ((body ((storage-class-maker storage)
(interval-volume new-domain)
(storage-class-default storage)))
@ -556,9 +556,9 @@
))))))))))
(define (array-extract array new-domain)
(assert (array? array)
(interval? new-domain)
(interval-subset? new-domain (array-domain array)))
(assert (and (array? array)
(interval? new-domain)
(interval-subset? new-domain (array-domain array))))
(if (specialized-array? array)
(specialized-array-share array new-domain values)
(make-array new-domain
@ -566,14 +566,14 @@
(array-setter array))))
(define (array-tile array sizes)
(assert (array? array)
(vector? sizes)
(= (array-dimension array) (vector-length sizes))
(vector-every exact-integer? sizes)
(vector-every >= sizes (interval-lower-bounds->vector
(array-domain array)))
(vector-every < sizes (interval-upper-bounds->vector
(array-domain array))))
(assert (and (array? array)
(vector? sizes)
(= (array-dimension array) (vector-length sizes))
(vector-every exact-integer? sizes)
(vector-every >= sizes (interval-lower-bounds->vector
(array-domain array)))
(vector-every < sizes (interval-upper-bounds->vector
(array-domain array)))))
(let ((domain (make-interval
(vector-map
(lambda (lo hi s) (exact (ceiling (/ (- hi lo) s))))
@ -662,9 +662,9 @@
(define (array-reverse array . o)
(assert (array? array))
(let ((flip? (if (pair? o) (car o) (make-vector (array-dimension array) #t))))
(assert (vector? flip?)
(= (array-dimension array) (vector-length flip?))
(vector-every boolean? flip?))
(assert (and (vector? flip?)
(= (array-dimension array) (vector-length flip?))
(vector-every boolean? flip?)))
(let* ((flips (vector->list flip?))
(domain (array-domain array))
(lowers (interval-lower-bounds->list domain))
@ -715,7 +715,7 @@
(apply array-set! array val (map * multi-index scales-ls))))))))
(define (array-outer-product op array1 array2)
(assert (procedure? op) (array? array1) (array? array2))
(assert (and (procedure? op) (array? array1) (array? array2)))
(make-array (interval-cartesian-product (array-domain array1)
(array-domain array2))
(let ((getter1 (array-getter array1))
@ -800,8 +800,8 @@
(car (cddr o))
(specialized-array-default-safe?)))
(res (make-specialized-array domain storage safe?)))
(assert (interval? domain) (storage-class? storage)
(boolean? mutable?) (boolean? safe?))
(assert (and (interval? domain) (storage-class? storage)
(boolean? mutable?) (boolean? safe?)))
(interval-fold
(lambda (ls . multi-index)
(apply array-set! res (car ls) multi-index)
@ -811,13 +811,14 @@
res))
(define (array-assign! destination source)
(assert (array? destination)
(mutable-array? destination)
(array? source)
(or (equal? (array-domain destination) (array-domain source))
(and (array-elements-in-order? destination)
(equal? (interval-volume (array-domain destination))
(interval-volume (array-domain source))))))
(assert
(and (array? destination)
(mutable-array? destination)
(array? source)
(or (equal? (array-domain destination) (array-domain source))
(and (array-elements-in-order? destination)
(equal? (interval-volume (array-domain destination))
(interval-volume (array-domain source)))))))
(let ((getter (array-getter source))
(setter (array-setter destination)))
(if (equal? (array-domain destination) (array-domain source))
@ -878,9 +879,9 @@
(lp (+ i 1) (cdr ls)))))))))
(define (specialized-array-reshape array new-domain . o)
(assert (specialized-array? array)
(= (interval-volume (array-domain array))
(interval-volume new-domain)))
(assert (and (specialized-array? array)
(= (interval-volume (array-domain array))
(interval-volume new-domain))))
(let ((copy-on-failure? (and (pair? o) (car o))))
(cond
((reshape-without-copy array new-domain))

View file

@ -36,6 +36,7 @@
(rename (srfi 166 test) (run-tests run-srfi-166-tests))
(rename (srfi 219 test) (run-tests run-srfi-219-tests))
(rename (scheme bytevector-test) (run-tests run-scheme-bytevector-tests))
(rename (chibi assert-test) (run-tests run-assert-tests))
(rename (chibi base64-test) (run-tests run-base64-tests))
(rename (chibi bytevector-test) (run-tests run-bytevector-tests))
(rename (chibi crypto md5-test) (run-tests run-md5-tests))
@ -106,6 +107,7 @@
(run-srfi-166-tests)
(run-srfi-219-tests)
(run-scheme-bytevector-tests)
(run-assert-tests)
(run-base64-tests)
(run-bytevector-tests)
(run-doc-tests)