From 12ad1d37d830774b98cab693f51048f6507093d0 Mon Sep 17 00:00:00 2001 From: Alex Shinn Date: Thu, 6 May 2021 20:34:03 +0900 Subject: [PATCH] add docs and tests for assert, unify with SRFI 145 --- lib/chibi/assert-test.sld | 30 +++++++++++ lib/chibi/assert.sld | 87 +++++++++++++++++++++--------- lib/srfi/145.sld | 5 +- lib/srfi/179.scm | 109 +++++++++++++++++++------------------- tests/lib-tests.scm | 2 + 5 files changed, 152 insertions(+), 81 deletions(-) create mode 100644 lib/chibi/assert-test.sld diff --git a/lib/chibi/assert-test.sld b/lib/chibi/assert-test.sld new file mode 100644 index 00000000..b98ec2ac --- /dev/null +++ b/lib/chibi/assert-test.sld @@ -0,0 +1,30 @@ + +(define-library (chibi assert-test) + (import (chibi) (chibi assert) (chibi test)) + (export run-tests) + (begin + (define-syntax test-assert + (syntax-rules () + ((test-assert irritants expr) + (protect (exn + (else + (test irritants (exception-irritants exn)))) + expr + (error "assertion not triggered"))))) + (define (run-tests) + (test-begin "assert") + (test-assert '((= x (+ x 1)) + (x 3)) + (let ((x 3)) (assert (= x (+ x 1))))) + (test-assert '((= x (+ y 1)) + (x 3) + (y 42)) + (let ((x 3) (y 42)) (assert (= x (+ y 1))))) + (test-assert '((eq? x 'three) + (x 3)) + (let ((x 3)) (assert (eq? x 'three)))) + (test-assert '((eq? x 'three) + "expected three: " + 3) + (let ((x 3)) (assert (eq? x 'three) "expected three: " x))) + (test-end)))) diff --git a/lib/chibi/assert.sld b/lib/chibi/assert.sld index af558db3..d5a3b27d 100644 --- a/lib/chibi/assert.sld +++ b/lib/chibi/assert.sld @@ -1,3 +1,39 @@ + +;;> A nice assert macro. +;;> +;;> Assert macros are common in Scheme, in particular being helpful +;;> for domain checks at the beginning of a procedure to catch errors +;;> as early as possible. Compared to statically typed languages this +;;> has the advantages that the assertions are optional, and that they +;;> are not limited by the type system. SRFI 145 provides the related +;;> notion of assumptions, but the motivation there is to provide +;;> hints to optimizing compilers, and these are not required to +;;> actually signal an error. +;;> +;;> \macro{(assert expr [msg ...])} +;;> +;;> Equivalent to SRFI 145 \code{assume} except that an error is +;;> guaranteed to be raised if \var{expr} is false. Conceptually +;;> shorthand for +;;> +;;> \code{(or \var{expr} +;;> (error "assertion failed" \var{msg} ...))} +;;> +;;> that is, evaluates \var{expr} and returns it if true, but raises +;;> an exception otherwise. The error is augmented to include the +;;> text of the failed \var{expr}. If no additional \var{msg} +;;> arguments are provided then \var{expr} is scanned for free +;;> variables in non-operator positions to report values from, e.g. in +;;> +;;> \code{(let ((x 3)) +;;> (assert (= x (+ x 1))))} +;;> +;;> the error would also report the bound value of \code{x}. This +;;> uses the technique from Oleg Kiselyov's \hyperlink[http://okmij.org/ftp/Scheme/assert-syntax-rule.txt]{good assert macro}, +;;> which is convenient but fallible. It is thus best to keep the +;;> body of the assertion simple, moving any predicates you need to +;;> external utilities, or provide an explicit \var{msg}. + (define-library (chibi assert) (export assert) (cond-expand @@ -10,11 +46,11 @@ (if (identifier? (cadr expr)) (car (cddr expr)) (cadr (cddr expr)))))) - (define-syntax syntax-memq? + (define-syntax syntax-id-memq? (er-macro-transformer (lambda (expr rename compare) (let ((expr (cdr expr))) - (if (memq (car expr) (cadr expr)) + (if (any (lambda (x) (compare x (car expr))) (cadr expr)) (car (cddr expr)) (cadr (cddr expr))))))))) (else @@ -32,7 +68,7 @@ ((sym? x sk fk) sk) ((sym? y sk fk) fk)))) (sym? abracadabra success-k failure-k))))) - (define-syntax syntax-memq? + (define-syntax syntax-id-memq? (syntax-rules () ((syntax-memq? id (ids ...) sk fk) (let-syntax @@ -42,35 +78,38 @@ ((memq? any-other sk2 fk2) sk2)))) (memq? random-symbol-to-match sk fk)))))))) (begin - (define-syntax report-vars + (define-syntax extract-vars (syntax-rules () ((report-vars (op arg0 arg1 ...) (next ...) res) - (syntax-memq? op (quote quasiquote lambda let let* letrec letrec* - let-syntax letrec-syntax let-values let*-values - receive match case define define-syntax do) - (next ... res) - (report-vars arg0 - (report-vars (op arg1 ...) (next ...)) - res))) + (syntax-id-memq? op (quote quasiquote lambda let let* letrec letrec* + let-syntax letrec-syntax let-values let*-values + receive match case define define-syntax do) + (next ... res) + (extract-vars arg0 + (extract-vars (op arg1 ...) (next ...)) + res))) ((report-vars (op . x) (next ...) res) (next ... res)) ((report-vars x (next ...) (res ...)) (syntax-identifier? x - (syntax-memq? x (res ...) - (next ... (res ...)) - (next ... (res ... x))) + (syntax-id-memq? x (res ...) + (next ... (res ...)) + (next ... (res ... x))) (next ... (res ...)))))) + (define-syntax qq-vars + (syntax-rules () + ((qq-vars (next ...) (var ...)) + (next ... `(var ,var) ...)))) (define-syntax report-final (syntax-rules () - ((report-final expr (var ...)) - (error "assertion failed" 'expr `(var ,var) ...)))) + ((report-final expr msg ...) + (error "assertion failed" 'expr msg ...)))) (define-syntax assert - (syntax-rules (report:) - ((assert test report: msg ...) - (unless test - (error msg ...))) - ((assert test0 test1 ...) - (if test0 - (assert test1 ...) - (report-vars test0 (report-final test0) ()))) + (syntax-rules () + ((assert test) + (or test + (extract-vars test (qq-vars (report-final test)) ()))) + ((assert test msg ...) + (or test + (report-final test msg ...))) ((assert) #t))))) diff --git a/lib/srfi/145.sld b/lib/srfi/145.sld index 8023dc89..b35d4985 100644 --- a/lib/srfi/145.sld +++ b/lib/srfi/145.sld @@ -1,6 +1,6 @@ (define-library (srfi 145) (export assume) - (import (scheme base)) + (import (scheme base) (chibi assert)) (cond-expand ((or elide-assumptions (and (not assumptions) @@ -17,7 +17,6 @@ (define-syntax assume (syntax-rules () ((assume expression objs ...) - (or expression - (error "invalid assumption" 'expression objs ...))) + (assert expression objs ...)) ((assume) (syntax-error "assume requires an expression")))))))) diff --git a/lib/srfi/179.scm b/lib/srfi/179.scm index 17985fe9..65382c31 100644 --- a/lib/srfi/179.scm +++ b/lib/srfi/179.scm @@ -30,12 +30,12 @@ (ub interval-ub)) (define (%make-interval lo hi) - (assert (translation? lo) - (translation? hi) - (not (vector-empty? lo)) - (not (vector-empty? hi)) - (= (vector-length lo) (vector-length hi)) - (vector-every < lo hi)) + (assert (and (translation? lo) + (translation? hi) + (not (vector-empty? lo)) + (not (vector-empty? hi)) + (= (vector-length lo) (vector-length hi)) + (vector-every < lo hi))) (%%make-interval lo hi)) (define (make-interval x . o) @@ -54,7 +54,7 @@ (define (interval-upper-bounds->vector iv) (vector-copy (interval-ub iv))) (define (interval= iv1 iv2) - (assert (interval? iv1) (interval? iv2)) + (assert (and (interval? iv1) (interval? iv2))) (equal? iv1 iv2)) (define (interval-volume iv) @@ -63,16 +63,16 @@ (interval-lb iv) (interval-ub iv))) (define (interval-subset? iv1 iv2) - (assert (interval? iv1) (interval? iv2) - (= (interval-dimension iv1) (interval-dimension iv2))) + (assert (and (interval? iv1) (interval? iv2) + (= (interval-dimension iv1) (interval-dimension iv2)))) (and (vector-every >= (interval-lb iv1) (interval-lb iv2)) (vector-every <= (interval-ub iv1) (interval-ub iv2)))) (define (interval-contains-multi-index? iv i0 . o) (assert (interval? iv)) (let ((i (list->vector (cons i0 o)))) - (assert (= (interval-dimension iv) (vector-length i)) - (vector-every integer? i)) + (assert (and (= (interval-dimension iv) (vector-length i)) + (vector-every integer? i))) (and (vector-every >= i (interval-lb iv)) (vector-every < i (interval-ub iv))))) @@ -136,8 +136,8 @@ (define (interval-intersect iv0 . o) (let ((ls (cons iv0 o))) - (assert (every interval? ls) - (or (null? o) (apply = (map interval-dimension ls)))) + (assert (and (every interval? ls) + (or (null? o) (apply = (map interval-dimension ls))))) (let ((lower (apply vector-map max (map interval-lb ls))) (upper (apply vector-map min (map interval-ub ls)))) (and (vector-every < lower upper) @@ -148,7 +148,7 @@ (interval-dilate iv translation translation)) (define (interval-permute iv perm) - (assert (interval? iv) (permutation? perm)) + (assert (and (interval? iv) (permutation? perm))) (let* ((len (interval-dimension iv)) (lower (make-vector len)) (upper (make-vector len))) @@ -167,11 +167,11 @@ (vector-copy upper 0 dim))))) (define (interval-scale iv scales) - (assert (interval? iv) - (vector? scales) - (= (interval-dimension iv) (vector-length scales)) - (vector-every exact-integer? scales) - (vector-every positive? scales)) + (assert (and (interval? iv) + (vector? scales) + (= (interval-dimension iv) (vector-length scales)) + (vector-every exact-integer? scales) + (vector-every positive? scales))) (make-interval (vector-map (lambda (u s) (exact (ceiling (/ u s)))) (interval-ub iv) @@ -273,14 +273,14 @@ (safe? array-safe?)) (define (%make-array domain getter setter storage body coeffs indexer safe?) - (assert (interval? domain) - (procedure? getter) - (or (not setter) (procedure? setter)) - (or (not storage) (storage-class? storage))) + (assert (and (interval? domain) + (procedure? getter) + (or (not setter) (procedure? setter)) + (or (not storage) (storage-class? storage)))) (%%make-array domain getter setter storage body coeffs indexer safe?)) (define (make-array domain getter . o) - (assert (interval? domain) (procedure? getter)) + (assert (and (interval? domain) (procedure? getter))) (%make-array domain getter (and (pair? o) (car o)) #f #f #f #f #f)) (define (array-dimension a) @@ -483,7 +483,7 @@ #t)))) (define (specialized-array-share array new-domain project) - (assert (specialized-array? array) (interval? new-domain)) + (assert (and (specialized-array? array) (interval? new-domain))) (let* ((body (array-body array)) (coeffs (indexer->coeffs @@ -509,8 +509,8 @@ (mutable? (if (pair? o) (car o) (specialized-array-default-mutable?))) (o (if (pair? o) (cdr o) '())) (safe? (if (pair? o) (car o) (specialized-array-default-safe?)))) - (assert (storage-class? storage) (interval? new-domain) - (boolean? mutable?) (boolean? safe?)) + (assert (and (storage-class? storage) (interval? new-domain) + (boolean? mutable?) (boolean? safe?))) (let* ((body ((storage-class-maker storage) (interval-volume new-domain) (storage-class-default storage))) @@ -556,9 +556,9 @@ )))))))))) (define (array-extract array new-domain) - (assert (array? array) - (interval? new-domain) - (interval-subset? new-domain (array-domain array))) + (assert (and (array? array) + (interval? new-domain) + (interval-subset? new-domain (array-domain array)))) (if (specialized-array? array) (specialized-array-share array new-domain values) (make-array new-domain @@ -566,14 +566,14 @@ (array-setter array)))) (define (array-tile array sizes) - (assert (array? array) - (vector? sizes) - (= (array-dimension array) (vector-length sizes)) - (vector-every exact-integer? sizes) - (vector-every >= sizes (interval-lower-bounds->vector - (array-domain array))) - (vector-every < sizes (interval-upper-bounds->vector - (array-domain array)))) + (assert (and (array? array) + (vector? sizes) + (= (array-dimension array) (vector-length sizes)) + (vector-every exact-integer? sizes) + (vector-every >= sizes (interval-lower-bounds->vector + (array-domain array))) + (vector-every < sizes (interval-upper-bounds->vector + (array-domain array))))) (let ((domain (make-interval (vector-map (lambda (lo hi s) (exact (ceiling (/ (- hi lo) s)))) @@ -662,9 +662,9 @@ (define (array-reverse array . o) (assert (array? array)) (let ((flip? (if (pair? o) (car o) (make-vector (array-dimension array) #t)))) - (assert (vector? flip?) - (= (array-dimension array) (vector-length flip?)) - (vector-every boolean? flip?)) + (assert (and (vector? flip?) + (= (array-dimension array) (vector-length flip?)) + (vector-every boolean? flip?))) (let* ((flips (vector->list flip?)) (domain (array-domain array)) (lowers (interval-lower-bounds->list domain)) @@ -715,7 +715,7 @@ (apply array-set! array val (map * multi-index scales-ls)))))))) (define (array-outer-product op array1 array2) - (assert (procedure? op) (array? array1) (array? array2)) + (assert (and (procedure? op) (array? array1) (array? array2))) (make-array (interval-cartesian-product (array-domain array1) (array-domain array2)) (let ((getter1 (array-getter array1)) @@ -800,8 +800,8 @@ (car (cddr o)) (specialized-array-default-safe?))) (res (make-specialized-array domain storage safe?))) - (assert (interval? domain) (storage-class? storage) - (boolean? mutable?) (boolean? safe?)) + (assert (and (interval? domain) (storage-class? storage) + (boolean? mutable?) (boolean? safe?))) (interval-fold (lambda (ls . multi-index) (apply array-set! res (car ls) multi-index) @@ -811,13 +811,14 @@ res)) (define (array-assign! destination source) - (assert (array? destination) - (mutable-array? destination) - (array? source) - (or (equal? (array-domain destination) (array-domain source)) - (and (array-elements-in-order? destination) - (equal? (interval-volume (array-domain destination)) - (interval-volume (array-domain source)))))) + (assert + (and (array? destination) + (mutable-array? destination) + (array? source) + (or (equal? (array-domain destination) (array-domain source)) + (and (array-elements-in-order? destination) + (equal? (interval-volume (array-domain destination)) + (interval-volume (array-domain source))))))) (let ((getter (array-getter source)) (setter (array-setter destination))) (if (equal? (array-domain destination) (array-domain source)) @@ -878,9 +879,9 @@ (lp (+ i 1) (cdr ls))))))))) (define (specialized-array-reshape array new-domain . o) - (assert (specialized-array? array) - (= (interval-volume (array-domain array)) - (interval-volume new-domain))) + (assert (and (specialized-array? array) + (= (interval-volume (array-domain array)) + (interval-volume new-domain)))) (let ((copy-on-failure? (and (pair? o) (car o)))) (cond ((reshape-without-copy array new-domain)) diff --git a/tests/lib-tests.scm b/tests/lib-tests.scm index 197a5467..e66a6480 100644 --- a/tests/lib-tests.scm +++ b/tests/lib-tests.scm @@ -36,6 +36,7 @@ (rename (srfi 166 test) (run-tests run-srfi-166-tests)) (rename (srfi 219 test) (run-tests run-srfi-219-tests)) (rename (scheme bytevector-test) (run-tests run-scheme-bytevector-tests)) + (rename (chibi assert-test) (run-tests run-assert-tests)) (rename (chibi base64-test) (run-tests run-base64-tests)) (rename (chibi bytevector-test) (run-tests run-bytevector-tests)) (rename (chibi crypto md5-test) (run-tests run-md5-tests)) @@ -106,6 +107,7 @@ (run-srfi-166-tests) (run-srfi-219-tests) (run-scheme-bytevector-tests) +(run-assert-tests) (run-base64-tests) (run-bytevector-tests) (run-doc-tests)