diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 42 |
1 files changed, 30 insertions, 12 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 04093b8..7b44ced 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -622,14 +622,19 @@ class TosaTestGen: ) return result_tens - def build_comparison(self, op, a, b, validator_fcns=None, error_name=None): - result_tens = OutputShaper.binaryComparisonOp( + def build_comparison( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 2 + a, b = inputs + + result_tensor = OutputShaper.binaryComparisonOp( self.ser, self.rng, a, b, error_name ) # Invalidate Input/Output list for error if checks. input_list = [a.name, b.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -645,9 +650,9 @@ class TosaTestGen: input2=b, input_shape=a.shape, input_dtype=a.dtype, - output_shape=result_tens.shape, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], + output_shape=result_tensor.shape, + output_dtype=result_tensor.dtype, + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -659,7 +664,11 @@ class TosaTestGen: input_list, output_list, ) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_argmax( self, op, inputs, args_dict, validator_fcns, error_name, qinfo=None @@ -3863,7 +3872,7 @@ class TosaTestGen: build_comparison, TosaTensorGen.tgBroadcastFuzz, TosaTensorValuesGen.tvgEqual, - None, + TosaArgGen.agNone, ), "types": TYPE_FI32, "error_if_validators": ( @@ -3875,6 +3884,9 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, "greater_equal": { "op": Op.GREATER_EQUAL, @@ -3882,8 +3894,8 @@ class TosaTestGen: "build_fcn": ( build_comparison, TosaTensorGen.tgBroadcastFuzz, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FI32, "error_if_validators": ( @@ -3895,6 +3907,9 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, "greater": { "op": Op.GREATER, @@ -3902,8 +3917,8 @@ class TosaTestGen: "build_fcn": ( build_comparison, TosaTensorGen.tgBroadcastFuzz, - TosaTensorValuesGen.tvgDefault, - None, + TosaTensorValuesGen.tvgLazyGenDefault, + TosaArgGen.agNone, ), "types": TYPE_FI32, "error_if_validators": ( @@ -3915,6 +3930,9 @@ class TosaTestGen: TosaErrorValidator.evDimensionMismatch, TosaErrorValidator.evBroadcastShapesMismatch, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, }, # Reduction operators "reduce_all": { |