diff options
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 670a3e4..8d6c8d7 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -259,11 +259,15 @@ class TosaTensorGen: @staticmethod def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None): + if rank == 0: + # No broadcasting possible for rank 0 + return [[]] * num_shapes + shape = testGen.makeShape(rng, rank) shape_list = [] - # Choose one of the inputs to broadcast - # Note: Simplifies OutputShaper code if we don't change first shape for errors + # Choose any one of the inputs to broadcast + # Note for ERRORS: Simplifies OutputShaper code if we don't change first shape bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes) fuzz_idx = rng.randInt(0, rank) @@ -1304,10 +1308,14 @@ class TosaTensorValuesGen: else: # MUL with 3 inputs (3rd is shift) tens_ser_list.append( - testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) + testGen.ser.addPlaceholder( + shapeList[0], dtypeList[0], a_arr.astype(np.int32) + ) ) tens_ser_list.append( - testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) + testGen.ser.addPlaceholder( + shapeList[1], dtypeList[1], b_arr.astype(np.int32) + ) ) tens_ser_list.append( testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift])) @@ -3021,7 +3029,9 @@ class TosaArgGen: for double_round in [False, True]: if error_name == ErrorIf.ScaleNotTrue and not double_round: continue - for per_channel in [False, True]: + # Per_channel is only valid with rank > 0 + pc_options = (False, True) if len(shapeList[0]) > 0 else (False,) + for per_channel in pc_options: if ( inDtype == DType.INT48 |