aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.cc
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-03-14 16:21:29 +0000
committerEric Kunze <eric.kunze@arm.com>2024-03-20 00:02:15 +0000
commitf36f25619cc3a34c75e78637ed244a2ca54ab3f4 (patch)
treeb1aa6a7314ef598561f0259c4d614a4169451031 /reference_model/src/ops/op_factory.cc
parent0a6d1deef02f2bd76b3068d615565f20c46075a5 (diff)
downloadreference_model-f36f25619cc3a34c75e78637ed244a2ca54ab3f4.tar.gz
[ref model] Add acc_type to Conv Ops
This patch implements changes required by the new acc_type field in ConvAttribute and TransposeConvAttribute Signed-off-by: Tai Ly <tai.ly@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r--reference_model/src/ops/op_factory.cc85
1 files changed, 44 insertions, 41 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 0f0013c..74315d7 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -70,41 +70,43 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16);
break;
case Op_CONV2D:
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
case Op_CONV3D:
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
case Op_DEPTHWISE_CONV2D:
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
case Op_FFT2D:
DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
@@ -148,16 +150,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64);
break;
case Op_TRANSPOSE_CONV2D:
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
// activation_funcs