diff options
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 82 |
1 files changed, 41 insertions, 41 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 76cf666..b1a405a 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -63,48 +63,48 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); break; case Op_CONV2D: - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32); + DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32); + DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); + DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); break; case Op_CONV3D: - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32); + DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32); + DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); + DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); break; case Op_DEPTHWISE_CONV2D: - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); break; case Op_FULLY_CONNECTED: - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, BF16, BF16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48); break; case Op_MATMUL: - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, BF16, FP32); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP16); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP32); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, BF16, FP32); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP32, FP32); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); @@ -117,13 +117,13 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32); break; case Op_TRANSPOSE_CONV2D: - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); break; // activation_funcs |