aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-02-01 11:37:58 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2022-02-16 12:06:56 +0000
commit42c9bae449af7ee395fc0e52d4ca7cc9ad55edeb (patch)
treec57416c1a7d0e0f25b88ba1ed728ff202c82290b
parentae0c1c646da49096c9a33a4839b138fbde2b36b8 (diff)
downloadreference_model-42c9bae449af7ee395fc0e52d4ca7cc9ad55edeb.tar.gz
Update refmodel apply_scale_32: adjust range checking
Fix up generated values for rescale tests Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I28fc3b8f189d25a7ad8e5172d4d8a43b86820fcf
-rw-r--r--reference_model/src/quant_util.h10
-rw-r--r--verif/generator/tosa_test_gen.py22
2 files changed, 32 insertions, 0 deletions
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index 4f6a525..8c1b391 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -55,6 +55,16 @@ public:
"apply_scale_32(): shift value should stay within [2, 62] but is " + std::to_string(shift);
throw desc;
}
+ int64_t low_val = -1L << (shift-2);
+ int64_t high_val = 1L << (shift-2);
+ if (value < low_val || value >= high_val)
+ {
+ std::string desc =
+ "apply_scale_32(): value should stay within [" +
+ std::to_string(low_val) + ", " + std::to_string(high_val-1) +
+ "] but is " + std::to_string(value) + " using shift of " + std::to_string(shift);
+ throw desc;
+ }
int64_t round = 1L << (shift - 1);
if (double_round)
{
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7e4cb1d..e60e643 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -5135,13 +5135,35 @@ class TosaTestGen:
multiplier_arr = np.int32(np.zeros(shape=[nc]))
shift_arr = np.int32(np.zeros(shape=[nc]))
+ min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
+ max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
for i in range(nc):
multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
scale_arr[i], scale32
)
+ min_shift_value_arr[i] = -1 << (shift_arr[i] - 2)
+ max_shift_value_arr[i] = (1 << (shift_arr[i] - 2)) - 1
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
+ if error_name is None:
+ # Make sure random values are within speicification
+ # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2))
+ assert val.placeholderFilename
+ values = np.load(
+ os.path.join(self.basePath, self.testPath, val.placeholderFilename)
+ )
+ val_adj = np.subtract(values, input_zp)
+ val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=values.dtype)
+ val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=values.dtype)
+ val_adj = np.add(val_adj, input_zp)
+ if not np.all(np.array_equal(values, val_adj)):
+ # Values changed so overwrite file with new values
+ np.save(
+ os.path.join(self.basePath, self.testPath, val.placeholderFilename),
+ val_adj,
+ False,
+ )
# Invalidate Input/Output list for error if checks.
input_list = [val.name]