diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-03-28 15:53:21 +0000 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-04-11 15:02:57 +0100 |
commit | 18a379d99ad10002b3cf6eda086457179221cc22 (patch) | |
tree | 9b90a31f846035236cbecb9cde379dee66b6f0c3 /verif/generator/tosa_arg_gen.py | |
parent | 3f3de01fa87246161e47c15fd6c44f710b86f3e7 (diff) | |
download | reference_model-18a379d99ad10002b3cf6eda086457179221cc22.tar.gz |
Add rank 0 testing support
Default test range is now rank 0 to 3 instead of 1 to 4
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ibde66b60b58de9f4a3852a3807c01f8dae61206f
Diffstat (limited to 'verif/generator/tosa_arg_gen.py')
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 670a3e4..8d6c8d7 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -259,11 +259,15 @@ class TosaTensorGen: @staticmethod def _get_broadcast_shapes(testGen, rng, num_shapes, rank, error_name=None): + if rank == 0: + # No broadcasting possible for rank 0 + return [[]] * num_shapes + shape = testGen.makeShape(rng, rank) shape_list = [] - # Choose one of the inputs to broadcast - # Note: Simplifies OutputShaper code if we don't change first shape for errors + # Choose any one of the inputs to broadcast + # Note for ERRORS: Simplifies OutputShaper code if we don't change first shape bcast_idx = rng.randInt(0 if error_name is None else 1, num_shapes) fuzz_idx = rng.randInt(0, rank) @@ -1304,10 +1308,14 @@ class TosaTensorValuesGen: else: # MUL with 3 inputs (3rd is shift) tens_ser_list.append( - testGen.ser.addPlaceholder(shapeList[0], dtypeList[0], a_arr) + testGen.ser.addPlaceholder( + shapeList[0], dtypeList[0], a_arr.astype(np.int32) + ) ) tens_ser_list.append( - testGen.ser.addPlaceholder(shapeList[1], dtypeList[1], b_arr) + testGen.ser.addPlaceholder( + shapeList[1], dtypeList[1], b_arr.astype(np.int32) + ) ) tens_ser_list.append( testGen.ser.addPlaceholder([1], DType.INT8, np.int8([shift])) @@ -3021,7 +3029,9 @@ class TosaArgGen: for double_round in [False, True]: if error_name == ErrorIf.ScaleNotTrue and not double_round: continue - for per_channel in [False, True]: + # Per_channel is only valid with rank > 0 + pc_options = (False, True) if len(shapeList[0]) > 0 else (False,) + for per_channel in pc_options: if ( inDtype == DType.INT48 |