From 708da823504b9a7f4e2ffc10e00f06bb092ce637 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 15 Nov 2023 16:25:45 +0000 Subject: Main Compliance testing support for CAST Limit CAST input tensor to maximums of output type to avoid saturation and infinity. Signed-off-by: Jeremy Johnson Change-Id: I33350a4ce0ec828da7d2e7aa8cd3183a89a97431 --- reference_model/src/generate/generate_utils.cc | 1 + verif/conformance/tosa_main_profile_ops_info.json | 7 +++-- verif/generator/tosa_arg_gen.py | 34 +++++++++++++++++++- verif/generator/tosa_test_gen.py | 38 +++++++++++++++++------ 4 files changed, 68 insertions(+), 12 deletions(-) diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index 58a3d33..d1758fc 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -42,6 +42,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op, { Op::Op_ADD, "ADD" }, { Op::Op_ARGMAX, "ARGMAX" }, { Op::Op_AVG_POOL2D, "AVG_POOL2D" }, + { Op::Op_CAST, "CAST" }, { Op::Op_CEIL, "CEIL" }, { Op::Op_CLAMP, "CLAMP" }, { Op::Op_CONV2D, "CONV2D" }, diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index 0b20e4f..960ad27 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -253,6 +253,7 @@ "profile": [ "tosa-mi" ], + "support_for": [ "lazy_data_gen" ], "generation": { "standard": { "negative_dim_range": "1,10", @@ -271,7 +272,7 @@ "--target-dtype", "int32", "--fp-values-range", - "-2.0,2.0", + "-max,max", "--tensor-dim-range", "16,64", "--target-rank", @@ -295,7 +296,7 @@ "--target-dtype", "int32", "--fp-values-range", - "-2.0,2.0", + "-max,max", "--tensor-dim-range", "1,16", "--target-rank", @@ -306,6 +307,8 @@ [ "--target-dtype", "fp16", + "--fp-values-range", + "-max,max", "--target-shape", "1,1,1,65533,1", "--target-shape", diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 3057963..c557207 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1434,6 +1434,27 @@ class TosaTensorValuesGen: testGen, opName, dtypeList, shapeList, argsDict, error_name ) + @staticmethod + def tvgCast(testGen, opName, dtypeList, shapeList, argsDict, error_name=None): + in_dtype = dtypeList[0] + out_dtype = argsDict["out_type"] + # Create look up to limit input tensor to output type maximums to avoid + # FP infinities and saturation of integers + out_range = testGen.getDTypeRange(out_dtype, high_inclusive=True) + highval_lookup = {in_dtype: out_range[1]} + data_range = TosaTensorValuesGen._get_data_range( + testGen, + in_dtype, + highval_lookup, + ) + + assert data_range is not None + argsDict["data_range"] = data_range + + return TosaTensorValuesGen.tvgLazyGenDefault( + testGen, opName, dtypeList, shapeList, argsDict, error_name + ) + class TosaArgGen: """Argument generators create exhaustive or random lists of attributes for @@ -2350,7 +2371,18 @@ class TosaArgGen: raise Exception("Unexpected input dtype: {}".format(inDtype)) for dtype in dtypeList: - arg_list.append(("out{}".format(testGen.typeStr(dtype)), [dtype])) + arg_list.append( + ("out{}".format(testGen.typeStr(dtype)), {"out_type": dtype}) + ) + + # Now add data generator types + arg_list = TosaArgGen._add_data_generators( + testGen, + opName, + dtype, + arg_list, + error_name, + ) return arg_list diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 63958a9..1602109 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -307,10 +307,15 @@ class TosaTestGen: def tensorComplianceMetaData( self, op, inputType, argsDict, outputTensor, errorName ): + # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet + UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED) if ( errorName or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype) - or not gtu.dtypeIsSupportedByCompliance(inputType) + or ( + not gtu.dtypeIsSupportedByCompliance(inputType) + and op["op"] in UNSUPPORTED_NON_FP32_INPUT_OPS + ) ): # No compliance for error tests or unsupported types currently return None @@ -1874,14 +1879,20 @@ class TosaTestGen: return val # Type Conversion - def build_cast(self, op, val, out_dtype, validator_fcns=None, error_name=None): - result_tens = OutputShaper.typeConversionOp( + def build_cast( + self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None + ): + assert len(inputs) == 1 + val = inputs[0] + out_dtype = args_dict["out_type"] + + result_tensor = OutputShaper.typeConversionOp( self.ser, self.rng, val, out_dtype, error_name ) # Invalidate Input/Output list for error if checks. input_list = [val.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] pCount, cCount = op["operands"] num_operands = pCount + cCount input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( @@ -1894,10 +1905,10 @@ class TosaTestGen: error_name, op=op, input_shape=val.shape, - output_shape=result_tens.shape, + output_shape=result_tensor.shape, input_dtype=val.dtype, - output_dtype=result_tens.dtype, - result_tensors=[result_tens], + output_dtype=result_tensor.dtype, + result_tensors=[result_tensor], input_list=input_list, output_list=output_list, num_operands=num_operands, @@ -1905,7 +1916,12 @@ class TosaTestGen: return None self.ser.addOperator(op["op"], input_list, output_list) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, val.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_rescale( self, @@ -4365,7 +4381,7 @@ class TosaTestGen: "build_fcn": ( build_cast, TosaTensorGen.tgBasic, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgCast, TosaArgGen.agCast, ), "types": ( @@ -4383,6 +4399,10 @@ class TosaTestGen: TosaErrorValidator.evWrongInputList, TosaErrorValidator.evWrongOutputList, ), + "data_gen": { + "fp": (gtu.DataGenType.PSEUDO_RANDOM,), + }, + "compliance": {"ulp": 0.5}, }, "rescale": { "op": Op.RESCALE, -- cgit v1.2.1