From 24dbc420aae556649f50e645bd94489dab2cc75a Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 19 Oct 2022 12:20:31 +0100 Subject: Add BF16 support to reference model * Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward Signed-off-by: Jeremy Johnson Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe --- reference_model/src/ops/op_factory.cc | 48 +++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) (limited to 'reference_model/src/ops/op_factory.cc') diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 1ff8229..0121ccf 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -49,6 +49,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // tensor_ops case Op_ARGMAX: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); @@ -56,6 +57,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_AVG_POOL2D: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, BF16, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); @@ -63,6 +65,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32); @@ -71,6 +74,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CONV3D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32); @@ -79,6 +83,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_DEPTHWISE_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32); @@ -87,6 +92,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_FULLY_CONNECTED: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32); @@ -95,12 +101,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_MATMUL: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, BF16, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, BF16); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); @@ -108,6 +116,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_TRANSPOSE_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32); @@ -117,22 +126,26 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // activation_funcs case Op_CLAMP: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); break; case Op_SIGMOID: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); break; case Op_TANH: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); break; // ewise_binary case Op_ADD: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); break; @@ -180,16 +193,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_MAXIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); break; case Op_MINIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); break; case Op_MUL: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); @@ -197,10 +213,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_POW: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); break; case Op_SUB: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); break; @@ -212,6 +230,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // ewise_unary case Op_ABS: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); break; @@ -222,6 +241,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_CEIL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); break; case Op_CLZ: @@ -229,14 +249,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_EXP: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); break; case Op_FLOOR: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); break; case Op_LOG: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); break; case Op_LOGICAL_NOT: @@ -244,6 +267,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_NEGATE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); @@ -251,16 +275,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_RECIPROCAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); break; case Op_RSQRT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); break; // ewise_ternary case Op_SELECT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); @@ -271,16 +298,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // comparison case Op_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); break; case Op_GREATER: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); break; case Op_GREATER_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); break; @@ -294,6 +324,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REDUCE_MAX: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); @@ -301,6 +332,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REDUCE_MIN: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); @@ -308,10 +340,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REDUCE_PRODUCT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); break; case Op_REDUCE_SUM: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); break; @@ -319,6 +353,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // data layout case Op_CONCAT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); @@ -327,6 +362,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_PAD: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); @@ -335,6 +371,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_RESHAPE: DEF_FACTORY_RESHAPE(OpReshape, FP16); + DEF_FACTORY_RESHAPE(OpReshape, BF16); DEF_FACTORY_RESHAPE(OpReshape, FP32); DEF_FACTORY_RESHAPE(OpReshape, INT8); DEF_FACTORY_RESHAPE(OpReshape, INT16); @@ -343,6 +380,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REVERSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); @@ -351,6 +389,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_SLICE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); @@ -359,6 +398,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_TILE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); @@ -368,6 +408,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_TRANSPOSE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); @@ -380,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, INT16); DEF_FACTORY_ONE_TYPE(OpGather, INT32); DEF_FACTORY_ONE_TYPE(OpGather, FP16); + DEF_FACTORY_ONE_TYPE(OpGather, BF16); DEF_FACTORY_ONE_TYPE(OpGather, FP32); break; case Op_SCATTER: @@ -387,6 +429,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpScatter, INT16); DEF_FACTORY_ONE_TYPE(OpScatter, INT32); DEF_FACTORY_ONE_TYPE(OpScatter, FP16); + DEF_FACTORY_ONE_TYPE(OpScatter, BF16); DEF_FACTORY_ONE_TYPE(OpScatter, FP32); break; @@ -397,6 +440,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48); DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16); DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16); + DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16); DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32); break; @@ -405,6 +449,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, return new OpConst(sgt, id); case Op_IDENTITY: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); @@ -435,6 +480,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); -- cgit v1.2.1