aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2023-06-29 23:20:00 +0000
committerEric Kunze <eric.kunze@arm.com>2023-07-06 21:36:53 +0000
commitdd14c1b2ba22f4b37be9ec1a9a7d61741d36506e (patch)
tree74d772c12768e253d3d79c619cdaf7cb7922af14
parentd69e2838b2e2f46401c5da19f662c9e0d5a5df06 (diff)
downloadreference_model-dd14c1b2ba22f4b37be9ec1a9a7d61741d36506e.tar.gz
Add reference model framework test for INT8 squared difference
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I6fba6907cef0616c18dc461dbb92d2aceb582f6c
-rwxr-xr-xverif/frameworks/tosa_verif_framework_generator.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/verif/frameworks/tosa_verif_framework_generator.py b/verif/frameworks/tosa_verif_framework_generator.py
index 94a1a15..09a06b4 100755
--- a/verif/frameworks/tosa_verif_framework_generator.py
+++ b/verif/frameworks/tosa_verif_framework_generator.py
@@ -369,7 +369,10 @@ TF_OP_LIST = {
"squared_difference": {
"operands": (2, 0),
"build_fcn": (TBuilder.SquaredDifference, TGen.tgBFuzz, ArgGen.agNone),
- "types": TYPE_F,
+ "types": {
+ "tf": TYPE_F,
+ "tflite": list(TYPE_FI + [QuantType.ALL_I8]),
+ },
},
"equal": {
"operands": (2, 0),
@@ -1119,7 +1122,12 @@ def run_unit_test(
max_val = float(qmax - qzero[idx]) * scale
else:
scale = (max_val - min_val) / float(qmax - qmin)
- zeropoint = int(round((-min_val) / scale)) + qmin
+ zeropoint = -int(round((-min_val) / scale)) + qmin
+
+ # Exit if min_val <= 0.0, in order to avoid assertion error
+ # from tf.quantization.fake_quant_with_min_max_args
+ if min_val > 0.0:
+ return True
# run through tf.fakequant first to assure quantization error aligned
fakequant_val = tf.quantization.fake_quant_with_min_max_args(