From 42c9bae449af7ee395fc0e52d4ca7cc9ad55edeb Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Tue, 1 Feb 2022 11:37:58 +0000 Subject: Update refmodel apply_scale_32: adjust range checking Fix up generated values for rescale tests Signed-off-by: Jeremy Johnson Change-Id: I28fc3b8f189d25a7ad8e5172d4d8a43b86820fcf --- reference_model/src/quant_util.h | 10 ++++++++++ verif/generator/tosa_test_gen.py | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index 4f6a525..8c1b391 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -55,6 +55,16 @@ public: "apply_scale_32(): shift value should stay within [2, 62] but is " + std::to_string(shift); throw desc; } + int64_t low_val = -1L << (shift-2); + int64_t high_val = 1L << (shift-2); + if (value < low_val || value >= high_val) + { + std::string desc = + "apply_scale_32(): value should stay within [" + + std::to_string(low_val) + ", " + std::to_string(high_val-1) + + "] but is " + std::to_string(value) + " using shift of " + std::to_string(shift); + throw desc; + } int64_t round = 1L << (shift - 1); if (double_round) { diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 7e4cb1d..e60e643 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -5135,13 +5135,35 @@ class TosaTestGen: multiplier_arr = np.int32(np.zeros(shape=[nc])) shift_arr = np.int32(np.zeros(shape=[nc])) + min_shift_value_arr = np.int64(np.zeros(shape=[nc])) + max_shift_value_arr = np.int64(np.zeros(shape=[nc])) for i in range(nc): multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift( scale_arr[i], scale32 ) + min_shift_value_arr[i] = -1 << (shift_arr[i] - 2) + max_shift_value_arr[i] = (1 << (shift_arr[i] - 2)) - 1 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp)) + if error_name is None: + # Make sure random values are within speicification + # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2)) + assert val.placeholderFilename + values = np.load( + os.path.join(self.basePath, self.testPath, val.placeholderFilename) + ) + val_adj = np.subtract(values, input_zp) + val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=values.dtype) + val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=values.dtype) + val_adj = np.add(val_adj, input_zp) + if not np.all(np.array_equal(values, val_adj)): + # Values changed so overwrite file with new values + np.save( + os.path.join(self.basePath, self.testPath, val.placeholderFilename), + val_adj, + False, + ) # Invalidate Input/Output list for error if checks. input_list = [val.name] -- cgit v1.2.1