aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r--reference_model/src/ops/op_factory.cc85
1 files changed, 44 insertions, 41 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 0f0013c..74315d7 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -70,41 +70,43 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16);
break;
case Op_CONV2D:
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
case Op_CONV3D:
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
case Op_DEPTHWISE_CONV2D:
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
case Op_FFT2D:
DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
@@ -148,16 +150,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64);
break;
case Op_TRANSPOSE_CONV2D:
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
- DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
+ // OP, attr_name, in_t, w_t, acc_t, out_t
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32, BF16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32, INT32);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48, INT48);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP64, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E4M3, FP8E4M3, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E5M2, FP8E5M2, FP16, FP16);
break;
// activation_funcs