aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7e4cb1d..e60e643 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -5135,13 +5135,35 @@ class TosaTestGen:
multiplier_arr = np.int32(np.zeros(shape=[nc]))
shift_arr = np.int32(np.zeros(shape=[nc]))
+ min_shift_value_arr = np.int64(np.zeros(shape=[nc]))
+ max_shift_value_arr = np.int64(np.zeros(shape=[nc]))
for i in range(nc):
multiplier_arr[i], shift_arr[i] = TosaQuantGen.computeMultiplierAndShift(
scale_arr[i], scale32
)
+ min_shift_value_arr[i] = -1 << (shift_arr[i] - 2)
+ max_shift_value_arr[i] = (1 << (shift_arr[i] - 2)) - 1
# print('multiplier {} shift {} inzp {} outzp {}'.format(multiplier_arr, shift_arr, input_zp, output_zp))
+ if error_name is None:
+ # Make sure random values are within speicification
+ # REQUIRES(value >= (-1<<(shift-2)) && value < (1<<(shift-2))
+ assert val.placeholderFilename
+ values = np.load(
+ os.path.join(self.basePath, self.testPath, val.placeholderFilename)
+ )
+ val_adj = np.subtract(values, input_zp)
+ val_adj = np.maximum(val_adj, min_shift_value_arr, dtype=values.dtype)
+ val_adj = np.minimum(val_adj, max_shift_value_arr, dtype=values.dtype)
+ val_adj = np.add(val_adj, input_zp)
+ if not np.all(np.array_equal(values, val_adj)):
+ # Values changed so overwrite file with new values
+ np.save(
+ os.path.join(self.basePath, self.testPath, val.placeholderFilename),
+ val_adj,
+ False,
+ )
# Invalidate Input/Output list for error if checks.
input_list = [val.name]