aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/tosa_test_gen.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 04fce90..105f016 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -4031,6 +4031,20 @@ class TosaTestGen:
)
tens.extend(placeholders)
+ elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32:
+ # Limit input tensors with cond_if_binary or while_loop to stop
+ # saturation of add/sub ops
+ pRemain = pCount
+ placeholders = []
+ for idx, shape in enumerate(shapeList[:]):
+ arr = self.getRandTensor(shapeList[idx], DType.INT16)
+ if pRemain > 0:
+ placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
+ pRemain -= 1
+ else:
+ placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
+
+ tens.extend(placeholders)
elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
# Force value of operand[1] to be within [0, num_bits]
assert (