diff options
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r-- | verif/generator/tosa_test_gen.py | 130 |
1 files changed, 68 insertions, 62 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 17cbd8f..54b624e 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -56,11 +56,9 @@ class TosaTestGen: self.random_fp_high = max(args.tensor_fp_value_range) # JSON schema validation self.descSchemaValidator = TestDescSchemaValidator() - # Data generator library when not generating the data later - if not args.lazy_data_gen: - self.dgl = GenerateLibrary(args.generate_lib_path) - else: - self.dgl = None + # Data generator library is sometimes needed for compliance set up + # even if we are generating the data later (lazy_data_generation) + self.dgl = GenerateLibrary(args.generate_lib_path) def createSerializer(self, opName, testPath): self.testPath = os.path.join(opName, testPath) @@ -108,11 +106,6 @@ class TosaTestGen: fd.write(f'const char* json_tdg_config_{path.stem} = R"(') json.dump(metaData["data_gen"], fd) fd.write(')";\n\n') - else: - # Generate the data - self.dgl.set_config(desc) - self.dgl.write_numpy_files(path) - if "compliance" in metaData: # Output datagen meta data as CPP data path_md = path / f"{testName}_meta_compliance.cpp" @@ -293,9 +286,15 @@ class TosaTestGen: low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1] ) - def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName): - if errorName or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype): - # No compliance for error tests or other data types currently + def tensorComplianceMetaData( + self, op, inputType, argsDict, outputTensor, errorName + ): + if ( + errorName + or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype) + or not gtu.dtypeIsSupportedByCompliance(inputType) + ): + # No compliance for error tests or unsupported types currently return None # Create compliance meta data for expected output tensor @@ -308,7 +307,9 @@ class TosaTestGen: mode = gtu.ComplianceMode.DOT_PRODUCT compliance_tens["dot_product_info"] = { "s": argsDict["s"], - "ks": argsDict["ks"], + "ks": int(argsDict["ksb"]) + if "ksb" in argsDict + else int(argsDict["ks"]), } elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL: mode = gtu.ComplianceMode.FP_SPECIAL @@ -741,31 +742,30 @@ class TosaTestGen: error_name, qinfo, ) - if gtu.dtypeIsSupportedByCompliance(inputs[0].dtype): - compliance = self.tensorComplianceMetaData( - op, args_dict, result_tensor, error_name - ) - else: - compliance = None + compliance = self.tensorComplianceMetaData( + op, inputs[0].dtype, args_dict, result_tensor, error_name + ) return TosaTestGen.BuildInfo(result_tensor, compliance) def build_conv2d( self, op, - ifm, - filter, - bias, - accum_dtype, - strides, - padding, - dilations, + inputs, + args_dict, validator_fcns=None, error_name=None, qinfo=None, ): + assert len(inputs) == 3 + ifm, filter, bias = inputs + accum_dtype = args_dict["acc_type"] + strides = args_dict["stride"] + padding = args_dict["pad"] + dilations = args_dict["dilation"] + assert len(padding) == 4 - result_tens = OutputShaper.conv2dOp( + result_tensor = OutputShaper.conv2dOp( self.ser, self.rng, ifm, @@ -784,12 +784,12 @@ class TosaTestGen: ): qinfo = [ TosaQuantGen.getZeroPoint(self, ifm.dtype), - TosaQuantGen.getZeroPoint(self, result_tens.dtype), + TosaQuantGen.getZeroPoint(self, result_tensor.dtype), ] # Invalidate Input/Output list for error_if checks. input_list = [ifm.name, filter.name, bias.name] - output_list = [result_tens.name] + output_list = [result_tensor.name] num_operands = sum(op["operands"]) input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList( self, error_name, input_list, output_list @@ -802,7 +802,7 @@ class TosaTestGen: op=op, input_dtype=ifm.dtype, weight_dtype=filter.dtype, - output_dtype=result_tens.dtype, + output_dtype=result_tensor.dtype, qinfo=qinfo, input_list=input_list, num_operands=num_operands, @@ -812,7 +812,7 @@ class TosaTestGen: dilation=dilations, input_shape=ifm.shape, weight_shape=filter.shape, - output_shape=result_tens.shape, + output_shape=result_tensor.shape, ): return None @@ -820,22 +820,29 @@ class TosaTestGen: attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) self.ser.addOperator(op["op"], input_list, output_list, attr) - return result_tens + + compliance = self.tensorComplianceMetaData( + op, ifm.dtype, args_dict, result_tensor, error_name + ) + + return TosaTestGen.BuildInfo(result_tensor, compliance) def build_conv3d( self, op, - ifm, - filter, - bias, - accum_dtype, - strides, - padding, - dilations, + inputs, + args_dict, validator_fcns=None, error_name=None, qinfo=None, ): + assert len(inputs) == 3 + ifm, filter, bias = inputs + accum_dtype = args_dict["acc_type"] + strides = args_dict["stride"] + padding = args_dict["pad"] + dilations = args_dict["dilation"] + assert len(padding) == 6 result_tens = OutputShaper.conv3dOp( self.ser, @@ -960,17 +967,19 @@ class TosaTestGen: def build_depthwise_conv2d( self, op, - ifm, - filter, - bias, - accum_dtype, - strides, - padding, - dilations, + inputs, + args_dict, validator_fcns=None, error_name=None, qinfo=None, ): + assert len(inputs) == 3 + ifm, filter, bias = inputs + accum_dtype = args_dict["acc_type"] + strides = args_dict["stride"] + padding = args_dict["pad"] + dilations = args_dict["dilation"] + result_tens = OutputShaper.depthwiseConv2dOp( self.ser, self.rng, @@ -1121,12 +1130,9 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_list, output_list, attr) - if gtu.dtypeIsSupportedByCompliance(a.dtype): - compliance = self.tensorComplianceMetaData( - op, args_dict, result_tensor, error_name - ) - else: - compliance = None + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) return TosaTestGen.BuildInfo(result_tensor, compliance) @@ -1431,12 +1437,9 @@ class TosaTestGen: self.ser.addOperator(op["op"], input_list, output_list, attr) - if gtu.dtypeIsSupportedByCompliance(a.dtype): - compliance = self.tensorComplianceMetaData( - op, args_dict, result_tensor, error_name - ) - else: - compliance = None + compliance = self.tensorComplianceMetaData( + op, a.dtype, args_dict, result_tensor, error_name + ) return TosaTestGen.BuildInfo(result_tensor, compliance) @@ -2911,7 +2914,7 @@ class TosaTestGen: "build_fcn": ( build_conv2d, TosaTensorGen.tgConv2D, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agConv, ), "qgen": TosaQuantGen.qgConv, @@ -2931,6 +2934,9 @@ class TosaTestGen: TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, ), + "data_gen": { + "fp": (gtu.DataGenType.DOT_PRODUCT,), + }, "template": True, }, # Templated operator. Filled in by createDynamicOpLists @@ -2941,7 +2947,7 @@ class TosaTestGen: "build_fcn": ( build_conv3d, TosaTensorGen.tgConv3D, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agConv, ), "qgen": TosaQuantGen.qgConv, @@ -2972,7 +2978,7 @@ class TosaTestGen: "build_fcn": ( build_depthwise_conv2d, TosaTensorGen.tgDepthwiseConv2D, - TosaTensorValuesGen.tvgDefault, + TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agConv, ), "qgen": TosaQuantGen.qgConv, |