From 24dbc420aae556649f50e645bd94489dab2cc75a Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 19 Oct 2022 12:20:31 +0100 Subject: Add BF16 support to reference model * Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward Signed-off-by: Jeremy Johnson Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe --- reference_model/src/ops/activation_funcs.cc | 13 ++++++-- reference_model/src/ops/comparison.cc | 6 ++++ reference_model/src/ops/data_layout.cc | 7 ++++ reference_model/src/ops/data_nodes.cc | 1 + reference_model/src/ops/ewise_binary.cc | 21 +++++++++--- reference_model/src/ops/ewise_ternary.cc | 1 + reference_model/src/ops/ewise_unary.cc | 36 +++++++++++++++------ reference_model/src/ops/image.cc | 29 ++++++++++++----- reference_model/src/ops/op_factory.cc | 48 +++++++++++++++++++++++++++ reference_model/src/ops/op_factory.h | 6 ++++ reference_model/src/ops/reduction.cc | 50 ++++++++++++++++++++++++++--- reference_model/src/ops/scatter_gather.cc | 2 ++ reference_model/src/ops/template_types.h | 14 ++++---- reference_model/src/ops/tensor_ops.cc | 18 ++++++++--- reference_model/src/ops/type_conversion.cc | 3 ++ 15 files changed, 217 insertions(+), 38 deletions(-) (limited to 'reference_model/src/ops') diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index 61f7df6..46234e2 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -16,6 +16,7 @@ #include "activation_funcs.h" #include "quant_util.h" #include "template_types.h" +#include "arith_util.h" #include using namespace TosaReference; @@ -28,13 +29,14 @@ int OpClamp::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: { InEigenType min = (InEigenType)attribute->min_fp(); InEigenType max = (InEigenType)attribute->max_fp(); ERROR_IF(max < min, "OpClamp: max smaller than min"); - this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; }; + this->fcn = [min, max](InEigenType a) -> OutEigenType { return fpTrunc(a <= min ? min : a >= max ? max : a); }; } break; case DType_INT8: @@ -59,8 +61,9 @@ int OpSigmoid::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.0 / (1.0 + (expf(-1.0 * a)))); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -75,8 +78,9 @@ int OpTanh::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(tanhf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -87,12 +91,15 @@ int OpTanh::register_fcn() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc index f240aa5..5b78a4f 100644 --- a/reference_model/src/ops/comparison.cc +++ b/reference_model/src/ops/comparison.cc @@ -28,6 +28,7 @@ int OpEqual::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; @@ -45,6 +46,7 @@ int OpGreater::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; @@ -62,6 +64,7 @@ int OpGreaterEqual::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; @@ -75,13 +78,16 @@ int OpGreaterEqual::register_fcn() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 69b6a65..bffd659 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -639,6 +639,7 @@ int OpTranspose::eval() // template explicit instantiation DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) @@ -646,6 +647,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); @@ -653,6 +655,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); +DEF_INSTANTIATE_RESHAPE(OpReshape, BF16); DEF_INSTANTIATE_RESHAPE(OpReshape, FP32); DEF_INSTANTIATE_RESHAPE(OpReshape, INT8); DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); @@ -660,6 +663,7 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); @@ -667,6 +671,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); @@ -674,6 +679,7 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); @@ -681,6 +687,7 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index 5709a92..f5304a5 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -90,6 +90,7 @@ int OpIdentity::eval() // note OpConst is not templated DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 098b0ea..e4c0ee0 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -143,8 +143,9 @@ int OpAdd::register_fcn() }; break; case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a + b); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); @@ -371,6 +372,7 @@ int OpMaximum::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; @@ -388,6 +390,7 @@ int OpMinimum::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; @@ -407,8 +410,9 @@ int OpMul::register_fcn() switch (InDtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a * b); }; break; case DType_INT32: this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { @@ -457,8 +461,9 @@ int OpPow::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(powf(a, b)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -482,8 +487,9 @@ int OpSub::register_fcn() }; break; case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a - b); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); @@ -581,6 +587,7 @@ int OpTable::eval() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); @@ -617,23 +624,28 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); @@ -643,5 +655,6 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); // Instantiation of nodes for comparison operators opEqual, opGreater // and opGreaterEqual DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index d85da1a..677a4e2 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -108,6 +108,7 @@ int OpSelect<0, Dtype>::eval() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 00897cc..5347b8c 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -78,11 +78,14 @@ int OpAbs::register_fcn() { switch (Dtype) { - case DType_FP32: - case DType_FP16: + case DType_FP32: // No fpTrunc for FP32 as it is a no-op case DType_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; + case DType_FP16: + case DType_BF16: + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(a > (InEigenType)0 ? a : (-a)); }; + break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } @@ -113,8 +116,9 @@ int OpCeil::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(ceilf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -161,8 +165,9 @@ int OpExp::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(expf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -177,8 +182,9 @@ int OpFloor::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(floorf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -193,8 +199,9 @@ int OpLog::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(logf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -245,10 +252,11 @@ int OpNegate::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); - return result; + return fpTrunc(result); }; break; case DType_INT16: @@ -297,8 +305,9 @@ int OpReciprocal::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.0 / a); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -313,8 +322,9 @@ int OpRsqrt::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.0 / sqrtf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -325,6 +335,7 @@ int OpRsqrt::register_fcn() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); @@ -333,29 +344,36 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index cf1d9f7..66efee0 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -63,7 +63,7 @@ int OpResize::checkTensorAttributes() if (this->mode == ResizeMode_BILINEAR) { - if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16) + if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; @@ -71,7 +71,7 @@ int OpResize::checkTensorAttributes() } else { - if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16) + if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; @@ -159,15 +159,15 @@ int OpResize::eval() resize_t dy; resize_t dx; - if (std::is_floating_point::value) + if (std::is_floating_point::value || (typeid(resize_t) == typeid(Eigen::bfloat16))) { - dy = fy - iy; - dx = fx - ix; + dy = (resize_t)(fy - iy); + dx = (resize_t)(fx - ix); } else { - dy = y - (iy * scale_y_n); - dx = x - (ix * scale_x_n); + dy = (resize_t)(y - (iy * scale_y_n)); + dx = (resize_t)(x - (ix * scale_x_n)); } int32_t iy0 = MAX(iy, 0); @@ -190,6 +190,15 @@ int OpResize::eval() acc += (OutEigenType)v10 * dy * (1.0 - dx); acc += (OutEigenType)v11 * dy * dx; } + else if ((typeid(resize_t) == typeid(Eigen::bfloat16))) + { + Eigen::bfloat16 bf16_acc; + bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx); + bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx; + bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx); + bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx; + acc = (float)bf16_acc; + } else { acc = (OutEigenType)v00 * (scale_y_n - dy) * (scale_x_n - dx); @@ -201,7 +210,7 @@ int OpResize::eval() else { ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode"); - if (std::is_floating_point::value) + if (std::is_floating_point::value || (typeid(resize_t) == typeid(Eigen::bfloat16))) { iy = (dy >= 0.5) ? iy1 : iy0; ix = (dx >= 0.5) ? ix1 : ix0; @@ -213,6 +222,9 @@ int OpResize::eval() } acc = in->getTensor()(b, iy, ix, c); } + if ((typeid(resize_t) == typeid(Eigen::bfloat16))) { + ASSERT_MSG(checkValidBFloat(acc), "Resize accumulator float value is not a valid bfloat16 value."); + } out->getTensor()(b, oy, ox, c) = acc; } @@ -225,4 +237,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float); +DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16); DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 1ff8229..0121ccf 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -49,6 +49,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // tensor_ops case Op_ARGMAX: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); @@ -56,6 +57,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_AVG_POOL2D: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, BF16, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); @@ -63,6 +65,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32); @@ -71,6 +74,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CONV3D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32); @@ -79,6 +83,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_DEPTHWISE_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32); @@ -87,6 +92,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_FULLY_CONNECTED: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32); @@ -95,12 +101,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_MATMUL: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, BF16, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, BF16); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); @@ -108,6 +116,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_TRANSPOSE_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32); @@ -117,22 +126,26 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // activation_funcs case Op_CLAMP: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); break; case Op_SIGMOID: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); break; case Op_TANH: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); break; // ewise_binary case Op_ADD: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); break; @@ -180,16 +193,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_MAXIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); break; case Op_MINIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); break; case Op_MUL: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); @@ -197,10 +213,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_POW: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); break; case Op_SUB: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); break; @@ -212,6 +230,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // ewise_unary case Op_ABS: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); break; @@ -222,6 +241,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_CEIL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); break; case Op_CLZ: @@ -229,14 +249,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_EXP: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); break; case Op_FLOOR: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); break; case Op_LOG: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); break; case Op_LOGICAL_NOT: @@ -244,6 +267,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_NEGATE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); @@ -251,16 +275,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_RECIPROCAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); break; case Op_RSQRT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); break; // ewise_ternary case Op_SELECT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); @@ -271,16 +298,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // comparison case Op_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); break; case Op_GREATER: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); break; case Op_GREATER_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); break; @@ -294,6 +324,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REDUCE_MAX: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); @@ -301,6 +332,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REDUCE_MIN: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); @@ -308,10 +340,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REDUCE_PRODUCT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); break; case Op_REDUCE_SUM: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); break; @@ -319,6 +353,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // data layout case Op_CONCAT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); @@ -327,6 +362,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_PAD: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); @@ -335,6 +371,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_RESHAPE: DEF_FACTORY_RESHAPE(OpReshape, FP16); + DEF_FACTORY_RESHAPE(OpReshape, BF16); DEF_FACTORY_RESHAPE(OpReshape, FP32); DEF_FACTORY_RESHAPE(OpReshape, INT8); DEF_FACTORY_RESHAPE(OpReshape, INT16); @@ -343,6 +380,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REVERSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); @@ -351,6 +389,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_SLICE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); @@ -359,6 +398,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_TILE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); @@ -368,6 +408,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_TRANSPOSE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); @@ -380,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, INT16); DEF_FACTORY_ONE_TYPE(OpGather, INT32); DEF_FACTORY_ONE_TYPE(OpGather, FP16); + DEF_FACTORY_ONE_TYPE(OpGather, BF16); DEF_FACTORY_ONE_TYPE(OpGather, FP32); break; case Op_SCATTER: @@ -387,6 +429,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpScatter, INT16); DEF_FACTORY_ONE_TYPE(OpScatter, INT32); DEF_FACTORY_ONE_TYPE(OpScatter, FP16); + DEF_FACTORY_ONE_TYPE(OpScatter, BF16); DEF_FACTORY_ONE_TYPE(OpScatter, FP32); break; @@ -397,6 +440,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48); DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16); DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16); + DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16); DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32); break; @@ -405,6 +449,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, return new OpConst(sgt, id); case Op_IDENTITY: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); @@ -435,6 +480,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index b525e69..f399bd1 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -111,6 +111,12 @@ return new OP(sgt, attribute, id); \ } +#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + { \ + return new OP(sgt, attribute, id); \ + } + #define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index eccba09..cd9d55f 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -80,10 +80,30 @@ int ReduceNode::checkTensorAttributes() return 0; } +// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0 +// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150 +// which seems to be due to incorrect data being passed internally as m_impl +struct AllReducer { + static const bool PacketAccess = false; + void reduce(const bool val, bool* accum) { + *accum = *accum && val; + } + bool initialize() const { return true; } + bool finalize(const bool accum) const { return accum; } +}; +struct AnyReducer { + static const bool PacketAccess = false; + void reduce(const bool val, bool* accum) { + *accum = *accum || val; + } + bool initialize() const { return false; } + bool finalize(const bool accum) const { return accum; } +}; + template int OpReduceAll::eval() { - this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions()); + this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions()); return GraphNode::eval(); } @@ -91,7 +111,7 @@ int OpReduceAll::eval() template int OpReduceAny::eval() { - this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions()); + this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions()); return GraphNode::eval(); } @@ -115,7 +135,16 @@ int OpReduceMin::eval() template int OpReduceProduct::eval() { - this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); + switch(Dtype) + { + case DType_FP16: + case DType_BF16: + this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc(f);}); + break; + default: + this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); + break; + } return GraphNode::eval(); } @@ -123,7 +152,16 @@ int OpReduceProduct::eval() template int OpReduceSum::eval() { - this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); + switch(Dtype) + { + case DType_FP16: + case DType_BF16: + this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc(f);}); + break; + default: + this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); + break; + } return GraphNode::eval(); } @@ -159,20 +197,24 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index b6c4043..bcd8ce5 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -227,10 +227,12 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8); DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16); DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); +DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32); diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 3de4899..647ca84 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -19,6 +19,8 @@ #include "tosa_generated.h" #include #include "half.hpp" +#include +#include "arith_util.h" using namespace tosa; @@ -76,6 +78,12 @@ struct GetEigenType using type = float; }; template <> +struct GetEigenType +{ + // NOTE: full precision used + using type = float; +}; +template <> struct GetEigenType { using type = int32_t; @@ -132,12 +140,6 @@ struct GetAccEigenType using type = typename GetEigenType::type; }; -template -struct GetHalfEigenType -{ - using type = half_float::half; -}; - // Meta function to get number of bits template struct GetNumBits diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 7db5182..b9ac94a 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -507,12 +507,13 @@ int OpAvgPool2d::eval() Eigen::array, 1> contract_dims = { Eigen::IndexPair(1, 0) }; Eigen::array bcast{ out_batch, 1, 1, out_channels }; + ETensor2 dm2_w = div_map_w.reshape(Eigen::array{ 1, out_width }); + ETensor2 dm2_h = div_map_h.reshape(Eigen::array{ out_height, 1 }); ETensor4 div_map = - div_map_h.reshape(Eigen::array{ out_height, 1 }) - .contract(div_map_w.reshape(Eigen::array{ 1, out_width }), contract_dims) + dm2_h.contract(dm2_w, contract_dims) .reshape(Eigen::array{ 1, out_height, out_width, 1 }) .broadcast(bcast); - if (Dtype != DType_FP32 && Dtype != DType_FP16) + if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16) { try { @@ -533,7 +534,7 @@ int OpAvgPool2d::eval() } else { - // Case for float-type resizes + // Case for float-types this->out->getTensor() = (sum / div_map.template cast()).template cast(); } @@ -1679,12 +1680,14 @@ int OpTransposeConv2d::eval() // template explicit instantiation DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP32); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, BF16, FP32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP32, FP32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32); @@ -1692,6 +1695,7 @@ DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32); // [in_t, weight_t, acc_t] DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, BF16, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32); @@ -1699,6 +1703,7 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, BF16, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32); @@ -1706,6 +1711,7 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, BF16, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32); @@ -1713,6 +1719,7 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, BF16, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32); @@ -1722,15 +1729,18 @@ DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP32); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, BF16, FP32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP32, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32); diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index f51c38c..e30c7bd 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -353,6 +353,9 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); -- cgit v1.2.1