stdio: start simplifying scanf limit tracking logic

Basically removing it from the __scanf_input structure and specializing
it at format sites. The reason is that pretending it's the end of the
stream after the limit is reached does not work because we have to
return EOF at end of stream but not when the limit is hit. So we have to
handle it explicitly, and since we do, no need to have it in the
structure too.
This commit is contained in:
Lephenixnoir 2024-01-14 19:28:36 +01:00
parent 2215b3c267
commit b11c059c0f
No known key found for this signature in database
GPG key ID: 1BBA026E13FC0495
12 changed files with 88 additions and 74 deletions

View file

@ -269,12 +269,11 @@ int __scanf(
// we will have to manage a given format // we will have to manage a given format
else if( format[pos] == '%' ) { else if( format[pos] == '%' ) {
in->readmaxlength = -1; in->readmaxlength = INT_MAX;
// main loop // main loop
loopagain: loopagain:
pos++; pos++;
in->currentlength = 0;
switch( format[pos] ) { switch( format[pos] ) {
// we need to decrypt the corresponding scanf set of character // we need to decrypt the corresponding scanf set of character
@ -286,7 +285,7 @@ int __scanf(
// we need to assign the read char to the corresponding pointer // we need to assign the read char to the corresponding pointer
if (!skip) { if (!skip) {
char *c = (char *) va_arg( *args, char* ); char *c = (char *) va_arg( *args, char* );
if (in->readmaxlength==(unsigned int)-1) { if (in->readmaxlength==INT_MAX) {
for(;;) { for(;;) {
temp = __scanf_peek( in ); temp = __scanf_peek( in );
if (temp==EOF) return EOF; if (temp==EOF) return EOF;
@ -332,7 +331,7 @@ int __scanf(
else else
{ {
if (in->readmaxlength==(unsigned int)-1) { if (in->readmaxlength==INT_MAX) {
for(;;) { for(;;) {
temp = __scanf_peek( in ); temp = __scanf_peek( in );
if (temp==EOF) return EOF; if (temp==EOF) return EOF;
@ -433,16 +432,7 @@ int __scanf(
break; break;
} }
case '0': case '0' ... '9': {
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9': {
user_length = user_length * 10 + (int) ( format[pos] - '0' ); user_length = user_length * 10 + (int) ( format[pos] - '0' );
in->readmaxlength = user_length; in->readmaxlength = user_length;
goto loopagain; goto loopagain;
@ -462,7 +452,8 @@ int __scanf(
bool use_unsigned = (f == 'o' || f == 'x' || f == 'X'); bool use_unsigned = (f == 'o' || f == 'x' || f == 'X');
long long int temp; long long int temp;
err = __strto_int(in, base, NULL, &temp, use_unsigned); err = __strto_int(in, base, NULL, &temp, use_unsigned,
in->readmaxlength);
if (err == EOF && validrets == 0) return EOF; if (err == EOF && validrets == 0) return EOF;
if (err != 0) return validrets; if (err != 0) return validrets;
if (skip) __scanf_store_i( temp, MODSKIP, args ); if (skip) __scanf_store_i( temp, MODSKIP, args );
@ -482,7 +473,8 @@ int __scanf(
// read a double from the current input stream // read a double from the current input stream
// and store in the corresponding arg as a char by reference // and store in the corresponding arg as a char by reference
long double temp; long double temp;
err = __strto_fp( in, NULL, NULL, &temp ); err = __strto_fp( in, NULL, NULL, &temp,
in->readmaxlength);
if (err == EOF && validrets == 0) return EOF; if (err == EOF && validrets == 0) return EOF;
if (err != 0) return validrets; if (err != 0) return validrets;
if (skip) __scanf_store_d( temp, MODSKIP, args ); if (skip) __scanf_store_d( temp, MODSKIP, args );
@ -495,9 +487,11 @@ int __scanf(
long int temp; long int temp;
if (!skip) { if (!skip) {
void *p = (void *) va_arg( *args, void** ); // get the adress of the target pointer (void**) void *p = (void *) va_arg( *args, void** ); // get the adress of the target pointer (void**)
err = __strto_int( in, 0, p, NULL, true ); err = __strto_int( in, 0, p, NULL, true,
in->readmaxlength);
} }
else err = __strto_int( in, 0, &temp, NULL, true ); else err = __strto_int( in, 0, &temp, NULL, true,
in->readmaxlength);
if (err == 0) validrets++; if (err == 0) validrets++;
else return validrets; else return validrets;
skip = false; skip = false;
@ -508,7 +502,7 @@ int __scanf(
int temp; int temp;
if (!skip) { if (!skip) {
char *c = (char *) va_arg( *args, char* ); char *c = (char *) va_arg( *args, char* );
if (in->readmaxlength==(unsigned int)-1) { if (in->readmaxlength==INT_MAX) {
temp = __scanf_peek( in ); temp = __scanf_peek( in );
if (temp==EOF) return EOF; if (temp==EOF) return EOF;
else *c = __scanf_in( in ); else *c = __scanf_in( in );
@ -522,7 +516,7 @@ int __scanf(
} }
} }
else { else {
if (in->readmaxlength==(unsigned int)-1) { if (in->readmaxlength==INT_MAX) {
temp = __scanf_peek( in ); temp = __scanf_peek( in );
if (temp==EOF) return EOF; if (temp==EOF) return EOF;
else { else {
@ -552,7 +546,7 @@ int __scanf(
__purge_space( in ); __purge_space( in );
if (!skip) { if (!skip) {
char *c = (char *) va_arg( *args, char* ); char *c = (char *) va_arg( *args, char* );
if (in->readmaxlength==(unsigned int)-1) { if (in->readmaxlength==INT_MAX) {
loopstring: loopstring:
temp = __scanf_peek( in ); temp = __scanf_peek( in );
if (temp==EOF && curstrlength==0) return validrets; if (temp==EOF && curstrlength==0) return validrets;
@ -582,7 +576,7 @@ int __scanf(
} }
} }
else { else {
if (in->readmaxlength==(unsigned int)-1) { if (in->readmaxlength==INT_MAX) {
loopstringskip: loopstringskip:
temp = __scanf_peek( in ); temp = __scanf_peek( in );
if (temp==EOF && curstrlength==0) return validrets; if (temp==EOF && curstrlength==0) return validrets;

View file

@ -20,7 +20,6 @@ struct __scanf_input {
// max char to read from the input stream as per user length modifier // max char to read from the input stream as per user length modifier
unsigned int readmaxlength; unsigned int readmaxlength;
int currentlength;
// total number of char read so far in the current call of a XYscanf() function (to return a %n when required) // total number of char read so far in the current call of a XYscanf() function (to return a %n when required)
int readsofar; int readsofar;
@ -48,15 +47,20 @@ static inline int __scanf_in(struct __scanf_input *__in)
int c = __in->buffer; int c = __in->buffer;
__in->buffer = __scanf_fetch(__in); __in->buffer = __scanf_fetch(__in);
__in->readsofar++; __in->readsofar++;
__in->currentlength++;
return c; return c;
} }
/* Read the next byte and also decrease a total count of available reads. */
static inline int __scanf_in_limit(struct __scanf_input *__in, int *__N)
{
(*__N)--;
return __scanf_in(__in);
}
/* Peek the next byte without advancing. */ /* Peek the next byte without advancing. */
static inline int __scanf_peek(struct __scanf_input *__in) static inline int __scanf_peek(struct __scanf_input *__in)
{ {
return ((unsigned)__in->currentlength < __in->readmaxlength) return __in->buffer;
? __in->buffer : EOF;
} }
/* Close the input by unsending the buffer once finished. */ /* Close the input by unsending the buffer once finished. */

View file

@ -3,6 +3,7 @@
#include <stdlib.h> #include <stdlib.h>
#include <stdbool.h> #include <stdbool.h>
#include <limits.h>
#include "../stdio/stdio_p.h" #include "../stdio/stdio_p.h"
/* /*
@ -22,13 +23,17 @@
** On platforms where long is 32-bit, 64-bit operations are performed only if ** On platforms where long is 32-bit, 64-bit operations are performed only if
** outll is non-NULL. This is because multiplications with overflow can be ** outll is non-NULL. This is because multiplications with overflow can be
** expensive. ** expensive.
**
** N is the bound on the number of characters to read. To disable the bound,
** specify INT_MAX.
*/ */
int __strto_int( int __strto_int(
struct __scanf_input *__input, struct __scanf_input *__input,
int __base, int __base,
long *__outl, long *__outl,
long long *__outll, long long *__outll,
bool __use_unsigned); bool __use_unsigned,
int __N);
/* /*
** Parse a floating-point value from a string. This is the base function for ** Parse a floating-point value from a string. This is the base function for
@ -42,6 +47,7 @@ int __strto_fp(
struct __scanf_input *__input, struct __scanf_input *__input,
double *__out, double *__out,
float *__outf, float *__outf,
long double *__outl); long double *__outl,
int __N);
#endif /*__STDLIB_P_H__*/ #endif /*__STDLIB_P_H__*/

View file

@ -38,8 +38,8 @@
** -> In hexadecimal notation, we read as many bits as the mantissa of a long ** -> In hexadecimal notation, we read as many bits as the mantissa of a long
** double, then later multiply by a power of 2. There are no approximations. ** double, then later multiply by a power of 2. There are no approximations.
*/ */
static bool parse_digits(struct __scanf_input *input, static int parse_digits(struct __scanf_input *input,
SIGNIFICAND_TYPE *digits, long *exponent, bool hexadecimal) SIGNIFICAND_TYPE *digits, long *exponent, bool hexadecimal, int *N)
{ {
bool dot_found = false; bool dot_found = false;
int digits_found=0, c=0; int digits_found=0, c=0;
@ -53,12 +53,14 @@ static bool parse_digits(struct __scanf_input *input,
int dot_character = '.'; int dot_character = '.';
int exp_character = (hexadecimal ? 'p' : 'e'); int exp_character = (hexadecimal ? 'p' : 'e');
for(int i = 0; true; i++) { for(int i = 0; *N >= 0; i++) {
c = __scanf_peek(input); c = __scanf_peek(input);
if(i == 0 && c == EOF)
return EOF;
if(!(isdigit(c) || if(!(isdigit(c) ||
(hexadecimal && isxdigit(c)) || (hexadecimal && isxdigit(c)) ||
(c == dot_character && !dot_found))) break; (c == dot_character && !dot_found))) break;
__scanf_in(input); __scanf_in_limit(input, N);
if(c == dot_character) { if(c == dot_character) {
dot_found = true; dot_found = true;
@ -102,9 +104,10 @@ static bool parse_digits(struct __scanf_input *input,
set correctly */ set correctly */
struct __scanf_input backup = *input; struct __scanf_input backup = *input;
__scanf_in(input); __scanf_in_limit(input, N);
long e = 0; long e = 0;
if(__strto_int(input, 10, &e, NULL, false) == 0) // TODO: strto_fp: Pass limit to __strto_int
if(__strto_int(input, 10, &e, NULL, false, *N) == 0)
*exponent += e; *exponent += e;
else else
*input = backup; *input = backup;
@ -124,18 +127,21 @@ static bool expect(struct __scanf_input *input, char const *sequence)
} }
int __strto_fp(struct __scanf_input *input, double *out, float *outf, int __strto_fp(struct __scanf_input *input, double *out, float *outf,
long double *outl) long double *outl, int N)
{ {
input->currentlength = 0;
/* Skip initial whitespace */ /* Skip initial whitespace */
while(isspace(__scanf_peek(input))) __scanf_in(input); while(isspace(__scanf_peek(input))) __scanf_in(input);
// TODO: strto_fp() doesn't support size limits well, affecting %5f etc.
if(N <= 0)
return EOF;
/* Read optional sign */ /* Read optional sign */
bool negative = false; bool negative = false;
int sign = __scanf_peek(input); int sign = __scanf_peek(input);
if(sign == '-') negative = true; if(sign == '-') negative = true;
if(sign == '-' || sign == '+') __scanf_in(input); if(sign == '-' || sign == '+') __scanf_in_limit(input, &N);
int errno_value = 0; int errno_value = 0;
bool valid = false; bool valid = false;
@ -156,8 +162,10 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf,
if(__scanf_peek(input) == '(') { if(__scanf_peek(input) == '(') {
while(i < 31) { while(i < 31) {
int c = __scanf_in(input); int c = __scanf_in_limit(input, &N);
if(c == ')') break; if(c == ')') break;
if(c == EOF || N <= 0)
return EOF;
arg[i++] = c; arg[i++] = c;
} }
arg[i] = 0; arg[i] = 0;
@ -179,6 +187,9 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf,
if(outl) *outl = __builtin_infl(); if(outl) *outl = __builtin_infl();
valid = true; valid = true;
} }
else if(__scanf_peek(input) == EOF) {
return EOF;
}
else { else {
SIGNIFICAND_TYPE digits = 0; SIGNIFICAND_TYPE digits = 0;
long e = 0; long e = 0;
@ -187,9 +198,9 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf,
not 0x isn't a problem. */ not 0x isn't a problem. */
bool hexa = false; bool hexa = false;
if(__scanf_peek(input) == '0') { if(__scanf_peek(input) == '0') {
__scanf_in(input); __scanf_in_limit(input, &N);
if(tolower(__scanf_peek(input)) == 'x') { if(tolower(__scanf_peek(input)) == 'x') {
__scanf_in(input); __scanf_in_limit(input, &N);
hexa = true; hexa = true;
} }
/* Count the 0 as a digit */ /* Count the 0 as a digit */
@ -197,13 +208,19 @@ int __strto_fp(struct __scanf_input *input, double *out, float *outf,
} }
if(hexa) { if(hexa) {
valid |= parse_digits(input, &digits, &e, true); int rc = parse_digits(input, &digits, &e, true, &N);
if(!valid && rc == EOF)
return EOF;
valid |= rc;
if(out) *out = (double)digits * exp2(e); if(out) *out = (double)digits * exp2(e);
if(outf) *outf = (float)digits * exp2f(e); if(outf) *outf = (float)digits * exp2f(e);
if(outl) *outl = (long double)digits * exp2l(e); if(outl) *outl = (long double)digits * exp2l(e);
} }
else { else {
valid |= parse_digits(input, &digits, &e, false); int rc = parse_digits(input, &digits, &e, false, &N);
if(!valid && rc == EOF)
return EOF;
valid |= rc;
if(out) *out = (double)digits * pow(10, e); if(out) *out = (double)digits * pow(10, e);
if(outf) *outf = (float)digits * powf(10, e); if(outf) *outf = (float)digits * powf(10, e);
if(outl) *outl = (long double)digits * powl(10, e); if(outl) *outl = (long double)digits * powl(10, e);

View file

@ -5,20 +5,19 @@
#include <limits.h> #include <limits.h>
int __strto_int(struct __scanf_input *input, int base, long *outl, int __strto_int(struct __scanf_input *input, int base, long *outl,
long long *outll, bool use_unsigned) long long *outll, bool use_unsigned, int N)
{ {
input->currentlength = 0;
/* Skip initial whitespace */ /* Skip initial whitespace */
while(isspace(__scanf_peek(input))) __scanf_in(input); while(isspace(__scanf_peek(input))) __scanf_in(input);
if(N <= 0)
return EOF;
/* Accept a sign character */ /* Accept a sign character */
bool negative = false; bool negative = false;
int sign = __scanf_peek(input); int sign = __scanf_peek(input);
if(sign == EOF) return EOF;
if(sign == '-') negative = true; if(sign == '-') negative = true;
if(sign == '-' || sign == '+') __scanf_in(input); if(sign == '-' || sign == '+') __scanf_in_limit(input, &N);
/* Use unsigned variables as only these have defined overflow */ /* Use unsigned variables as only these have defined overflow */
unsigned long xl = 0; unsigned long xl = 0;
@ -29,10 +28,10 @@ int __strto_int(struct __scanf_input *input, int base, long *outl,
/* Read prefixes and determine base */ /* Read prefixes and determine base */
if(__scanf_peek(input) == '0') { if(__scanf_peek(input) == '0') {
__scanf_in(input); __scanf_in_limit(input, &N);
if((base == 0 || base == 16) && if((base == 0 || base == 16) &&
tolower(__scanf_peek(input)) == 'x') { tolower(__scanf_peek(input)) == 'x') {
__scanf_in(input); __scanf_in_limit(input, &N);
base = 16; base = 16;
} }
/* If we don't consume the x then count the 0 as a digit */ /* If we don't consume the x then count the 0 as a digit */
@ -40,13 +39,13 @@ int __strto_int(struct __scanf_input *input, int base, long *outl,
if(base == 0) if(base == 0)
base = 8; base = 8;
} }
else if(__scanf_peek(input) == EOF) if(!valid && (N <= 0 || __scanf_peek(input) == EOF))
return EOF; return EOF;
if(base == 0) if(base == 0)
base = 10; base = 10;
/* Read digits */ /* Read digits */
while(1) { while(N > 0) {
int v = -1; int v = -1;
int c = __scanf_peek(input); int c = __scanf_peek(input);
if(isdigit(c)) v = c - '0'; if(isdigit(c)) v = c - '0';
@ -71,7 +70,7 @@ int __strto_int(struct __scanf_input *input, int base, long *outl,
errno_value = ERANGE; errno_value = ERANGE;
} }
__scanf_in(input); __scanf_in_limit(input, &N);
} }
/* Handle sign and range */ /* Handle sign and range */

View file

@ -7,9 +7,9 @@ double strtod(char const * restrict ptr, char ** restrict endptr)
if(endptr) if(endptr)
*endptr = (char *)ptr; *endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in); __scanf_start(&in);
int err = __strto_fp(&in, &d, NULL, NULL); int err = __strto_fp(&in, &d, NULL, NULL, INT_MAX);
__scanf_end(&in); __scanf_end(&in);
if(err != 0) if(err != 0)

View file

@ -7,9 +7,9 @@ float strtof(char const * restrict ptr, char ** restrict endptr)
if(endptr) if(endptr)
*endptr = (char *)ptr; *endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in); __scanf_start(&in);
int err = __strto_fp(&in, NULL, &f, NULL); int err = __strto_fp(&in, NULL, &f, NULL, INT_MAX);
__scanf_end(&in); __scanf_end(&in);
if(err != 0) if(err != 0)

View file

@ -7,15 +7,9 @@ long int strtol(char const * restrict ptr, char ** restrict endptr, int base)
if(endptr) if(endptr)
*endptr = (char *)ptr; *endptr = (char *)ptr;
struct __scanf_input in = { struct __scanf_input in = { .str = ptr, .fp = NULL };
.str = ptr,
.fp = NULL,
.readmaxlength = -1,
.currentlength = 0,
.readsofar = 0,
};
__scanf_start(&in); __scanf_start(&in);
int err = __strto_int(&in, base, &n, NULL, false); int err = __strto_int(&in, base, &n, NULL, false, INT_MAX);
__scanf_end(&in); __scanf_end(&in);
if(err != 0) if(err != 0)

View file

@ -7,9 +7,9 @@ long double strtold(char const * restrict ptr, char ** restrict endptr)
if(endptr) if(endptr)
*endptr = (char *)ptr; *endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in); __scanf_start(&in);
int err = __strto_fp(&in, NULL, NULL, &ld); int err = __strto_fp(&in, NULL, NULL, &ld, INT_MAX);
__scanf_end(&in); __scanf_end(&in);
if(err != 0) if(err != 0)

View file

@ -8,9 +8,9 @@ long long int strtoll(char const * restrict ptr, char ** restrict endptr,
if(endptr) if(endptr)
*endptr = (char *)ptr; *endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in); __scanf_start(&in);
int err = __strto_int(&in, base, NULL, &n, false); int err = __strto_int(&in, base, NULL, &n, false, INT_MAX);
__scanf_end(&in); __scanf_end(&in);
if(err != 0) if(err != 0)

View file

@ -8,9 +8,9 @@ unsigned long int strtoul(char const * restrict ptr, char ** restrict endptr,
if(endptr) if(endptr)
*endptr = (char *)ptr; *endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in); __scanf_start(&in);
int err = __strto_int(&in, base, (long *)&n, NULL, true); int err = __strto_int(&in, base, (long *)&n, NULL, true, INT_MAX);
__scanf_end(&in); __scanf_end(&in);
if(err != 0) if(err != 0)

View file

@ -8,9 +8,9 @@ unsigned long long int strtoull(char const * restrict ptr,
if(endptr) if(endptr)
*endptr = (char *)ptr; *endptr = (char *)ptr;
struct __scanf_input in = { .str = ptr, .fp = NULL, .readmaxlength = -1 }; struct __scanf_input in = { .str = ptr, .fp = NULL };
__scanf_start(&in); __scanf_start(&in);
int err = __strto_int(&in, base, NULL, (long long *)&n, true); int err = __strto_int(&in, base, NULL, (long long *)&n, true, INT_MAX);
__scanf_end(&in); __scanf_end(&in);
if(err != 0) if(err != 0)