aboutsummaryrefslogtreecommitdiff
path: root/verif/frameworks/tosa_verif_framework_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/frameworks/tosa_verif_framework_generator.py')
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py17
1 files changed, 6 insertions, 11 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 09a06b4..124bf6e 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -960,13 +960,10 @@ def run_unit_test(
# Get and seed a random number generator for this test
rng = np.random.default_rng(seed)
- # For broadcast fuzzing, record the fuzzed index if fuzzing is already done.
- fuzzed = []
-
# return placeholders=(str: name, np.array: value)
# consts=(str: name, np.array: value)
placeholders, consts = (
- tensor_gen_fcn(op, curr_shape, dtype, rng, fuzzed)
+ tensor_gen_fcn(op, curr_shape, dtype, rng, False)
if tensor_gen_fcn.__name__ == "tgBFuzz"
else tensor_gen_fcn(op, curr_shape, dtype, rng)
)
@@ -1122,12 +1119,10 @@ def run_unit_test(
max_val = float(qmax - qzero[idx]) * scale
else:
scale = (max_val - min_val) / float(qmax - qmin)
- zeropoint = -int(round((-min_val) / scale)) + qmin
-
- # Exit if min_val <= 0.0, in order to avoid assertion error
- # from tf.quantization.fake_quant_with_min_max_args
- if min_val > 0.0:
- return True
+ if op_name == "squared_difference":
+ zeropoint = -int(round((-min_val) / scale)) + qmin
+ else:
+ zeropoint = int(round((-min_val) / scale)) + qmin
# run through tf.fakequant first to assure quantization error aligned
fakequant_val = tf.quantization.fake_quant_with_min_max_args(
@@ -1177,7 +1172,7 @@ def run_unit_test(
def input_stats():
for i in range(0, args.num_samples):
placeholders, _ = (
- tensor_gen_fcn(op, placeholder_shapes[0], dtype, rng, fuzzed)
+ tensor_gen_fcn(op, placeholder_shapes[0], dtype, rng, True)
if tensor_gen_fcn == "tgBFuzz"
else tensor_gen_fcn(op, placeholder_shapes[0], dtype, rng)
)