aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-03-11 09:58:24 +0000
committerEric Kunze <eric.kunze@arm.com>2024-03-13 17:11:15 +0000
commite52c0a3e952a2376e8b537517e30f43fc4f496fe (patch)
tree9a90e97dca3ee371716ceb6e562f6b218053c36e
parentaf09018205f476ab12e3ccfc25523f3f939a2aa3 (diff)
downloadreference_model-e52c0a3e952a2376e8b537517e30f43fc4f496fe.tar.gz
Fix REDUCE_SUM compliance test creation
Make sure output shape is big enough to perform statistical compliance error checking. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ia7ed7dd19a6c9cb888363f6cbdf0c6943235e0be
-rw-r--r--verif/generator/tosa_arg_gen.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 253e8ee..20572e8 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1937,7 +1937,12 @@ class TosaArgGen:
for a in axes:
args_dict = {"axis": int(a)}
if opid == Op.REDUCE_SUM:
- args_dict["dot_products"] = gtu.product(shape)
+ output_shape = shape.copy()
+ if error_name is None:
+ # It only matters that we calculate the dot_products correctly
+ # for non error_if tests as they should never be run
+ output_shape[a] = 1
+ args_dict["dot_products"] = gtu.product(output_shape)
args_dict["shape"] = shape
args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32