aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-14 16:21:29 +0000
committerEric Kunze <eric.kunze@arm.com>2024-03-20 00:02:15 +0000
commitf36f25619cc3a34c75e78637ed244a2ca54ab3f4 (patch)
treeb1aa6a7314ef598561f0259c4d614a4169451031
parent0a6d1deef02f2bd76b3068d615565f20c46075a5 (diff)
downloadreference_model-f36f25619cc3a34c75e78637ed244a2ca54ab3f4.tar.gz
[ref model] Add acc_type to Conv Ops
This patch implements changes required by the new acc_type field in ConvAttribute and TransposeConvAttribute Signed-off-by: Tai Ly <tai.ly@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
-rw-r--r--.gitignore3
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosabin1676 -> 1708 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosabin1380 -> 1412 bytes
-rw-r--r--examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosabin1484 -> 1516 bytes
-rw-r--r--examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosabin11920 -> 11956 bytes
-rw-r--r--reference_model/src/graph_node.h4
-rw-r--r--reference_model/src/ops/op_factory.cc85
-rw-r--r--reference_model/src/ops/op_factory.h8
-rw-r--r--reference_model/src/ops/tensor_ops.cc163
-rw-r--r--reference_model/src/ops/tensor_ops.h16
m---------thirdparty/serialization_lib0
-rw-r--r--verif/generator/tosa_arg_gen.py240
-rw-r--r--verif/generator/tosa_error_if.py34
-rw-r--r--verif/generator/tosa_test_gen.py42
-rw-r--r--verif/generator/tosa_utils.py14
15 files changed, 355 insertions, 254 deletions
diff --git a/.gitignore b/.gitignore
index 941cf20..dddbcc0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,6 +4,9 @@
__pycache__/
build/
debug-build/
+conformance/
+conformance_build/
+conformance_large_files/
.cache
compile_commands.json
dist/
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
index 20e1333..dc8413a 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
index d55d5d6..521c1a3 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
index 01d8375..73c1839 100644
--- a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
+++ b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa
Binary files differ
diff --git a/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa b/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa
index deaca6e..fff67ef 100644
--- a/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa
+++ b/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa
Binary files differ
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index e10f132..c0dceda 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -41,6 +41,10 @@
#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>;
+#define DEF_INSTANTIATE_FOUR_TYPE(OP, DTYPE1, DTYPE2, DTYPE3, DTYPE4) \
+ template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3, \
+ TOSA_REF_TYPE_##DTYPE4>;
+
#define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE) \
template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, OP_TYPE>;
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 0f0013c..74315d7 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -70,41 +70,43 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16);
break;
case Op_CONV2D:
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
case Op_CONV3D:
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
case Op_DEPTHWISE_CONV2D:
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
case Op_FFT2D:
DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
@@ -148,16 +150,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64);
break;
case Op_TRANSPOSE_CONV2D:
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
// activation_funcs
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index 1d20066..f1d1680 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -94,6 +94,14 @@
return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>(sgt, attribute, id); \
}
+#define DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OP, ATTR_NAME, IN_DTYPE, W_DTYPE, ACC_DTYPE, OUT_DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##IN_DTYPE && weightDTYPE == TOSA_REF_TYPE_##W_DTYPE && \
+ outputDTYPE == TOSA_REF_TYPE_##OUT_DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACC_DTYPE) \
+ { \
+ return new OP<TOSA_REF_TYPE_##IN_DTYPE, TOSA_REF_TYPE_##W_DTYPE, TOSA_REF_TYPE_##ACC_DTYPE, \
+ TOSA_REF_TYPE_##OUT_DTYPE>(sgt, attribute, id); \
+ }
+
// Statement-expression to evaluate accumulate attribute in-place
#define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \
({ \
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 7bd249b..afd20e9 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -586,8 +586,10 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
return GraphNode::eval();
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -596,15 +598,15 @@ OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, Tosa
INIT_ATTRIBUTE(Conv);
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpConv2d()
{
if (attribute)
delete attribute;
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -640,8 +642,8 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -793,8 +795,10 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_CONV3D, id_)
{
setRequiredOperands(3, 1);
@@ -803,15 +807,15 @@ OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, Tosa
INIT_ATTRIBUTE(Conv);
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpConv3d()
{
if (attribute)
delete attribute;
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -847,8 +851,8 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_depth = this->input->getShape()[1];
@@ -1008,10 +1012,10 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -1020,15 +1024,15 @@ OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTra
INIT_ATTRIBUTE(Conv);
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpDepthwiseConv2d()
{
if (attribute)
delete attribute;
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1064,8 +1068,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -1903,10 +1907,10 @@ int OpRFFT2d<Dtype>::eval()
return GraphNode::eval();
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -1915,15 +1919,15 @@ OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTra
INIT_ATTRIBUTE(TransposeConv);
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -2017,8 +2021,8 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
-int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -2168,39 +2172,39 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16);
-// [in_t, weight_t, out_t]
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
-
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
-
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
+// [in_t, weight_t, acc_t, out_t]
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP32, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, BF16, BF16, FP32, BF16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP32, FP32, FP32, FP32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT4, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT8, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT16, INT8, INT48, INT48);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP64, FP64, FP64, FP64);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16, FP16);
+
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP32, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, BF16, BF16, FP32, BF16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP32, FP32, FP32, FP32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT4, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT8, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT16, INT8, INT48, INT48);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP64, FP64, FP64, FP64);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16, FP16);
+
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32, BF16);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32, FP32);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48, INT48);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64, FP64);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
@@ -2238,13 +2242,14 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E5M2);
DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
-DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
+// [in_t, weight_t, acc_t, out_t]
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP32, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, BF16, BF16, FP32, BF16);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP32, FP32, FP32, FP32);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT4, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT8, INT32, INT32);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT16, INT8, INT48, INT48);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP64, FP64, FP64, FP64);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16, FP16);
+DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16, FP16);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index e2bb811..2e65548 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -75,7 +75,7 @@ protected:
int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
};
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
class OpConv2d : public GraphNode
{
public:
@@ -87,7 +87,7 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
@@ -105,7 +105,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
class OpConv3d : public GraphNode
{
public:
@@ -117,7 +117,7 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 5>;
using TWeight = Eigen::Tensor<WeightEigenType, 5>;
@@ -135,7 +135,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
class OpDepthwiseConv2d : public GraphNode
{
public:
@@ -147,7 +147,7 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
@@ -294,7 +294,7 @@ protected:
tosa::TosaRFFTAttribute* attribute;
};
-template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype>
class OpTransposeConv2d : public GraphNode
{
public:
@@ -306,7 +306,7 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 0b6d7c271af1e6593e6a2cf14b32acea765f4b6
+Subproject ad78daaf0fa1e41742cbed314459c3dbbb483c2
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 83487a1..ffa3683 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -1990,7 +1990,12 @@ class TosaArgGen:
# Shape: (OFM channels), (KD), KH, KW, IFM channels
filter_shape = shapeList[1]
- accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
+ accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
+
+ if error_name == ErrorIf.WrongAccumulatorType:
+ accum_dtypes = (
+ [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
+ )
# Op type checks
conv3d = opName.startswith("conv3d")
@@ -2110,88 +2115,91 @@ class TosaArgGen:
sparsity = 1
n = 0
- for s in sorted(list(strides)):
- for p in sorted(list(paddings)):
- for d in sorted(list(dilations)):
- if (
- n % sparsity == 0
- # the padded shape must exceed the dilation * kernel to get a positive
- # sized output shape
- and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1)
- and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1)
- and (
- k_rank < 3
- or (
- (ifm_shape[3] - 1 + p[4] + p[5])
- > d[2] * (k_shape[2] - 1)
- )
- )
- ):
- remainders = []
- outputs = []
- for index in range(k_rank):
- pad_offset = index * 2
- partial = (
- ifm_shape[index + 1]
- - 1
- + p[pad_offset]
- + p[pad_offset + 1]
- - (k_shape[index] - 1) * d[index]
- )
- remainders.append(partial % s[index])
- outputs.append((partial // s[index]) + 1)
-
+ for a in accum_dtypes:
+ for s in sorted(list(strides)):
+ for p in sorted(list(paddings)):
+ for d in sorted(list(dilations)):
if (
- # the parameters must produce integer exact output
- error_name != ErrorIf.ConvOutputShapeNonInteger
- and max(remainders) == 0
- ) or (
- error_name == ErrorIf.ConvOutputShapeNonInteger
- and max(remainders) > 0
+ n % sparsity == 0
+ # the padded shape must exceed the dilation * kernel to get a positive
+ # sized output shape
+ and (ifm_shape[1] - 1 + p[0] + p[1])
+ > d[0] * (k_shape[0] - 1)
+ and (ifm_shape[2] - 1 + p[2] + p[3])
+ > d[1] * (k_shape[1] - 1)
+ and (
+ k_rank < 3
+ or (
+ (ifm_shape[3] - 1 + p[4] + p[5])
+ > d[2] * (k_shape[2] - 1)
+ )
+ )
):
+ remainders = []
+ outputs = []
+ for index in range(k_rank):
+ pad_offset = index * 2
+ partial = (
+ ifm_shape[index + 1]
+ - 1
+ + p[pad_offset]
+ + p[pad_offset + 1]
+ - (k_shape[index] - 1) * d[index]
+ )
+ remainders.append(partial % s[index])
+ outputs.append((partial // s[index]) + 1)
+
if (
- max_dim_size is not None
- and max(outputs) >= max_dim_size
+ # the parameters must produce integer exact output
+ error_name != ErrorIf.ConvOutputShapeNonInteger
+ and max(remainders) == 0
+ ) or (
+ error_name == ErrorIf.ConvOutputShapeNonInteger
+ and max(remainders) > 0
):
- # Test will consume too much memory - skip it
- continue
-
- # Compliance - number of dot product calculations
- if depthwise:
- # N*OH*OW*C*M
- dots = gtu.product(
- (ifm_shape[0], *outputs, *filter_shape[2:])
- )
- else:
- # N*OH*OW*OC or N*OD*OH*OW*OC
- dots = gtu.product(
- (ifm_shape[0], *outputs, filter_shape[0])
- )
- args_dict = {
- "acc_type": accum_dtype,
- "stride": s,
- "pad": p,
- "dilation": d,
- "kernel": k_shape,
- "ks": k_size,
- "dot_products": dots,
- "shape": ifm_shape,
- }
-
- # Support for larger values than 9 needs different delimiter
- delim = "" if max(s + p + d) <= 9 else "x"
- arg_list.append(
- (
- "acc{}_st{}_pad{}_dilat{}".format(
- testGen.typeStr(accum_dtype),
- delim.join([str(x) for x in s]),
- delim.join([str(x) for x in p]),
- delim.join([str(x) for x in d]),
- ),
- args_dict,
+ if (
+ max_dim_size is not None
+ and max(outputs) >= max_dim_size
+ ):
+ # Test will consume too much memory - skip it
+ continue
+
+ # Compliance - number of dot product calculations
+ if depthwise:
+ # N*OH*OW*C*M
+ dots = gtu.product(
+ (ifm_shape[0], *outputs, *filter_shape[2:])
+ )
+ else:
+ # N*OH*OW*OC or N*OD*OH*OW*OC
+ dots = gtu.product(
+ (ifm_shape[0], *outputs, filter_shape[0])
+ )
+ args_dict = {
+ "acc_type": a,
+ "stride": s,
+ "pad": p,
+ "dilation": d,
+ "kernel": k_shape,
+ "ks": k_size,
+ "dot_products": dots,
+ "shape": ifm_shape,
+ }
+
+ # Support for larger values than 9 needs different delimiter
+ delim = "" if max(s + p + d) <= 9 else "x"
+ arg_list.append(
+ (
+ "acc{}_st{}_pad{}_dilat{}".format(
+ testGen.typeStr(a),
+ delim.join([str(x) for x in s]),
+ delim.join([str(x) for x in p]),
+ delim.join([str(x) for x in d]),
+ ),
+ args_dict,
+ )
)
- )
- n += 1
+ n += 1
arg_list = TosaArgGen._add_data_generators(
testGen,
@@ -2216,7 +2224,7 @@ class TosaArgGen:
# Pick some potentially correct output dtype if input type is incorrect
accum_dtype = DType.INT32
else:
- accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
+ accum_dtype = dtypes[-1] # use output dtype as accum_dtype
# Set up compliance info
args_dict = {
@@ -2303,7 +2311,12 @@ class TosaArgGen:
ifm_shape = shapeList[0]
filter_shape = shapeList[1]
- accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes)
+ accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes)
+
+ if error_name == ErrorIf.WrongAccumulatorType:
+ accum_dtypes = (
+ [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16]
+ )
# Must be rank 4
if error_name != ErrorIf.WrongRank:
@@ -2400,41 +2413,42 @@ class TosaArgGen:
sparsity = 1
n = 0
- for s in sorted(list(strides)):
- for p in sorted(list(paddings)):
- if n % sparsity == 0:
- # Determine the output shape
- oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
- ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
- os = [ifm_shape[0], oh, ow, filter_shape[0]]
-
- # N*OH*OW*OC
- dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
- args_dict = {
- "acc_type": accum_dtype,
- "stride": s,
- "pad": p,
- "kernel": k_shape,
- "ks": k_size,
- "dot_products": dots,
- "shape": ifm_shape,
- "out_shape": os,
- }
+ for a in accum_dtypes:
+ for s in sorted(list(strides)):
+ for p in sorted(list(paddings)):
+ if n % sparsity == 0:
+ # Determine the output shape
+ oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0]
+ ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1]
+ os = [ifm_shape[0], oh, ow, filter_shape[0]]
+
+ # N*OH*OW*OC
+ dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0]))
+ args_dict = {
+ "acc_type": a,
+ "stride": s,
+ "pad": p,
+ "kernel": k_shape,
+ "ks": k_size,
+ "dot_products": dots,
+ "shape": ifm_shape,
+ "out_shape": os,
+ }
- # Support for larger values than 9 needs different delimiter
- delim = "" if max(s + p) <= 9 else "x"
- arg_list.append(
- (
- "acc{}_st{}_pad{}_os{}".format(
- testGen.typeStr(accum_dtype),
- delim.join([str(x) for x in s]),
- delim.join([str(x) for x in p]),
- "x".join([str(x) for x in os]),
- ),
- args_dict,
+ # Support for larger values than 9 needs different delimiter
+ delim = "" if max(s + p) <= 9 else "x"
+ arg_list.append(
+ (
+ "acc{}_st{}_pad{}_os{}".format(
+ testGen.typeStr(a),
+ delim.join([str(x) for x in s]),
+ delim.join([str(x) for x in p]),
+ "x".join([str(x) for x in os]),
+ ),
+ args_dict,
+ )
)
- )
- n += 1
+ n += 1
arg_list = TosaArgGen._add_data_generators(
testGen,
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index e557f06..916b4f9 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -649,9 +649,9 @@ class TosaErrorValidator:
or input_dtype == DType.INT16
and output_dtype != DType.INT48
or input_dtype == DType.FP16
- and output_dtype not in (DType.FP16, DType.FP32)
+ and output_dtype != DType.FP16
or input_dtype == DType.BF16
- and output_dtype != DType.FP32
+ and output_dtype != DType.BF16
or input_dtype == DType.FP32
and output_dtype != DType.FP32
or input_dtype == DType.FP8E4M3
@@ -2682,6 +2682,36 @@ class TosaErrorValidator:
):
error_result = True
+ elif op["op"] in {
+ Op.CONV2D,
+ Op.CONV3D,
+ Op.DEPTHWISE_CONV2D,
+ Op.TRANSPOSE_CONV2D,
+ }:
+ if input_dtype == DType.INT8 and accum_dtype != DType.INT32:
+ error_result = True
+ elif input_dtype == DType.INT16 and accum_dtype != DType.INT48:
+ error_result = True
+ elif (
+ input_dtype
+ in (
+ DType.FP32,
+ DType.BF16,
+ )
+ and accum_dtype != DType.FP32
+ ):
+ error_result = True
+ elif input_dtype == DType.FP16 and accum_dtype not in (
+ DType.FP16,
+ DType.FP32,
+ ):
+ error_result = True
+ elif (
+ input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
+ and accum_dtype != DType.FP16
+ ):
+ error_result = True
+
info_dict = {
"error_name": error_name,
"error_result": error_result,
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 7702753..c867070 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -896,6 +896,7 @@ class TosaTestGen:
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
+ accum_dtype=accum_dtype,
):
return None
@@ -903,7 +904,9 @@ class TosaTestGen:
local_bound = False
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
+ attr.ConvAttribute(
+ padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
+ )
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -981,6 +984,7 @@ class TosaTestGen:
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
+ accum_dtype=accum_dtype,
):
return None
@@ -988,7 +992,9 @@ class TosaTestGen:
local_bound = False
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
+ attr.ConvAttribute(
+ padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
+ )
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -1057,6 +1063,7 @@ class TosaTestGen:
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
+ accum_dtype=accum_dtype,
):
return None
@@ -1065,7 +1072,7 @@ class TosaTestGen:
attr = ts.TosaSerializerAttribute()
attr.TransposeConvAttribute(
- out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound
+ out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype
)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -1143,6 +1150,7 @@ class TosaTestGen:
input_shape=ifm.shape,
weight_shape=filter.shape,
output_shape=result_tensor.shape,
+ accum_dtype=accum_dtype,
):
return None
@@ -1150,7 +1158,9 @@ class TosaTestGen:
local_bound = False
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound)
+ attr.ConvAttribute(
+ padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype
+ )
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -3385,6 +3395,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
+ TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": {
"fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -3418,6 +3429,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
+ TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": {
"fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -3452,6 +3464,7 @@ class TosaTestGen:
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
TosaErrorValidator.evConvOutputShapeNonInteger,
+ TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": {
"fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -3564,6 +3577,7 @@ class TosaTestGen:
TosaErrorValidator.evStrideSmallerOne,
TosaErrorValidator.evWrongRank,
TosaErrorValidator.evConvOutputShapeMismatch,
+ TosaErrorValidator.evWrongAccumulatorType,
),
"data_gen": {
"fp": (gtu.DataGenType.DOT_PRODUCT,),
@@ -5290,6 +5304,18 @@ class OutputShaper:
return ser.addOutput(shape, outputDType)
@staticmethod
+ def _get_conv_output_type(input_dtype):
+ if input_dtype in (DType.FP16, DType.BF16, DType.FP32):
+ return input_dtype
+ elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2):
+ return DType.FP16
+ elif input_dtype in (DType.INT8, DType.INT4):
+ return DType.INT32
+ elif input_dtype in (DType.INT16,):
+ return DType.INT48
+ assert True, f"Unsupported convolution data type {input_dtype}"
+
+ @staticmethod
def conv2dOp(
ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None
):
@@ -5329,7 +5355,7 @@ class OutputShaper:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- out_dtype = accum_dtype
+ out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
@@ -5393,7 +5419,7 @@ class OutputShaper:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- out_dtype = accum_dtype
+ out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
@@ -5444,7 +5470,7 @@ class OutputShaper:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- out_dtype = accum_dtype
+ out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
@@ -5958,7 +5984,7 @@ class OutputShaper:
# Pick some potentially correct output dtype if input type is incorrect
out_dtype = DType.INT32
else:
- out_dtype = accum_dtype
+ out_dtype = OutputShaper._get_conv_output_type(ifm.dtype)
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index cfe7cc6..4a4f6bb 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -164,10 +164,18 @@ def product(shape):
return value
-def get_accum_dtype_from_tgTypes(dtypes):
- # Get accumulate data-type from the test generator's defined types
+def get_accum_dtypes_from_tgTypes(dtypes):
+ # Get accumulate data-types from the test generator's defined types
assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
- return dtypes[-1]
+ input_dtype = dtypes[0]
+ output_dtype = dtypes[-1]
+ # by default, accum_dtypes contains only output_dtype
+ accum_dtypes = [output_dtype]
+ if input_dtype == DType.FP16 and output_dtype == DType.FP16:
+ accum_dtypes = [DType.FP16, DType.FP32]
+ elif output_dtype == DType.BF16:
+ accum_dtypes = [DType.FP32]
+ return accum_dtypes
def get_wrong_output_type(op_name, rng, input_dtype):