aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--reference_model/src/quant_util.h11
-rw-r--r--verif/generator/tosa_test_gen.py6
2 files changed, 8 insertions, 9 deletions
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
index 3b7674d..2e5c2e5 100644
--- a/reference_model/src/quant_util.h
+++ b/reference_model/src/quant_util.h
@@ -55,14 +55,13 @@ public:
"apply_scale_32(): shift value should stay within [2, 62] but is " + std::to_string(shift);
throw desc;
}
- int64_t low_val = -1L << (shift-2);
- int64_t high_val = 1L << (shift-2);
+ int64_t low_val = -1L << (shift - 1);
+ int64_t high_val = 1L << (shift - 1);
if (value < low_val || value >= high_val)
{
- std::string desc =
- "apply_scale_32(): value should stay within [" +
- std::to_string(low_val) + ", " + std::to_string(high_val-1) +
- "] but is " + std::to_string(value) + " using shift of " + std::to_string(shift);
+ std::string desc = "apply_scale_32(): value should stay within [" + std::to_string(low_val) + ", " +
+ std::to_string(high_val - 1) + "] but is " + std::to_string(value) + " using shift of " +
+ std::to_string(shift);
throw desc;
}
int64_t round = 1L << (shift - 1);
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index b0e7c8c..583e1ed 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -1653,13 +1653,13 @@ class TosaTestGen:
multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
scale_arr[i], scale32
)
- min_shift_value_arr[i] = -1 << (shift_arr[i] - 2)
- max_shift_value_arr[i] = (1 << (shift_arr[i] - 2)) - 1
+ min_shift_value_arr[i] = -1 << (shift_arr[i] - 1)
+ max_shift_value_arr[i] = (1 << (shift_arr[i] - 1)) - 1
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
if scale32 and error_name is None:
# Make sure random values are within apply_scale_32 specification
- # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2))
+ # REQUIRES(value >= (-1<<(shift-1)) && value < (1<<(shift-1))
assert val.placeholderFilename
values = np.load(
os.path.join(self.basePath, self.testPath, val.placeholderFilename)