aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py29
1 files changed, 21 insertions, 8 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 49d9f1b..6867979 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -318,8 +318,13 @@ class TosaTestGen:
def tensorComplianceMetaData(
self, op, inputType, argsDict, outputTensor, errorName
):
- # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
- UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
+ # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
+ UNSUPPORTED_NON_FP32_INPUT_OPS = (
+ Op.MATMUL,
+ Op.CONV2D,
+ Op.FULLY_CONNECTED,
+ Op.DEPTHWISE_CONV2D,
+ )
if (
errorName
or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
@@ -1063,7 +1068,7 @@ class TosaTestGen:
padding = args_dict["pad"]
dilations = args_dict["dilation"]
- result_tens = OutputShaper.depthwiseConv2dOp(
+ result_tensor = OutputShaper.depthwiseConv2dOp(
self.ser,
self.rng,
ifm,
@@ -1082,12 +1087,12 @@ class TosaTestGen:
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
- TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
@@ -1100,7 +1105,7 @@ class TosaTestGen:
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
- output_dtype=result_tens.dtype,
+ output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
@@ -1110,7 +1115,7 @@ class TosaTestGen:
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
):
return None
@@ -1121,7 +1126,12 @@ class TosaTestGen:
attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, ifm.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_fully_connected(
self,
@@ -3206,6 +3216,9 @@ class TosaTestGen:
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
"template": True,
},
"fully_connected": {