aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-02-28 13:20:05 +0000
committerTatWai Chong <tatwai.chong@arm.com>2024-03-01 13:16:56 -0800
commit0a042997ac24fee1a338e806caf18bd8dfba28f3 (patch)
tree1cfe325d7d775b778873a3940407e68d39c80a48 /verif/generator/tosa_test_gen.py
parent3195a665e3f96809a67b4cb04a57330d2bfeb0de (diff)
downloadreference_model-0a042997ac24fee1a338e806caf18bd8dfba28f3.tar.gz
Testing support for MUL with shift as input
Always create the shift as a tensor for all types in testing. In the reference model, set the shift operand to be available for all types, but only read in the shift tensor for i32. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Signed-off-by: TatWai Chong <tatwai.chong@arm.com> Change-Id: Ia267cbf8b63ca0a9c97b38e8fb4db83eeb8c0538
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py17
1 files changed, 7 insertions, 10 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index ee45f0e..b472087 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -587,9 +587,9 @@ class TosaTestGen:
def build_mul(
self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
):
- assert len(inputs) == 2
- a, b = inputs
- shift = args_dict["shift"]
+ # Note that mul is binary operator but it has a shift value tensor
+ assert len(inputs) == 3
+ a, b, s = inputs
result_tensor = OutputShaper.binaryBroadcastOp(
self.ser, self.rng, a, b, error_name
@@ -605,7 +605,7 @@ class TosaTestGen:
result_tensor.setDtype(outputDType)
# Invalidate Input/Output list for error if checks.
- input_list = [a.name, b.name]
+ input_list = [a.name, b.name, s.name]
output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
@@ -629,10 +629,7 @@ class TosaTestGen:
):
return None
- attr = ts.TosaSerializerAttribute()
- attr.MulAttribute(shift)
-
- self.ser.addOperator(op["op"], input_list, output_list, attr)
+ self.ser.addOperator(op["op"], input_list, output_list)
compliance = self.tensorComplianceMetaData(
op, a.dtype, args_dict, result_tensor, error_name
@@ -3874,10 +3871,10 @@ class TosaTestGen:
},
"mul": {
"op": Op.MUL,
- "operands": (2, 0),
+ "operands": (3, 0),
"build_fcn": (
build_mul,
- TosaTensorGen.tgBroadcastFuzz,
+ TosaTensorGen.tgMul,
TosaTensorValuesGen.tvgMul,
TosaArgGen.agMul,
),