aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r--reference_model/src/ops/op_factory.cc112
1 files changed, 85 insertions, 27 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index b6a2e15..fd73eb5 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2021, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -48,71 +48,91 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
{
// tensor_ops
case Op_ARGMAX:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
break;
case Op_AVG_POOL2D:
- DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT);
- DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT8);
- DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FLOAT);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FLOAT, FLOAT);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
break;
case Op_CONV2D:
- DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpConv2d, INT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48);
break;
case Op_CONV3D:
- DEF_FACTORY_TWO_TYPE(OpConv3d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpConv3d, INT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpConv3d, INT16, INT8);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48);
break;
case Op_DEPTHWISE_CONV2D:
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48);
break;
case Op_FULLY_CONNECTED:
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48);
break;
case Op_MATMUL:
- DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT);
- DEF_FACTORY_ONE_TYPE(OpMatMul, INT8);
- DEF_FACTORY_ONE_TYPE(OpMatMul, INT16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FLOAT);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FLOAT, FLOAT);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48);
break;
case Op_MAX_POOL2D:
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
break;
case Op_TRANSPOSE_CONV2D:
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT4);
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT8, INT8);
- DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48);
break;
// activation_funcs
case Op_CLAMP:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
break;
case Op_SIGMOID:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
break;
case Op_TANH:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
break;
// ewise_binary
case Op_ADD:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
break;
@@ -159,23 +179,28 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
break;
case Op_MAXIMUM:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
break;
case Op_MINIMUM:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
break;
case Op_MUL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
break;
case Op_POW:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
break;
case Op_SUB:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
break;
@@ -186,6 +211,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// ewise_unary
case Op_ABS:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
break;
@@ -195,38 +221,46 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
break;
case Op_CEIL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
break;
case Op_CLZ:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
break;
case Op_EXP:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
break;
case Op_FLOOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
break;
case Op_LOG:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
break;
case Op_LOGICAL_NOT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
break;
case Op_NEGATE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
break;
case Op_RECIPROCAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
break;
case Op_RSQRT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
break;
// ewise_ternary
case Op_SELECT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
@@ -236,14 +270,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// comparison
case Op_EQUAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
break;
case Op_GREATER:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
break;
case Op_GREATER_EQUAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
break;
@@ -256,27 +293,32 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
break;
case Op_REDUCE_MAX:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
break;
case Op_REDUCE_MIN:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
break;
case Op_REDUCE_PRODUCT:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
break;
case Op_REDUCE_SUM:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
break;
// data layout
case Op_CONCAT:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
@@ -284,6 +326,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
break;
case Op_PAD:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
@@ -291,6 +334,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
break;
case Op_RESHAPE:
+ DEF_FACTORY_RESHAPE(OpReshape, FP16);
DEF_FACTORY_RESHAPE(OpReshape, FLOAT);
DEF_FACTORY_RESHAPE(OpReshape, INT8);
DEF_FACTORY_RESHAPE(OpReshape, INT16);
@@ -298,6 +342,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RESHAPE(OpReshape, BOOL);
break;
case Op_REVERSE:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
@@ -305,6 +350,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
break;
case Op_SLICE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
@@ -312,6 +358,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
break;
case Op_TILE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
@@ -320,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_TRANSPOSE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
@@ -331,12 +379,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, INT8);
DEF_FACTORY_ONE_TYPE(OpGather, INT16);
DEF_FACTORY_ONE_TYPE(OpGather, INT32);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP16);
DEF_FACTORY_ONE_TYPE(OpGather, FLOAT);
break;
case Op_SCATTER:
DEF_FACTORY_ONE_TYPE(OpScatter, INT8);
DEF_FACTORY_ONE_TYPE(OpScatter, INT16);
DEF_FACTORY_ONE_TYPE(OpScatter, INT32);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP16);
DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT);
break;
@@ -346,6 +396,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT8, INT8);
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48);
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16);
+ DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16);
DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OpResize, FLOAT, FLOAT);
break;
@@ -353,6 +404,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_CONST:
return new OpConst(sgt, id);
case Op_IDENTITY:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
@@ -368,15 +420,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);