aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r--reference_model/src/ops/op_factory.cc82
1 files changed, 41 insertions, 41 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 76cf666..b1a405a 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -63,48 +63,48 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
break;
case Op_CONV2D:
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
break;
case Op_CONV3D:
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
break;
case Op_DEPTHWISE_CONV2D:
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
break;
case Op_FULLY_CONNECTED:
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, BF16, BF16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
break;
case Op_MATMUL:
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, BF16, FP32);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, BF16, FP32);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP32, FP32);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
@@ -117,13 +117,13 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32);
break;
case Op_TRANSPOSE_CONV2D:
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
break;
// activation_funcs