aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 64f0c5e..80b2981 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -6063,6 +6063,21 @@ class TosaTestGen:
self.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
)
tens.extend(placeholders)
+ elif op["op"] == Op.REDUCE_SUM and dtypeList[0] == DType.INT32:
+ assert (
+ pCount == 1 and cCount == 0
+ ), "Op.REDUCE_SUM must have 1 placeholders, 0 consts"
+ # Limit values so that the sum cannot exceed the range of an int32 during
+ # summation of any axis
+ range_val = int((1 << 31) / max(shapeList[0]))
+ values_arr = np.int32(
+ self.rng.integers(low=-range_val, high=range_val, size=shapeList[0])
+ )
+ placeholders = []
+ placeholders.append(
+ self.ser.addPlaceholder(shapeList[0], dtypeList[0], values_arr)
+ )
+ tens.extend(placeholders)
else:
tens.extend(
self.buildPlaceholderTensors(shapeList[0:pCount], dtypeList[0:pCount])