aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-03-28 15:53:21 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-04-11 15:02:57 +0100
commit18a379d99ad10002b3cf6eda086457179221cc22 (patch)
tree9b90a31f846035236cbecb9cde379dee66b6f0c3 /verif/generator/tosa_arg_gen.py
parent3f3de01fa87246161e47c15fd6c44f710b86f3e7 (diff)
downloadreference_model-18a379d99ad10002b3cf6eda086457179221cc22.tar.gz
Add rank 0 testing support
Default test range is now rank 0 to 3 instead of 1 to 4 Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ibde66b60b58de9f4a3852a3807c01f8dae61206f
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py20
1 files changed, 15 insertions, 5 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 670a3e4..8d6c8d7 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -259,11 +259,15 @@ class TosaTensorGen:
@staticmethod
def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
+ if rank == 0:
+ # No broadcasting possible for rank 0
+ return [[]] * num_shapes
+
shape = testGen.makeShape(rng, rank)
shape_list = []
- # Choose one of the inputs to broadcast
- # Note: Simplifies OutputShaper code if we don't change first shape for errors
+ # Choose any one of the inputs to broadcast
+ # Note for ERRORS: Simplifies OutputShaper code if we don't change first shape
bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
fuzz_idx = rng.randInt(0, rank)
@@ -1304,10 +1308,14 @@ class TosaTensorValuesGen:
else:
# MUL with 3 inputs (3rd is shift)
tens_ser_list.append(
- testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+ testGen.ser.addPlaceholder(
+ shapeList[0], dtypeList[0], a_arr.astype(np.int32)
+ )
)
tens_ser_list.append(
- testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
+ testGen.ser.addPlaceholder(
+ shapeList[1], dtypeList[1], b_arr.astype(np.int32)
+ )
)
tens_ser_list.append(
testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
@@ -3021,7 +3029,9 @@ class TosaArgGen:
for double_round in [False, True]:
if error_name == ErrorIf.ScaleNotTrue and not double_round:
continue
- for per_channel in [False, True]:
+ # Per_channel is only valid with rank > 0
+ pc_options = (False, True) if len(shapeList[0]) > 0 else (False,)
+ for per_channel in pc_options:
if (
inDtype == DType.INT48