diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-02-28 13:20:05 +0000 |
---|---|---|
committer | TatWai Chong <tatwai.chong@arm.com> | 2024-03-01 13:16:56 -0800 |
commit | 0a042997ac24fee1a338e806caf18bd8dfba28f3 (patch) | |
tree | 1cfe325d7d775b778873a3940407e68d39c80a48 /verif/generator/tosa_test_gen.py | |
parent | 3195a665e3f96809a67b4cb04a57330d2bfeb0de (diff) | |
download | reference_model-0a042997ac24fee1a338e806caf18bd8dfba28f3.tar.gz |
Testing support for MUL with shift as input
Always create the shift as a tensor for all types in testing.
In the reference model, set the shift operand to be available for
all types, but only read in the shift tensor for i32.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Signed-off-by: TatWai Chong <tatwai.chong@arm.com>
Change-Id: Ia267cbf8b63ca0a9c97b38e8fb4db83eeb8c0538
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index ee45f0e..b472087 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -587,9 +587,9 @@ class TosaTestGen: def build_mul( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None ): - assert len(inputs) == 2 - a, b = inputs - shift = args_dict["shift"] + # Note that mul is binary operator but it has a shift value tensor + assert len(inputs) == 3 + a, b, s = inputs result_tensor = OutputShaper.binaryBroadcastOp( self.ser, self.rng, a, b, error_name @@ -605,7 +605,7 @@ class TosaTestGen: result_tensor.setDtype(outputDType) # Invalidate Input/Output list for error if checks. - input_list = [a.name, b.name] + input_list = [a.name, b.name, s.name] output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount @@ -629,10 +629,7 @@ class TosaTestGen: ): return None - attr = ts.TosaSerializerAttribute() - attr.MulAttribute(shift) - - self.ser.addOperator(op["op"], input_list, output_list, attr) + self.ser.addOperator(op["op"], input_list, output_list) compliance = self.tensorComplianceMetaData( op, a.dtype, args_dict, result_tensor, error_name @@ -3874,10 +3871,10 @@ class TosaTestGen: }, "mul": { "op": Op.MUL, - "operands": (2, 0), + "operands": (3, 0), "build_fcn": ( build_mul, - TosaTensorGen.tgBroadcastFuzz, + TosaTensorGen.tgMul, TosaTensorValuesGen.tvgMul, TosaArgGen.agMul, ), |