aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/op_factory.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/op_factory.cc')
-rw-r--r--reference_model/src/ops/op_factory.cc48
1 files changed, 48 insertions, 0 deletions
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 1ff8229..0121ccf 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -49,6 +49,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// tensor_ops
case Op_ARGMAX:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
@@ -56,6 +57,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_AVG_POOL2D:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, BF16, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
@@ -63,6 +65,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_CONV2D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32);
@@ -71,6 +74,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_CONV3D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32);
@@ -79,6 +83,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32);
@@ -87,6 +92,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_FULLY_CONNECTED:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, BF16, BF16, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32);
@@ -95,12 +101,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_MATMUL:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, BF16, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, BF16);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
@@ -108,6 +116,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_TRANSPOSE_CONV2D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32);
@@ -117,22 +126,26 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// activation_funcs
case Op_CLAMP:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
break;
case Op_SIGMOID:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
break;
case Op_TANH:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
break;
// ewise_binary
case Op_ADD:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
break;
@@ -180,16 +193,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_MAXIMUM:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
break;
case Op_MINIMUM:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
break;
case Op_MUL:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
@@ -197,10 +213,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_POW:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
break;
case Op_SUB:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
break;
@@ -212,6 +230,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// ewise_unary
case Op_ABS:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
break;
@@ -222,6 +241,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_CEIL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
break;
case Op_CLZ:
@@ -229,14 +249,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_EXP:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
break;
case Op_FLOOR:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
break;
case Op_LOG:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
break;
case Op_LOGICAL_NOT:
@@ -244,6 +267,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_NEGATE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
@@ -251,16 +275,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_RECIPROCAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
break;
case Op_RSQRT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
break;
// ewise_ternary
case Op_SELECT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
@@ -271,16 +298,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// comparison
case Op_EQUAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
break;
case Op_GREATER:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
break;
case Op_GREATER_EQUAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
break;
@@ -294,6 +324,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_REDUCE_MAX:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
@@ -301,6 +332,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_REDUCE_MIN:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
@@ -308,10 +340,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_REDUCE_PRODUCT:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
break;
case Op_REDUCE_SUM:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
break;
@@ -319,6 +353,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// data layout
case Op_CONCAT:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
@@ -327,6 +362,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_PAD:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
@@ -335,6 +371,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_RESHAPE:
DEF_FACTORY_RESHAPE(OpReshape, FP16);
+ DEF_FACTORY_RESHAPE(OpReshape, BF16);
DEF_FACTORY_RESHAPE(OpReshape, FP32);
DEF_FACTORY_RESHAPE(OpReshape, INT8);
DEF_FACTORY_RESHAPE(OpReshape, INT16);
@@ -343,6 +380,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_REVERSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
@@ -351,6 +389,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_SLICE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
@@ -359,6 +398,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_TILE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
@@ -368,6 +408,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_TRANSPOSE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
@@ -380,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, INT16);
DEF_FACTORY_ONE_TYPE(OpGather, INT32);
DEF_FACTORY_ONE_TYPE(OpGather, FP16);
+ DEF_FACTORY_ONE_TYPE(OpGather, BF16);
DEF_FACTORY_ONE_TYPE(OpGather, FP32);
break;
case Op_SCATTER:
@@ -387,6 +429,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpScatter, INT16);
DEF_FACTORY_ONE_TYPE(OpScatter, INT32);
DEF_FACTORY_ONE_TYPE(OpScatter, FP16);
+ DEF_FACTORY_ONE_TYPE(OpScatter, BF16);
DEF_FACTORY_ONE_TYPE(OpScatter, FP32);
break;
@@ -397,6 +440,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48);
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16);
DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16);
DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32);
break;
@@ -405,6 +449,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
return new OpConst(sgt, id);
case Op_IDENTITY:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
@@ -435,6 +480,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);