diff options
author | Tai Ly <tai.ly@arm.com> | 2024-03-14 16:21:29 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-03-20 00:02:15 +0000 |
commit | f36f25619cc3a34c75e78637ed244a2ca54ab3f4 (patch) | |
tree | b1aa6a7314ef598561f0259c4d614a4169451031 /reference_model/src/ops/op_factory.cc | |
parent | 0a6d1deef02f2bd76b3068d615565f20c46075a5 (diff) | |
download | reference_model-f36f25619cc3a34c75e78637ed244a2ca54ab3f4.tar.gz |
[ref model] Add acc_type to Conv Ops
This patch implements changes required by the new acc_type field in
ConvAttribute and TransposeConvAttribute
Signed-off-by: Tai Ly <tai.ly@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 85 |
1 files changed, 44 insertions, 41 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 0f0013c..74315d7 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -70,41 +70,43 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16); break; case Op_CONV2D: - DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_CONV3D: - DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_DEPTHWISE_CONV2D: - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_FFT2D: DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); @@ -148,16 +150,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64); break; case Op_TRANSPOSE_CONV2D: - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E5M2, FP8E5M2, FP16, FP16); break; // activation_funcs |