diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-01-03 17:07:44 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-01-08 21:40:41 +0000 |
commit | bd801960c958db85ae4092d1350ffbd383c3f77c (patch) | |
tree | e3fa9e3d2a817b75a4c13b663b46e776a3c766e0 /verif/generator/tosa_test_gen.py | |
parent | d80ea5e11e5f92e0f7c08afeba74cb7d1719987b (diff) | |
download | reference_model-bd801960c958db85ae4092d1350ffbd383c3f77c.tar.gz |
Main Compliance: REDUCE_PRODUCT support
Update and fix REDUCE_PRODUCT compliance verify lib support.
Added compliance test generation with data range to not cause infs.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I3b3004c6caa80d97e330a6393f435f5270b56e21
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 5129e24..0d072ac 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -347,6 +347,7 @@ class TosaTestGen: compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]} elif op["op"] == Op.REDUCE_PRODUCT: mode = gtu.ComplianceMode.REDUCE_PRODUCT + compliance_tens["reduce_product_info"] = {"n": argsDict["n"]} elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID): mode = gtu.ComplianceMode.ABS_ERROR if "compliance" in op and "abs_error_lower_bound" in op["compliance"]: @@ -1251,13 +1252,13 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_list, output_list, attr) - if op["op"] == Op.REDUCE_PRODUCT: - # TODO: Add compliance support! - compliance = None - else: - compliance = self.tensorComplianceMetaData( - op, a.dtype, args_dict, result_tensor, error_name - ) + if error_name is None and op["op"] == Op.REDUCE_PRODUCT: + # Number of products - needed for compliance + args_dict["n"] = a.shape[axis] + + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) return TosaTestGen.BuildInfo(result_tensor, compliance) @@ -4066,7 +4067,7 @@ class TosaTestGen: "build_fcn": ( build_reduce, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgLazyGenDefault, + TosaTensorValuesGen.tvgReduceProduct, TosaArgGen.agAxis, ), "types": TYPE_FP, @@ -4080,6 +4081,9 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, "reduce_sum": { "op": Op.REDUCE_SUM, |