diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-02-01 11:37:58 +0000 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2022-02-16 12:06:56 +0000 |
commit | 42c9bae449af7ee395fc0e52d4ca7cc9ad55edeb (patch) | |
tree | c57416c1a7d0e0f25b88ba1ed728ff202c82290b | |
parent | ae0c1c646da49096c9a33a4839b138fbde2b36b8 (diff) | |
download | reference_model-42c9bae449af7ee395fc0e52d4ca7cc9ad55edeb.tar.gz |
Update refmodel apply_scale_32: adjust range checking
Fix up generated values for rescale tests
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I28fc3b8f189d25a7ad8e5172d4d8a43b86820fcf
-rw-r--r-- | reference_model/src/quant_util.h | 10 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 22 |
2 files changed, 32 insertions, 0 deletions
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h index 4f6a525..8c1b391 100644 --- a/reference_model/src/quant_util.h +++ b/reference_model/src/quant_util.h @@ -55,6 +55,16 @@ public: "apply_scale_32(): shift value should stay within [2, 62] but is " + std::to_string(shift); throw desc; } + int64_t low_val = -1L << (shift-2); + int64_t high_val = 1L << (shift-2); + if (value < low_val || value >= high_val) + { + std::string desc = + "apply_scale_32(): value should stay within [" + + std::to_string(low_val) + ", " + std::to_string(high_val-1) + + "] but is " + std::to_string(value) + " using shift of " + std::to_string(shift); + throw desc; + } int64_t round = 1L << (shift - 1); if (double_round) { diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 7e4cb1d..e60e643 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -5135,13 +5135,35 @@ class TosaTestGen: multiplier_arr = np.int32(np.zeros(shape=[nc])) shift_arr = np.int32(np.zeros(shape=[nc])) + min_shift_value_arr = np.int64(np.zeros(shape=[nc])) + max_shift_value_arr = np.int64(np.zeros(shape=[nc])) for i in range(nc): multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift( scale_arr[i], scale32 ) + min_shift_value_arr[i] = -1 << (shift_arr[i] - 2) + max_shift_value_arr[i] = (1 << (shift_arr[i] - 2)) - 1 # print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp)) + if error_name is None: + # Make sure random values are within speicification + # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2)) + assert val.placeholderFilename + values = np.load( + os.path.join(self.basePath, self.testPath, val.placeholderFilename) + ) + val_adj = np.subtract(values, input_zp) + val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=values.dtype) + val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=values.dtype) + val_adj = np.add(val_adj, input_zp) + if not np.all(np.array_equal(values, val_adj)): + # Values changed so overwrite file with new values + np.save( + os.path.join(self.basePath, self.testPath, val.placeholderFilename), + val_adj, + False, + ) # Invalidate Input/Output list for error if checks. input_list = [val.name] |