aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-01-24 12:24:21 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2022-01-24 12:24:27 +0000
commit0c171ac4de49c895c3aa7da2e394f9572ee49888 (patch)
treed90a068dc1fa14e36adf9af31c828c9fc7ec8889
parent66bad80a98307246f94f8b69d2a62f4649e71455 (diff)
downloadreference_model-0c171ac4de49c895c3aa7da2e394f9572ee49888.tar.gz
Fix COND_IF binary INT8/16 test generation
Limit input values to allowed for logical shift operations. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I78110c449274ab96a3f824890c3f03a0eeb345eb
-rw-r--r--verif/generator/tosa_test_gen.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index cb97acb..c04b585 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -5792,15 +5792,23 @@ class TosaTestGen:
)
tens.extend(placeholders)
- elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[
- 0
- ] == DType.INT32:
+ elif (op["op"] == Op.COND_IF or op["op"] == Op.WHILE_LOOP) and dtypeList[0] in (
+ DType.INT32,
+ DType.INT16,
+ DType.INT8,
+ ):
# Limit input tensors with cond_if_binary or while_loop to stop
- # saturation of add/sub ops
+ # saturation of add/sub ops with int32 and keep all logical shift
+ # values between 0 to 31 for int16 or int8
pRemain = pCount
placeholders = []
for idx, shape in enumerate(shapeList[:]):
- arr = self.getRandTensor(shapeList[idx], DType.INT16)
+ if dtypeList[0] == DType.INT32:
+ arr = self.getRandTensor(shapeList[idx], DType.INT16)
+ else:
+ arr = np.int32(
+ self.rng.integers(low=0, high=32, size=shapeList[idx])
+ )
if pRemain > 0:
placeholders.append(
self.ser.addPlaceholder(shape, dtypeList[idx], arr)