aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--verif/generator/tosa_test_gen.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index cb97acb..c04b585 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -5792,15 +5792,23 @@ class TosaTestGen:
)
tens.extend(placeholders)
- elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[
- 0
- ] == DType.INT32:
+ elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] in (
+ DType.INT32,
+ DType.INT16,
+ DType.INT8,
+ ):
# Limit input tensors with cond_if_binary or while_loop to stop
- # saturation of add/sub ops
+ # saturation of add/sub ops with int32 and keep all logical shift
+ # values between 0 to 31 for int16 or int8
pRemain = pCount
placeholders = []
for idx, shape in enumerate(shapeList[:]):
- arr = self.getRandTensor(shapeList[idx], DType.INT16)
+ if dtypeList[0] == DType.INT32:
+ arr = self.getRandTensor(shapeList[idx], DType.INT16)
+ else:
+ arr = np.int32(
+ self.rng.integers(low=0, high=32, size=shapeList[idx])
+ )
if pRemain > 0:
placeholders.append(
self.ser.addPlaceholder(shape, dtypeList[idx], arr)