aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py163
1 files changed, 101 insertions, 62 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 8fcea29..17cbd8f 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -658,15 +658,22 @@ class TosaTestGen:
def build_pool2d(
self,
op,
- input,
- accum_dtype,
- stride,
- pad,
- kernel,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
+ assert len(inputs) == 1
+ input = inputs[0]
+ # max_pool has no accum_dtype
+ accum_dtype = (
+ args_dict["acc_type"] if "acc_type" in args_dict else DType.UNKNOWN
+ )
+ stride = args_dict["stride"]
+ pad = args_dict["pad"]
+ kernel = args_dict["kernel"]
+
result_tens = OutputShaper.pool2dOp(
self.ser, self.rng, input, kernel, stride, pad, error_name
)
@@ -720,27 +727,28 @@ class TosaTestGen:
def build_maxpool2d(
self,
op,
- input,
- stride,
- pad,
- kernel,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
- # Same as build_pool2d but manually sets accum_dtype value
- # (maxpool has no accum_dtype)
- return self.build_pool2d(
+ result_tensor = self.build_pool2d(
op,
- input,
- DType.UNKNOWN,
- stride,
- pad,
- kernel,
+ inputs,
+ args_dict,
validator_fcns,
error_name,
qinfo,
)
+ if gtu.dtypeIsSupportedByCompliance(inputs[0].dtype):
+ compliance = self.tensorComplianceMetaData(
+ op, args_dict, result_tensor, error_name
+ )
+ else:
+ compliance = None
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_conv2d(
self,
@@ -1070,8 +1078,10 @@ class TosaTestGen:
return result_tens
def build_matmul(
- self, op, a, b, args_dict, validator_fcns=None, error_name=None, qinfo=None
+ self, op, inputs, args_dict, validator_fcns=None, error_name=None, qinfo=None
):
+ assert len(inputs) == 2
+ a, b = inputs
accum_dtype = args_dict["acc_type"]
result_tensor = OutputShaper.matmulOp(
self.ser, self.rng, a, b, accum_dtype, error_name
@@ -1372,15 +1382,19 @@ class TosaTestGen:
def build_pad(
self,
op,
- a,
- padding,
- pad_const_int,
- pad_const_float,
+ inputs,
+ args_dict,
validator_fcns=None,
error_name=None,
qinfo=None,
):
- result_tens = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
+ assert len(inputs) == 1
+ a = inputs[0]
+ padding = args_dict["pad"]
+ pad_const_int = args_dict["pad_const_int"]
+ pad_const_float = args_dict["pad_const_fp"]
+
+ result_tensor = OutputShaper.padOp(self.ser, self.rng, a, padding, error_name)
attr = ts.TosaSerializerAttribute()
attr.PadAttribute(
@@ -1389,7 +1403,7 @@ class TosaTestGen:
# Invalidate Input/Output list for error if checks.
input_list = [a.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
pCount, cCount = op["operands"]
num_operands = pCount + cCount
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
@@ -1402,12 +1416,12 @@ class TosaTestGen:
error_name,
op=op,
input_shape=a.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
input_dtype=a.dtype,
- output_dtype=result_tens.dtype,
+ output_dtype=result_tensor.dtype,
pad=padding,
qinfo=qinfo,
- result_tensors=[result_tens],
+ result_tensors=[result_tensor],
input_list=input_list,
output_list=output_list,
num_operands=num_operands,
@@ -1416,7 +1430,15 @@ class TosaTestGen:
return None
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ if gtu.dtypeIsSupportedByCompliance(a.dtype):
+ compliance = self.tensorComplianceMetaData(
+ op, args_dict, result_tensor, error_name
+ )
+ else:
+ compliance = None
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_dim(
self,
@@ -2609,8 +2631,9 @@ class TosaTestGen:
tensMeta = {}
# Check we are using the new testArgs interface with an argsDict dictionary
- if len(testArgs) == 1 and isinstance(testArgs[0], dict):
- argsDict = testArgs[0]
+ if isinstance(testArgs, dict):
+ # New interface with args info in dictionary
+ argsDict = testArgs
assert "dg_type" in argsDict
tvgInfo = tvgen_fcn(
self, opName, dtypeList, shapeList, argsDict, error_name
@@ -2618,38 +2641,49 @@ class TosaTestGen:
if tvgInfo.dataGenDict:
tensMeta["data_gen"] = tvgInfo.dataGenDict
tens = tvgInfo.tensorList
+
+ result = build_fcn(
+ self,
+ op,
+ tens,
+ argsDict,
+ validator_fcns=error_if_validators,
+ error_name=error_name,
+ qinfo=qinfo,
+ )
else:
+ # Old interface with args info in a list
tens = tvgen_fcn(self, op, dtypeList, shapeList, testArgs, error_name)
- try:
- if error_if_validators is None:
- if qinfo is not None:
- result = build_fcn(self, op, *tens, *testArgs, qinfo)
- else:
- result = build_fcn(self, op, *tens, *testArgs)
- else:
- if qinfo is not None:
- result = build_fcn(
- self,
- op,
- *tens,
- *testArgs,
- validator_fcns=error_if_validators,
- error_name=error_name,
- qinfo=qinfo,
- )
+ try:
+ if error_if_validators is None:
+ if qinfo is not None:
+ result = build_fcn(self, op, *tens, *testArgs, qinfo)
+ else:
+ result = build_fcn(self, op, *tens, *testArgs)
else:
- result = build_fcn(
- self,
- op,
- *tens,
- *testArgs,
- validator_fcns=error_if_validators,
- error_name=error_name,
- )
- except TypeError as e:
- print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
- raise e
+ if qinfo is not None:
+ result = build_fcn(
+ self,
+ op,
+ *tens,
+ *testArgs,
+ validator_fcns=error_if_validators,
+ error_name=error_name,
+ qinfo=qinfo,
+ )
+ else:
+ result = build_fcn(
+ self,
+ op,
+ *tens,
+ *testArgs,
+ validator_fcns=error_if_validators,
+ error_name=error_name,
+ )
+ except TypeError as e:
+ print(f"build_fcn: {build_fcn}\nTensors: {tens}\nArgs: {testArgs}\n")
+ raise e
if result:
# The test is valid, serialize it
@@ -2847,7 +2881,7 @@ class TosaTestGen:
"build_fcn": (
build_pool2d,
TosaTensorGen.tgNHWC,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agPooling,
),
"qgen": TosaQuantGen.qgUnary,
@@ -3004,7 +3038,6 @@ class TosaTestGen:
),
"data_gen": {
"fp": (gtu.DataGenType.DOT_PRODUCT,),
- "int": (gtu.DataGenType.PSEUDO_RANDOM,),
},
},
"max_pool2d": {
@@ -3014,7 +3047,7 @@ class TosaTestGen:
"build_fcn": (
build_maxpool2d,
TosaTensorGen.tgNHWC,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agPooling,
),
"types": TYPE_NARROW_INT_FP,
@@ -3032,6 +3065,9 @@ class TosaTestGen:
TosaErrorValidator.evPoolingOutputShapeMismatch,
TosaErrorValidator.evPoolingOutputShapeNonInteger,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ },
},
# Templated operator. Filled in by createDynamicOpLists
"transpose_conv2d_TEMPLATE": {
@@ -3909,7 +3945,7 @@ class TosaTestGen:
"build_fcn": (
build_pad,
TosaTensorGen.tgBasic,
- TosaTensorValuesGen.tvgDefault,
+ TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agPad,
),
"types": TYPE_FIB,
@@ -3923,6 +3959,9 @@ class TosaTestGen:
TosaErrorValidator.evRankMismatch,
TosaErrorValidator.evWrongRank,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.PSEUDO_RANDOM,),
+ },
},
"dim": {
"op": Op.DIM,