aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py130
1 files changed, 68 insertions, 62 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 17cbd8f..54b624e 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -56,11 +56,9 @@ class TosaTestGen:
self.random_fp_high = max(args.tensor_fp_value_range)
# JSON schema validation
self.descSchemaValidator = TestDescSchemaValidator()
- # Data generator library when not generating the data later
- if not args.lazy_data_gen:
- self.dgl = GenerateLibrary(args.generate_lib_path)
- else:
- self.dgl = None
+ # Data generator library is sometimes needed for compliance set up
+ # even if we are generating the data later (lazy_data_generation)
+ self.dgl = GenerateLibrary(args.generate_lib_path)
def createSerializer(self, opName, testPath):
self.testPath = os.path.join(opName, testPath)
@@ -108,11 +106,6 @@ class TosaTestGen:
fd.write(f'const char* json_tdg_config_{path.stem} = R"(')
json.dump(metaData["data_gen"], fd)
fd.write(')";\n\n')
- else:
- # Generate the data
- self.dgl.set_config(desc)
- self.dgl.write_numpy_files(path)
-
if "compliance" in metaData:
# Output datagen meta data as CPP data
path_md = path / f"{testName}_meta_compliance.cpp"
@@ -293,9 +286,15 @@ class TosaTestGen:
low=self.args.tensor_shape_range[0], high=self.args.tensor_shape_range[1]
)
- def tensorComplianceMetaData(self, op, argsDict, outputTensor, errorName):
- if errorName or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype):
- # No compliance for error tests or other data types currently
+ def tensorComplianceMetaData(
+ self, op, inputType, argsDict, outputTensor, errorName
+ ):
+ if (
+ errorName
+ or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
+ or not gtu.dtypeIsSupportedByCompliance(inputType)
+ ):
+ # No compliance for error tests or unsupported types currently
return None
# Create compliance meta data for expected output tensor
@@ -308,7 +307,9 @@ class TosaTestGen:
mode = gtu.ComplianceMode.DOT_PRODUCT
compliance_tens["dot_product_info"] = {
"s": argsDict["s"],
- "ks": argsDict["ks"],
+ "ks": int(argsDict["ksb"])
+ if "ksb" in argsDict
+ else int(argsDict["ks"]),
}
elif argsDict["dg_type"] == gtu.DataGenType.OP_SPECIAL:
mode = gtu.ComplianceMode.FP_SPECIAL
@@ -741,31 +742,30 @@ class TosaTestGen:
error_name,
qinfo,
)
- if gtu.dtypeIsSupportedByCompliance(inputs[0].dtype):
- compliance = self.tensorComplianceMetaData(
- op, args_dict, result_tensor, error_name
- )
- else:
- compliance = None
+ compliance = self.tensorComplianceMetaData(
+ op, inputs[0].dtype, args_dict, result_tensor, error_name
+ )
return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_conv2d(
self,
op,
- ifm,
- filter,
- bias,
- accum_dtype,
- strides,
- padding,
- dilations,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 3
+ ifm, filter, bias = inputs
+ accum_dtype = args_dict["acc_type"]
+ strides = args_dict["stride"]
+ padding = args_dict["pad"]
+ dilations = args_dict["dilation"]
+
assert len(padding) == 4
- result_tens = OutputShaper.conv2dOp(
+ result_tensor = OutputShaper.conv2dOp(
self.ser,
self.rng,
ifm,
@@ -784,12 +784,12 @@ class TosaTestGen:
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
- TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
@@ -802,7 +802,7 @@ class TosaTestGen:
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
- output_dtype=result_tens.dtype,
+ output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
@@ -812,7 +812,7 @@ class TosaTestGen:
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
):
return None
@@ -820,22 +820,29 @@ class TosaTestGen:
attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, ifm.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_conv3d(
self,
op,
- ifm,
- filter,
- bias,
- accum_dtype,
- strides,
- padding,
- dilations,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 3
+ ifm, filter, bias = inputs
+ accum_dtype = args_dict["acc_type"]
+ strides = args_dict["stride"]
+ padding = args_dict["pad"]
+ dilations = args_dict["dilation"]
+
assert len(padding) == 6
result_tens = OutputShaper.conv3dOp(
self.ser,
@@ -960,17 +967,19 @@ class TosaTestGen:
def build_depthwise_conv2d(
self,
op,
- ifm,
- filter,
- bias,
- accum_dtype,
- strides,
- padding,
- dilations,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 3
+ ifm, filter, bias = inputs
+ accum_dtype = args_dict["acc_type"]
+ strides = args_dict["stride"]
+ padding = args_dict["pad"]
+ dilations = args_dict["dilation"]
+
result_tens = OutputShaper.depthwiseConv2dOp(
self.ser,
self.rng,
@@ -1121,12 +1130,9 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
- if gtu.dtypeIsSupportedByCompliance(a.dtype):
- compliance = self.tensorComplianceMetaData(
- op, args_dict, result_tensor, error_name
- )
- else:
- compliance = None
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
return TosaTestGen.BuildInfo(result_tensor, compliance)
@@ -1431,12 +1437,9 @@ class TosaTestGen:
self.ser.addOperator(op["op"], input_list, output_list, attr)
- if gtu.dtypeIsSupportedByCompliance(a.dtype):
- compliance = self.tensorComplianceMetaData(
- op, args_dict, result_tensor, error_name
- )
- else:
- compliance = None
+ compliance = self.tensorComplianceMetaData(
+ op, a.dtype, args_dict, result_tensor, error_name
+ )
return TosaTestGen.BuildInfo(result_tensor, compliance)
@@ -2911,7 +2914,7 @@ class TosaTestGen:
"build_fcn": (
build_conv2d,
TosaTensorGen.tgConv2D,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
@@ -2931,6 +2934,9 @@ class TosaTestGen:
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists
@@ -2941,7 +2947,7 @@ class TosaTestGen:
"build_fcn": (
build_conv3d,
TosaTensorGen.tgConv3D,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,
@@ -2972,7 +2978,7 @@ class TosaTestGen:
"build_fcn": (
build_depthwise_conv2d,
TosaTensorGen.tgDepthwiseConv2D,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agConv,
),
"qgen": TosaQuantGen.qgConv,