diff options
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 85 |
1 files changed, 44 insertions, 41 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 0f0013c..74315d7 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -70,41 +70,43 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16); break; case Op_CONV2D: - DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_CONV3D: - DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_DEPTHWISE_CONV2D: - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_FFT2D: DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); @@ -148,16 +150,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64); break; case Op_TRANSPOSE_CONV2D: - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E5M2, FP8E5M2, FP16, FP16); break; // activation_funcs |