aboutsummaryrefslogtreecommitdiff
path: root/verif
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2024-01-04 17:05:24 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2024-01-30 11:49:56 +0000
commit4f931307a6319d9d99b3afce4ca6e1cd30d77f01 (patch)
tree5661b63bd087b210403e3b50dbc0ce0a9f8a41b4 /verif
parent2d7e4b13d2c3022ae8176d59e2a11d5584ea1d0b (diff)
downloadreference_model-4f931307a6319d9d99b3afce4ca6e1cd30d77f01.tar.gz
Main Compliance: DEPTHWISE_CONV2D support
Added DEPTHWISE_CONV2D data generation. Updated test generation for FP16 and FP32. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I0471d0a1e4e279a27233f4d285082906ceea1bff
Diffstat (limited to 'verif')
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json7
-rw-r--r--verif/generator/tosa_arg_gen.py7
-rw-r--r--verif/generator/tosa_test_gen.py29
3 files changed, 30 insertions, 13 deletions
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index ced1d9e..c77f0be 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -747,6 +747,7 @@
"profile": [
"tosa-mi"
],
+ "support_for": [ "lazy_data_gen" ],
"generation": {
"standard": {
"negative_dim_range": "1,10",
@@ -759,20 +760,20 @@
"--target-dtype",
"bf16",
"--fp-values-range",
- "-2.0,2.0",
+ "-max,max",
"--target-shape",
"1,17,31,4",
"--target-shape",
"1,37,11,5",
"--tensor-dim-range",
- "1,16",
+ "1,32",
"--allow-pooling-and-conv-oversizes"
],
[
"--target-dtype",
"fp32",
"--fp-values-range",
- "-2.0,2.0",
+ "-max,max",
"--target-shape",
"1,1,65531,2",
"--target-shape",
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 4863956..8501caa 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -2038,9 +2038,12 @@ class TosaArgGen:
# Compliance - number of dot product calculations
if depthwise:
- # TODO - add support
- dots = 0
+ # N*OH*OW*C*M
+ dots = gtu.product(
+ (ifm_shape[0], *outputs, *filter_shape[2:])
+ )
else:
+ # N*OH*OW*OC or N*OD*OH*OW*OC
dots = gtu.product(
(ifm_shape[0], *outputs, filter_shape[0])
)
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 49d9f1b..6867979 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -318,8 +318,13 @@ class TosaTestGen:
def tensorComplianceMetaData(
self, op, inputType, argsDict, outputTensor, errorName
):
- # TODO - Dot product Ops with FP16 or BF16 inputs that produce FP32 outputs are not supported yet
- UNSUPPORTED_NON_FP32_INPUT_OPS = (Op.MATMUL, Op.CONV2D, Op.FULLY_CONNECTED)
+ # TODO - Dot product Ops with BF16 inputs that produce FP32 outputs are not supported yet
+ UNSUPPORTED_NON_FP32_INPUT_OPS = (
+ Op.MATMUL,
+ Op.CONV2D,
+ Op.FULLY_CONNECTED,
+ Op.DEPTHWISE_CONV2D,
+ )
if (
errorName
or not gtu.dtypeIsSupportedByCompliance(outputTensor.dtype)
@@ -1063,7 +1068,7 @@ class TosaTestGen:
padding = args_dict["pad"]
dilations = args_dict["dilation"]
- result_tens = OutputShaper.depthwiseConv2dOp(
+ result_tensor = OutputShaper.depthwiseConv2dOp(
self.ser,
self.rng,
ifm,
@@ -1082,12 +1087,12 @@ class TosaTestGen:
):
qinfo = [
TosaQuantGen.getZeroPoint(self, ifm.dtype),
- TosaQuantGen.getZeroPoint(self, result_tens.dtype),
+ TosaQuantGen.getZeroPoint(self, result_tensor.dtype),
]
# Invalidate Input/Output list for error_if checks.
input_list = [ifm.name, filter.name, bias.name]
- output_list = [result_tens.name]
+ output_list = [result_tensor.name]
num_operands = sum(op["operands"])
input_list, output_list = TosaErrorIfArgGen.eiInvalidateInputOutputList(
self, error_name, input_list, output_list
@@ -1100,7 +1105,7 @@ class TosaTestGen:
op=op,
input_dtype=ifm.dtype,
weight_dtype=filter.dtype,
- output_dtype=result_tens.dtype,
+ output_dtype=result_tensor.dtype,
qinfo=qinfo,
input_list=input_list,
num_operands=num_operands,
@@ -1110,7 +1115,7 @@ class TosaTestGen:
dilation=dilations,
input_shape=ifm.shape,
weight_shape=filter.shape,
- output_shape=result_tens.shape,
+ output_shape=result_tensor.shape,
):
return None
@@ -1121,7 +1126,12 @@ class TosaTestGen:
attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
self.ser.addOperator(op["op"], input_list, output_list, attr)
- return result_tens
+
+ compliance = self.tensorComplianceMetaData(
+ op, ifm.dtype, args_dict, result_tensor, error_name
+ )
+
+ return TosaTestGen.BuildInfo(result_tensor, compliance)
def build_fully_connected(
self,
@@ -3206,6 +3216,9 @@ class TosaTestGen:
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
),
+ "data_gen": {
+ "fp": (gtu.DataGenType.DOT_PRODUCT,),
+ },
"template": True,
},
"fully_connected": {