diff options
Diffstat (limited to 'verif')
-rw-r--r-- | verif/conformance/tosa_main_profile_ops_info.json | 5 | ||||
-rw-r--r-- | verif/generator/tosa_arg_gen.py | 26 | ||||
-rw-r--r-- | verif/generator/tosa_test_gen.py | 20 |
3 files changed, 41 insertions, 10 deletions
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index fb25622..2b99ed9 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -2507,6 +2507,7 @@ "profile": [ "tosa-mi" ], + "support_for": [ "lazy_data_gen" ], "generation": { "standard": { "generator_args": [ @@ -2518,13 +2519,15 @@ "--target-dtype", "bf16", "--fp-values-range", - "-2.0,2.0", + "-max,max", "--tensor-dim-range", "1,34" ], [ "--target-dtype", "fp16", + "--fp-values-range", + "-max,max", "--target-shape", "2,65527,3,1", "--target-shape", diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 1e23822..8641499 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2023, ARM Limited. +# Copyright (c) 2021-2024, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import itertools import math @@ -1274,6 +1274,30 @@ class TosaTensorValuesGen: testGen, opName, dtypeList, shapeList, argsDict, error_name ) + @staticmethod + def tvgReduceProduct( + testGen, opName, dtypeList, shapeList, argsDict, error_name=None + ): + dtype = dtypeList[0] + if error_name is None: + # Limit ranges for (non error) tests by using + # values that can be multiplied on any axis to not hit infinity + highval_lookup = { + dtype: math.pow( + TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], + 1 / max(shapeList[0]), + ) + } + data_range = TosaTensorValuesGen._get_data_range( + testGen, dtype, highval_lookup + ) + assert data_range is not None + argsDict["data_range"] = data_range + + return TosaTensorValuesGen.tvgLazyGenDefault( + testGen, opName, dtypeList, shapeList, argsDict, error_name + ) + # Set the POW exponent high data range TVG_FLOAT_HIGH_VALUE_POW_EXP = { DType.FP32: 10.0, diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 5129e24..0d072ac 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -347,6 +347,7 @@ class TosaTestGen: compliance_tens["ulp_info"] = {"ulp": op["compliance"]["ulp"]} elif op["op"] == Op.REDUCE_PRODUCT: mode = gtu.ComplianceMode.REDUCE_PRODUCT + compliance_tens["reduce_product_info"] = {"n": argsDict["n"]} elif op["op"] in (Op.EXP, Op.POW, Op.TANH, Op.SIGMOID): mode = gtu.ComplianceMode.ABS_ERROR if "compliance" in op and "abs_error_lower_bound" in op["compliance"]: @@ -1251,13 +1252,13 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_list, output_list, attr) - if op["op"] == Op.REDUCE_PRODUCT: - # TODO: Add compliance support! - compliance = None - else: - compliance = self.tensorComplianceMetaData( - op, a.dtype, args_dict, result_tensor, error_name - ) + if error_name is None and op["op"] == Op.REDUCE_PRODUCT: + # Number of products - needed for compliance + args_dict["n"] = a.shape[axis] + + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) return TosaTestGen.BuildInfo(result_tensor, compliance) @@ -4066,7 +4067,7 @@ class TosaTestGen: "build_fcn": ( build_reduce, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgLazyGenDefault, + TosaTensorValuesGen.tvgReduceProduct, TosaArgGen.agAxis, ), "types": TYPE_FP, @@ -4080,6 +4081,9 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, "reduce_sum": { "op": Op.REDUCE_SUM, |