From fa6213b9072682e3ac161639b31ad1b4c5551b31 Mon Sep 17 00:00:00 2001 From: Justin Ethier Date: Mon, 11 Mar 2024 19:19:12 -0700 Subject: [PATCH] Issue #530 - First cut at improving sqrt Improving sqrt to properly handle negative parameter values --- scheme/inexact.sld | 31 ++++++++++++++++++++++++++++++- tests/base.scm | 10 ++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/scheme/inexact.sld b/scheme/inexact.sld index e2d1a1a3..ab8c3376 100644 --- a/scheme/inexact.sld +++ b/scheme/inexact.sld @@ -69,7 +69,6 @@ (/ (c-log z1) (c-log z2*))))) (define-inexact-op c-log "log" "clog") (define-inexact-op exp "exp" "cexp") - (define-inexact-op sqrt "sqrt" "csqrt") (define-inexact-op sin "sin" "csin") (define-inexact-op cos "cos" "ccos") (define-inexact-op tan "tan" "ctan") @@ -93,4 +92,34 @@ (* (if (eqv? y -0.0) -1 1) (if (eqv? x -0.0) 3.141592653589793 x)) (atan1 (/ y x)))))))) + + (define-c + sqrt + "(void *data, int argc, closure _, object k, object z)" + " double complex result; + Cyc_check_num(data, z); + if (obj_is_int(z)) { + result = csqrt(obj_obj2int(z)); + } else if (type_of(z) == integer_tag) { + result = csqrt(((integer_type *)z)->value); + } else if (type_of(z) == bignum_tag) { + result = csqrt(mp_get_double(&bignum_value(z))); + } else if (type_of(z) == complex_num_tag) { + result = csqrt(complex_num_value(z)); + } else { + result = csqrt(((double_type *)z)->value); + } + + if (cimag(result) == 0.0) { + make_double(d, creal(result)); + return_closcall1(data, k, &d); + } else { + complex_num_type cn; + assign_complex_num((&cn), result); + return_closcall1(data, k, &cn); + } " +; "(void *data, object ptr, object z)" +; " return_inexact_double_or_cplx_op_no_cps(data, ptr, sqrt, csqrt, z);" +) + )) diff --git a/tests/base.scm b/tests/base.scm index ee704b87..980b0ece 100644 --- a/tests/base.scm +++ b/tests/base.scm @@ -9,6 +9,7 @@ (import (scheme base) + (scheme inexact) (cyclone test)) @@ -102,6 +103,15 @@ (test 2.0 (denominator (inexact (/ 6 4)))) ) +(test-group + "sqrt" + (test #t (sqrt -1)) + (test #t (sqrt -1.0)) + ; TODO: (test 2 (sqrt 4)) + (test 2.0 (sqrt 4.0)) + (test 2i (sqrt -4.0)) +) + (test-group "exact" (test -1 (exact -1))