diff options
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 112 |
1 files changed, 85 insertions, 27 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index b6a2e15..fd73eb5 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2021, ARM Limited. +// Copyright (c) 2020-2022, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -48,71 +48,91 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, { // tensor_ops case Op_ARGMAX: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); break; case Op_AVG_POOL2D: - DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT); - DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT8); - DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FLOAT, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); break; case Op_CONV2D: - DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48); break; case Op_CONV3D: - DEF_FACTORY_TWO_TYPE(OpConv3d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpConv3d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48); break; case Op_DEPTHWISE_CONV2D: - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48); break; case Op_FULLY_CONNECTED: - DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48); break; case Op_MATMUL: - DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT); - DEF_FACTORY_ONE_TYPE(OpMatMul, INT8); - DEF_FACTORY_ONE_TYPE(OpMatMul, INT16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FLOAT, FLOAT); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48); break; case Op_MAX_POOL2D: + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); break; case Op_TRANSPOSE_CONV2D: - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT); - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT4); - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT8); - DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FLOAT, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48); break; // activation_funcs case Op_CLAMP: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); break; case Op_SIGMOID: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); break; case Op_TANH: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); break; // ewise_binary case Op_ADD: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); break; @@ -159,23 +179,28 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); break; case Op_MAXIMUM: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); break; case Op_MINIMUM: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); break; case Op_MUL: + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); break; case Op_POW: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); break; case Op_SUB: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); break; @@ -186,6 +211,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // ewise_unary case Op_ABS: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); break; @@ -195,38 +221,46 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); break; case Op_CEIL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); break; case Op_CLZ: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); break; case Op_EXP: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); break; case Op_FLOOR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); break; case Op_LOG: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); break; case Op_LOGICAL_NOT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); break; case Op_NEGATE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); break; case Op_RECIPROCAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); break; case Op_RSQRT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); break; // ewise_ternary case Op_SELECT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); @@ -236,14 +270,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // comparison case Op_EQUAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); break; case Op_GREATER: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); break; case Op_GREATER_EQUAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); break; @@ -256,27 +293,32 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); break; case Op_REDUCE_MAX: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); break; case Op_REDUCE_MIN: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); break; case Op_REDUCE_PRODUCT: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); break; case Op_REDUCE_SUM: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); break; // data layout case Op_CONCAT: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); @@ -284,6 +326,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); break; case Op_PAD: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); @@ -291,6 +334,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); break; case Op_RESHAPE: + DEF_FACTORY_RESHAPE(OpReshape, FP16); DEF_FACTORY_RESHAPE(OpReshape, FLOAT); DEF_FACTORY_RESHAPE(OpReshape, INT8); DEF_FACTORY_RESHAPE(OpReshape, INT16); @@ -298,6 +342,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RESHAPE(OpReshape, BOOL); break; case Op_REVERSE: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); @@ -305,6 +350,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); break; case Op_SLICE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); @@ -312,6 +358,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); break; case Op_TILE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); @@ -320,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_TRANSPOSE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); @@ -331,12 +379,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, INT8); DEF_FACTORY_ONE_TYPE(OpGather, INT16); DEF_FACTORY_ONE_TYPE(OpGather, INT32); + DEF_FACTORY_ONE_TYPE(OpGather, FP16); DEF_FACTORY_ONE_TYPE(OpGather, FLOAT); break; case Op_SCATTER: DEF_FACTORY_ONE_TYPE(OpScatter, INT8); DEF_FACTORY_ONE_TYPE(OpScatter, INT16); DEF_FACTORY_ONE_TYPE(OpScatter, INT32); + DEF_FACTORY_ONE_TYPE(OpScatter, FP16); DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT); break; @@ -346,6 +396,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT8, INT8); DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48); DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16); + DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16); DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OpResize, FLOAT, FLOAT); break; @@ -353,6 +404,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CONST: return new OpConst(sgt, id); case Op_IDENTITY: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); @@ -368,15 +420,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); |