aboutsummaryrefslogtreecommitdiff
path: root/verif/generator/tosa_test_gen.py
diff options
context:
space:
mode:
Diffstat (limited to 'verif/generator/tosa_test_gen.py')
-rw-r--r--verif/generator/tosa_test_gen.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 68a4e94..9c3cd32 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -325,6 +325,7 @@ class TosaTestGen:
Op.FULLY_CONNECTED,
Op.DEPTHWISE_CONV2D,
Op.TRANSPOSE_CONV2D,
+ Op.CONV3D,
)
if (
errorName
@@ -952,7 +953,7 @@ class TosaTestGen:
dilations = args_dict["dilation"]
assert len(padding) == 6
- result_tens = OutputShaper.conv3dOp(
+ result_tensor = OutputShaper.conv3dOp(
self.ser,
self.rng,
ifm,
@@ -971,12 +972,12 @@ class TosaTestGen:
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
- TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
@@ -989,7 +990,7 @@ class TosaTestGen:
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
- output_dtype=result_tens.dtype,
+ output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
@@ -999,7 +1000,7 @@ class TosaTestGen:
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
):
return None
@@ -1010,7 +1011,12 @@ class TosaTestGen:
attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, ifm.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_transpose_conv2d(
self,
@@ -3254,6 +3260,9 @@ class TosaTestGen:
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
"template": True,
},
# Templated operator. Filled in by createDynamicOpLists