diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-03-11 09:58:24 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-03-13 17:11:15 +0000 |
commit | e52c0a3e952a2376e8b537517e30f43fc4f496fe (patch) | |
tree | 9a90e97dca3ee371716ceb6e562f6b218053c36e | |
parent | af09018205f476ab12e3ccfc25523f3f939a2aa3 (diff) | |
download | reference_model-e52c0a3e952a2376e8b537517e30f43fc4f496fe.tar.gz |
Fix REDUCE_SUM compliance test creation
Make sure output shape is big enough to perform statistical
compliance error checking.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ia7ed7dd19a6c9cb888363f6cbdf0c6943235e0be
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 253e8ee..20572e8 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1937,7 +1937,12 @@ class TosaArgGen: for a in axes: args_dict = {"axis": int(a)} if opid == Op.REDUCE_SUM: - args_dict["dot_products"] = gtu.product(shape) + output_shape = shape.copy() + if error_name is None: + # It only matters that we calculate the dot_products correctly + # for non error_if tests as they should never be run + output_shape[a] = 1 + args_dict["dot_products"] = gtu.product(output_shape) args_dict["shape"] = shape args_dict["ks"] = int(shape[a]) if a >= 0 and a < len(shape) else 1 args_dict["acc_type"] = dtype if dtype != DType.BF16 else DType.FP32 |