aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/generator/tosa_arg_gen.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 253e8ee..20572e8 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1937,7 +1937,12 @@ class TosaArgGen:
for a in axes:
args_dict = {"axis": int(a)}
if opid == Op.REDUCE_SUM:
- args_dict["dot_products"] = gtu.product(shape)
+ output_shape = shape.copy()
+ if error_name is None:
+ # It only matters that we calculate the dot_products correctly
+ # for non error_if tests as they should never be run
+ output_shape[a] = 1
+ args_dict["dot_products"] = gtu.product(output_shape)
args_dict["shape"] = shape
args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1
args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32