From 7b9abced233128f4128d84294a0f9d6b432a24cf Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 10 Jan 2024 11:07:29 +0000 Subject: Main Compliance testing for SELECT Signed-off-by: Jeremy Johnson Change-Id: I7276f2db39e67314c950e972cc1a97b7796dcd18 --- reference_model/src/generate/generate_utils.cc | 1 + verif/conformance/tosa_main_profile_ops_info.json | 7 +++--- verif/generator/tosa_arg_gen.py | 6 ++--- verif/generator/tosa_test_gen.py | 28 +++++++++++++++++------ 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index b94b888..23ee1dc 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -72,6 +72,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op, { Op::Op_REDUCE_PRODUCT, "REDUCE_PRODUCT" }, { Op::Op_REDUCE_SUM, "REDUCE_SUM" }, { Op::Op_SCATTER, "SCATTER" }, + { Op::Op_SELECT, "SELECT" }, { Op::Op_SIGMOID, "SIGMOID" }, { Op::Op_SUB, "SUB" }, { Op::Op_TANH, "TANH" }, diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index 6cf98ed..fa9b26b 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -2730,6 +2730,7 @@ "profile": [ "tosa-mi" ], + "support_for": [ "lazy_data_gen" ], "generation": { "standard": { "generator_args": [ @@ -2741,7 +2742,7 @@ "--target-dtype", "bf16", "--fp-values-range", - "-2.0,2.0", + "-max,max", "--tensor-dim-range", "16,64", "--target-rank", @@ -2759,7 +2760,7 @@ "--target-dtype", "bf16", "--fp-values-range", - "-2.0,2.0", + "-max,max", "--tensor-dim-range", "1,16", "--target-rank", @@ -2771,7 +2772,7 @@ "--target-dtype", "fp32", "--fp-values-range", - "-2.0,2.0", + "-max,max", "--target-shape", "1,2,65534,2,1", "--target-shape", diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 8641499..bfe7f0d 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1006,12 +1006,12 @@ class TosaTensorValuesGen: return placeholders @staticmethod - def tvgSelect(testGen, op, dtypeList, shapeList, testArgs, error_name=None): + def tvgSelect(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): # Set datatype of condition tensor to boolean dtypeList[0] = DType.BOOL - return TosaTensorValuesGen.tvgDefault( - testGen, op, dtypeList, shapeList, testArgs, error_name + return TosaTensorValuesGen.tvgLazyGenDefault( + testGen, opName, dtypeList, shapeList, argsDict, error_name ) @staticmethod diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 0d072ac..7a759e9 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -601,12 +601,19 @@ class TosaTestGen: return result_tens - def build_select(self, op, cond, a, b, validator_fcns=None, error_name=None): - result_tens = OutputShaper.selectOp(self.ser, self.rng, cond, a, b, error_name) + def build_select( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 3 + cond, a, b = inputs + + result_tensor = OutputShaper.selectOp( + self.ser, self.rng, cond, a, b, error_name + ) # Invalidate Input/Output list for error if checks. input_list = [cond.name, a.name, b.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -623,8 +630,8 @@ class TosaTestGen: input3=b, input_shape=a.shape, input_dtype=a.dtype, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], + output_dtype=result_tensor.dtype, + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -636,7 +643,11 @@ class TosaTestGen: input_list, output_list, ) - return result_tens + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_comparison( self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None @@ -3882,7 +3893,7 @@ class TosaTestGen: build_select, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgSelect, - None, + TosaArgGen.agNone, ), "types": TYPE_FIB, "error_if_validators": ( @@ -3894,6 +3905,9 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, # Comparison operators "equal": { -- cgit v1.2.1