aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_arg_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r--verif/generator/tosa_arg_gen.py20
1 files changed, 15 insertions, 5 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 670a3e4..8d6c8d7 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -259,11 +259,15 @@ class TosaTensorGen:
@staticmethod
def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None):
+ if rank == 0:
+ # No broadcasting possible for rank 0
+ return [[]] * num_shapes
+
shape = testGen.makeShape(rng, rank)
shape_list = []
- # Choose one of the inputs to broadcast
- # Note: Simplifies OutputShaper code if we don't change first shape for errors
+ # Choose any one of the inputs to broadcast
+ # Note for ERRORS: Simplifies OutputShaper code if we don't change first shape
bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes)
fuzz_idx = rng.randInt(0, rank)
@@ -1304,10 +1308,14 @@ class TosaTensorValuesGen:
else:
# MUL with 3 inputs (3rd is shift)
tens_ser_list.append(
- testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr)
+ testGen.ser.addPlaceholder(
+ shapeList[0], dtypeList[0], a_arr.astype(np.int32)
+ )
)
tens_ser_list.append(
- testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr)
+ testGen.ser.addPlaceholder(
+ shapeList[1], dtypeList[1], b_arr.astype(np.int32)
+ )
)
tens_ser_list.append(
testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift]))
@@ -3021,7 +3029,9 @@ class TosaArgGen:
for double_round in [False, True]:
if error_name == ErrorIf.ScaleNotTrue and not double_round:
continue
- for per_channel in [False, True]:
+ # Per_channel is only valid with rank > 0
+ pc_options = (False, True) if len(shapeList[0]) > 0 else (False,)
+ for per_channel in pc_options:
if (
inDtype == DType.INT48