diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index cb97acb..c04b585 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -5792,15 +5792,23 @@ class TosaTestGen: ) tens.extend(placeholders) - elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[ - 0 - ] == DType.INT32: + elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] in ( + DType.INT32, + DType.INT16, + DType.INT8, + ): # Limit input tensors with cond_if_binary or while_loop to stop - # saturation of add/sub ops + # saturation of add/sub ops with int32 and keep all logical shift + # values between 0 to 31 for int16 or int8 pRemain = pCount placeholders = [] for idx, shape in enumerate(shapeList[:]): - arr = self.getRandTensor(shapeList[idx], DType.INT16) + if dtypeList[0] == DType.INT32: + arr = self.getRandTensor(shapeList[idx], DType.INT16) + else: + arr = np.int32( + self.rng.integers(low=0, high=32, size=shapeList[idx]) + ) if pRemain > 0: placeholders.append( self.ser.addPlaceholder(shape, dtypeList[idx], arr) |