aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-04-15 11:21:32 +0100
committerEric Kunze <eric.kunze@arm.com>2024-04-16 16:02:22 +0000
commite0ded5922496db4f65b8e3be397fa3b7603493c7 (patch)
tree47c3ba4e042d74b15a22b7dc96093478f36ba495
parent4a2051146f498cb9ec35d7213720540c5c3e81e2 (diff)
downloadreference_model-e0ded5922496db4f65b8e3be397fa3b7603493c7.tar.gz
Fix ARITHMETIC_RIGHT_SHIFT shift tensor type for int 8 & 16
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I91f7bc956c3b141e1518098781bbf29577c3fbbc
-rw-r--r--verif/generator/tosa_arg_gen.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 5957a33..26dd6f9 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1080,9 +1080,9 @@ class TosaTensorValuesGen:
for idx, shape in enumerate(shapeList[:]):
if idx == 1:
if dtypeList[idx] == DType.INT8:
- arr = np.int32(rng.integers(low=0, high=8, size=shape))
+ arr = np.int8(rng.integers(low=0, high=8, size=shape))
elif dtypeList[idx] == DType.INT16:
- arr = np.int32(rng.integers(low=0, high=16, size=shape))
+ arr = np.int16(rng.integers(low=0, high=16, size=shape))
elif dtypeList[idx] == DType.INT32:
arr = np.int32(rng.integers(low=0, high=32, size=shape))
elif error_name == ErrorIf.WrongInputType: