/* Program Name: sbtm.cc Programmer: Michael Schulte, mschulte@eecs.lehigh.edu Date: November 27, 1996 Description: This program produces tables for the bi-partitite table method for approximating a function f(x). This method takes as inputs n0, n1, and n2, which correspond to the number of bits in x0, x1, and x2, and where x = x0 + x1 + x2. It also takes as input the number of guard digits, ng, and a variable, f, which indicates the function to be implement. The program computes the coefficients produced by this approximation method. And reports on the maximum and average error of the approximations. Compiling: gcc -g -o sbtm sbtm.cc -lm Usage: sbtm in_file out_file [data_file] [err_file] Input File Format: n0 = Number of bits - n0 n1 = Number of bits - n1 n2 = Number of bits - n2 ng = Number of guard bits p = Precision of output f = Function (see case statement) rnd = Rounding method (e.g. Round-to-Nearest) start_value = Start value stop_value = Stop value Disclaimer: This program is available as is, and the author is in no way responsible for any consequences of its actions, or its use in any other manners not prescribed by this code. */ #include "sbtm.h" #define ARR_SIZE 100000 main(int argc, char* argv[]) { double step, x, x0, x01, x2; double err, y_true, y_approx, tot_err; double max_err, x_max_err, bits_max_err, eps, delta1, delta2, ppg; double a0[ARR_SIZE], a1[ARR_SIZE], bits_a1, max_bits_a1; double avg_err, bits_avg_err; double start_value, stop_value; int c0, c1, c2, c01, c02, max_c01, max_c02; int n0, n1, n2, n, ng, num_vals, rnd; int iter, num_iter, i, f, j, p; PDF func = NULL; PDF deriv_func = NULL; PDF deriv2_func = NULL; FILE *in_file= NULL; FILE *out_file = NULL; FILE *data_file = NULL; FILE *err_file = NULL; open_files(argc, argv, in_file, out_file, data_file, err_file); fscanf(in_file,"%d%d%d%d%d%d%d\n", &n0, &n1, &n2, &ng, &p, &f, &rnd); fscanf(in_file,"%lg%lg\n", &start_value, &stop_value); fprintf(out_file,"n0 = %d, n1 = %d, n2 = %d, ng = %d, p = %d, f = %d, rnd = %d\n", n0, n1, n2, ng, p, f, rnd); fprintf(out_file,"start_value = %lg, stop_value = %lg\n", start_value, stop_value); /* Initialization */ max_err = 0.0; tot_err = 0.0; max_bits_a1 = 0.0; /* Choose function from Case Statement */ choose_func(f, func, deriv_func, deriv2_func); /* Calculate Number of bits from Input Operand */ n = n0 + n1 + n2; /* Calculate Total Number of Bits p + total number of guard bits */ ppg = double(p+ng); /* Place start and stop values into Accurate Bit Notation */ start_value = flr(start_value,double(n)); stop_value = flr(stop_value,double(n)); /* Calculate total number of values */ step = pow(2, double(-n)); num_vals = (int) ((stop_value - start_value)/step); max_c02 = num_vals/(1 << n1); /* Calculate Values for n0, n1, and n2 that minimize memory requirements*/ int n0_opt, n1_opt, n2_opt; int p0, p1, k, lead_deriv_bits, lead_deriv2_bits; unsigned mem_size, min_mem_size, standard_mem_size; lead_deriv_bits = int(-ceil(log2(fmax(fabs(deriv_func(stop_value)), fabs(deriv_func(start_value)))))); lead_deriv2_bits = int(-ceil(log2(fmax(fabs(deriv2_func(stop_value)), fabs(deriv2_func(start_value)))))); /* Do a special case for atan, since its max function does not occur at an endpoint */ if (f == 5) lead_deriv2_bits = 0; if (f == 0) min_mem_size = standard_mem_size = (1 << n) * (p - 1); else min_mem_size = standard_mem_size = (1 << n) * p; p0 = (int) ppg; if (f == 0) p0--; k = (int) floor(double(n)/3.0); for (n0 = k-2; n0 <= k+2; n0++) { for (n1 = k-2; n1 <= k+2; n1++) { n2 = n - n0 - n1; p1 = int(ppg) - n0 - n1 - 1 - lead_deriv_bits; /* Test if the error criterion is met */ /* if (2*n0 + n1 >= p - 1 - lead_deriv2_bits) { */ if (2*n0 + n1 >= p - lead_deriv2_bits) { mem_size = (1 << (n0 + n1))*p0 + (1 << (n0 + n2 - 1))*p1; if (mem_size < min_mem_size) { min_mem_size = mem_size; n0_opt = n0; n1_opt = n1; n2_opt = n2; } } } } n0 = n0_opt; n1 = n1_opt; n2 = n2_opt; p1 = int(ppg) - n0 - n1 - 1 - lead_deriv_bits; int w0, w1, mem0, mem1, mem_tot; double comp; w0= 1 << (n0 + n1); w1 = 1 << (n0 + n2 - 1); mem0 = w0*p0; mem1 = w1*p1; mem_tot = mem0 + mem1; comp = double(standard_mem_size)/double(mem_tot); fprintf(out_file,"n0 = %d, n1 = %d, n2 = %d\n", n0, n1, n2); fprintf(out_file,"a0 table is %d words x %d bits = %d bits\n", w0, p0, mem0); fprintf(out_file,"a1 table is %d words x %d bits = %d bits\n", w1, p1, mem1); fprintf(out_file,"total_memory = %d, standard_memory = %d, compression = %lg\n", mem_tot, standard_mem_size, comp); /* Calculate delta1 and delta2 -- Page 3 */ delta1 = pow(2.0, double (-n0 - 1)) - pow(2.0, double(- n0 - n1 - 1)); delta2 = pow(2.0, double(- n0 - n1 - 1)) - pow(2.0, double(- n0 - n1 - n2 - 1)); eps = pow(2.0, -ppg - 1); /* Obtain the coefficients */ for (i = 0, c0 = 0, x = start_value; x < stop_value; c0++) { x0 = flr(x, double(n0)); for (c1 = 0; c1 < (1 << n1) && x < stop_value; c1++) { /* Increment c01 properly xxxxx1yyyy => n1=5 (c0<<5)+c1 */ c01 = (c0 << n1) + c1; /* Calcuate x1 */ x01 = flr(x, double(n0+n1)); /* Calculate a0 - Equation 4 */ a0[c01] = func(x01+delta2); for (c2 = 0; c2 < (1 << n2) && x < stop_value; c2++, x+=step, i++) { /* Increment c02 properly xxxxxyyyyy1 => n1=5 (c0<<5)+c1 */ c02 = (c0 << n2) + c2; /* Calculate x2 */ x2 = x - x01; /* Calculate a1 - Equation 5 */ a1[c02] = deriv_func(x0 + delta1 + delta2)*(x2 - delta2); /* Calculate true value - used for error */ y_true = func(x); /* Determine Rounding Method */ if (rnd == 0) y_approx = a0[c01] + a1[c02]; else if (rnd == 1 ) { a0[c01] = round(a0[c01], ppg); a1[c02] = flr(a1[c02], ppg) + eps; y_approx = a0[c01] + a1[c02]; } else { a0[c01] = round(a0[c01], ppg); a1[c02] = flr(a1[c02], ppg) + eps; y_approx = round(a0[c01] + a1[c02], p); } /* Calculate Error */ err = y_true - y_approx; if (err_file) fprintf(err_file, "%lg, %lg\n", x, err); /* Sum Total Error */ tot_err += err; /* If Data File argument is present, output to file */ if (data_file) { fprintf(data_file, "x = %lg, err = %lg, y_true = %lg, y_approx = %lg, a0[%d] = %lg, a1[%d] = %lg\n", x, err, y_true, y_approx, c01, a0[c01], c02, a1[c02]); fprintf(data_file, "x = "); disp_bin(x, 1, n0 + n1 + n2, data_file); fprintf(data_file, " a0 = "); disp_bin(a0[c01], 1, int (ppg+1), data_file); fprintf(data_file, " a1 = "); disp_bin(a1[c02], 1, int (ppg+1), data_file); fprintf(data_file, "\n"); } err = fabs(err); /* Calculate total number of bits for error */ if (err > max_err) { max_err = err; x_max_err = x; bits_max_err = -log(max_err)/log(2.0); } /* Calculate total number of bits for a1 */ bits_a1 = ppg + log(fabs(a1[c02]))/log(2.0); if (bits_a1 > max_bits_a1) max_bits_a1 = bits_a1; } } } max_c02 = c02; max_c01 = c01; if (err_file) fprintf(err_file, "\n"); fprintf(out_file,"num_vals = %d, a0 table = %d words, a1 table = %d words\n", num_vals, max_c01+1, (max_c02+1)/2); fprintf(out_file, "\n iteration = 0\n\n"); fprintf(out_file, "max_err = %lg, x_max_err = %lg, bits_max_err = %lg\n\n", max_err, x_max_err, bits_max_err); fprintf(out_file, "avg_err = %lg, max_bits_a1 = %lg\n\n", tot_err/double(num_vals), max_bits_a1); } void choose_func(int func_spec, PDF &func, PDF &deriv_func, PDF &deriv2_func) { switch (func_spec) { case 0: func = recip; deriv_func = deriv_recip; deriv2_func = deriv2_recip; break; case 1: func = sqrt; deriv_func = deriv_sqrt; deriv2_func = deriv2_sqrt; break; case 2: func = rsqrt; deriv_func = deriv_rsqrt; deriv2_func = deriv2_rsqrt; break; case 3: func = sin; deriv_func = cos; deriv2_func = deriv2_sin; break; case 4: func = cos; deriv_func = deriv_cos; deriv2_func = deriv2_cos; break; case 5: func = atan; deriv_func = deriv_atan; deriv2_func = deriv2_atan; break; case 6: func = log2; deriv_func = deriv_log2; deriv2_func = deriv2_log2; break; case 7: func = exp2; deriv_func = deriv_exp2; deriv2_func = deriv2_exp2; break; case 8: func = log; deriv_func = recip; deriv2_func = deriv_recip; break; case 9: func = exp; deriv_func = exp; deriv2_func = exp; break; case 10: func = square; deriv_func = deriv_square; deriv2_func = deriv2_square; break; default: fprintf(stderr,"choose_func: An invalid function has been specified (f = %d)\n", func_spec); exit(1); }; } void open_files(int argc, char** argv, FILE *&in_file, FILE *&out_file, FILE *&data_file, FILE *&err_file) { if (argc < 3 ) { fprintf(stderr,"Usage: %s in_file out_file [data_file] [err_file]\n", argv[0]); exit(1); } in_file = fopen(argv[1], "r"); if (in_file == NULL) { fprintf(stderr,"%s not found\n", argv[1]); exit(1); } out_file = fopen(argv[2], "a"); if (out_file == NULL) { fprintf(stderr,"Unable to create %s\n", argv[2]); exit(1); } if (ftell(out_file) != 0) { fprintf(stderr,"overwrite %s? ", argv[2]); if ('y' == getchar()) { fclose(out_file); out_file = fopen(argv[2], "w"); } else { exit(1); } } if (argc > 3) { data_file = fopen(argv[3], "w"); if (data_file == NULL) { fprintf(stderr,"Unable to create %s\n", argv[3]); exit(1); } } if (argc > 4) { err_file = fopen(argv[4], "w"); if (err_file == NULL) { fprintf(stderr,"Unable to create %s\n", argv[4]); exit(1); } } }