aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2021-10-20 15:51:11 +0100
committerEric Kunze <eric.kunze@arm.com>2021-10-25 18:23:53 +0000
commit8c06a6547a132f0a22fa34d467026f12fabb4e1f (patch)
treef6b74aa39d18200ccbdfd17f618ff77c3a23f11b
parent6e5286674c204ebf829d8dc1ddc6606ce8c73aff (diff)
downloadreference_model-8c06a6547a132f0a22fa34d467026f12fabb4e1f.tar.gz
Limit tensor values for COND_IF/WHILE_LOOP tests to stop saturation
Change-Id: Idb36b1f1c0d78ec101c168865a9c8d03221b4c84 Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
-rw-r--r--verif/tosa_test_gen.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/verif/tosa_test_gen.py b/verif/tosa_test_gen.py
index 04fce90..105f016 100644
--- a/verif/tosa_test_gen.py
+++ b/verif/tosa_test_gen.py
@@ -4031,6 +4031,20 @@ class TosaTestGen:
)
tens.extend(placeholders)
+ elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] == DType.INT32:
+ # Limit input tensors with cond_if_binary or while_loop to stop
+ # saturation of add/sub ops
+ pRemain = pCount
+ placeholders = []
+ for idx, shape in enumerate(shapeList[:]):
+ arr = self.getRandTensor(shapeList[idx], DType.INT16)
+ if pRemain > 0:
+ placeholders.append(self.ser.addPlaceholder(shape, dtypeList[idx], arr))
+ pRemain -= 1
+ else:
+ placeholders.append(self.ser.addConst(shape, dtypeList[idx], arr))
+
+ tens.extend(placeholders)
elif op["op"] == Op.ARITHMETIC_RIGHT_SHIFT:
# Force value of operand[1] to be within [0, num_bits]
assert (