From 24dbc420aae556649f50e645bd94489dab2cc75a Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 19 Oct 2022 12:20:31 +0100 Subject: Add BF16 support to reference model * Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward Signed-off-by: Jeremy Johnson Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe --- reference_model/include/func_config.h | 1 + reference_model/src/arith_util.h | 89 +++++++++++++++++++++++++++++ reference_model/src/main.cpp | 3 + reference_model/src/ops/activation_funcs.cc | 13 ++++- reference_model/src/ops/comparison.cc | 6 ++ reference_model/src/ops/data_layout.cc | 7 +++ reference_model/src/ops/data_nodes.cc | 1 + reference_model/src/ops/ewise_binary.cc | 21 +++++-- reference_model/src/ops/ewise_ternary.cc | 1 + reference_model/src/ops/ewise_unary.cc | 36 +++++++++--- reference_model/src/ops/image.cc | 29 +++++++--- reference_model/src/ops/op_factory.cc | 48 ++++++++++++++++ reference_model/src/ops/op_factory.h | 6 ++ reference_model/src/ops/reduction.cc | 50 ++++++++++++++-- reference_model/src/ops/scatter_gather.cc | 2 + reference_model/src/ops/template_types.h | 14 +++-- reference_model/src/ops/tensor_ops.cc | 18 ++++-- reference_model/src/ops/type_conversion.cc | 3 + reference_model/src/subgraph_traverser.cc | 11 ++++ reference_model/src/tensor.cc | 36 +++++++++--- reference_model/src/tensor.h | 1 + thirdparty/eigen | 2 +- thirdparty/serialization_lib | 2 +- verif/checker/tosa_result_checker.py | 22 ++++++- verif/generator/tosa_arg_gen.py | 10 +++- verif/generator/tosa_error_if.py | 35 ++++++++++-- verif/generator/tosa_test_gen.py | 80 ++++++++++++++++++++++---- verif/generator/tosa_utils.py | 45 ++++++++++++++- verif/generator/tosa_verif_build_tests.py | 4 +- verif/tests/test_tosa_refmodel.py | 16 +++++- 30 files changed, 544 insertions(+), 68 deletions(-) diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h index 41df135..d9b51d5 100644 --- a/reference_model/include/func_config.h +++ b/reference_model/include/func_config.h @@ -36,6 +36,7 @@ struct func_config_t uint32_t tosa_profile = 1; uint32_t dump_intermediates = 0; std::string fp_format = "0.5"; + bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian() }; #endif diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h index 554a7a2..33bdeed 100644 --- a/reference_model/src/arith_util.h +++ b/reference_model/src/arith_util.h @@ -31,13 +31,18 @@ #include #define __STDC_LIMIT_MACROS //enable min/max of plain data type #include "func_debug.h" +#include "func_config.h" #include "inttypes.h" +#include "tosa_generated.h" #include #include #include #include #include +#include +#include +using namespace tosa; using namespace std; inline size_t _count_one(uint64_t val) @@ -191,4 +196,88 @@ constexpr T saturate(const uint32_t width, const intmax_t value) // clang-format on } +inline void float_trunc_bytes(float* src) +{ + /* Set the least significant two bytes to zero for the input float value.*/ + char src_as_bytes[sizeof(float)]; + memcpy(src_as_bytes, src, sizeof(float)); + + if (g_func_config.float_is_big_endian) + { + src_as_bytes[2] = '\000'; + src_as_bytes[3] = '\000'; + } + else + { + src_as_bytes[0] = '\000'; + src_as_bytes[1] = '\000'; + } + + memcpy(src, &src_as_bytes, sizeof(float)); +} + +inline void truncateFloatToBFloat(float* src, int64_t size) { + /* Set the least significant two bytes to zero for each float + value in the input src buffer. */ + ASSERT_MEM(src); + ASSERT_MSG(size > 0, "Size of src (representing number of values in src) must be a positive integer."); + for (; size != 0; src++, size--) + { + float_trunc_bytes(src); + } +} + +inline bool checkValidBFloat(float src) +{ + /* Checks if the least significant two bytes are zero. */ + ASSERT_MEM(src); + char src_as_bytes[sizeof(float)]; + memcpy(src_as_bytes, &src, sizeof(float)); + + if (g_func_config.float_is_big_endian) + { + return (src_as_bytes[2] == '\000' && src_as_bytes[3] == '\000'); + } + else + { + return (src_as_bytes[0] == '\000' && src_as_bytes[1] == '\000'); + } +} + +inline bool float_is_big_endian() +{ + /* Compares float values 1.0 and -1.0 by checking whether the + negation causes the first or the last byte to change. + First byte changing would indicate the float representation + is big-endian.*/ + float f = 1.0; + char f_as_bytes[sizeof(float)]; + memcpy(f_as_bytes, &f, sizeof(float)); + f = -f; + char f_neg_as_bytes[sizeof(float)]; + memcpy(f_neg_as_bytes, &f, sizeof(float)); + return f_as_bytes[0] != f_neg_as_bytes[0]; +} + +template +float fpTrunc(float f_in) +{ + /* Truncates a float value based on the DType it represents.*/ + switch (Dtype) + { + case DType_BF16: + truncateFloatToBFloat(&f_in, 1); + break; + case DType_FP16: + // TODO(jw): implement FP16 truncate function (no-op placeholder for now) + break; + case DType_FP32: + // No-op for fp32 + break; + default: + ASSERT_MSG(false, "DType %s should not be float-truncated.", EnumNameDType(Dtype)); + } + return f_in; +} + #endif /* _ARITH_UTIL_H */ diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 776fbf3..5c2735d 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -20,6 +20,7 @@ #include "ops/op_factory.h" #include "subgraph_traverser.h" #include "tosa_serialization_handler.h" +#include "arith_util.h" #include #include @@ -67,6 +68,8 @@ int main(int argc, char** argv) return TOSA_VERSION_MISMATCH; } + g_func_config.float_is_big_endian = float_is_big_endian(); + json test_desc; // Initialize test descriptor diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index 61f7df6..46234e2 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -16,6 +16,7 @@ #include "activation_funcs.h" #include "quant_util.h" #include "template_types.h" +#include "arith_util.h" #include using namespace TosaReference; @@ -28,13 +29,14 @@ int OpClamp::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: { InEigenType min = (InEigenType)attribute->min_fp(); InEigenType max = (InEigenType)attribute->max_fp(); ERROR_IF(max < min, "OpClamp: max smaller than min"); - this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; }; + this->fcn = [min, max](InEigenType a) -> OutEigenType { return fpTrunc(a <= min ? min : a >= max ? max : a); }; } break; case DType_INT8: @@ -59,8 +61,9 @@ int OpSigmoid::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.0 / (1.0 + (expf(-1.0 * a)))); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -75,8 +78,9 @@ int OpTanh::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(tanhf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -87,12 +91,15 @@ int OpTanh::register_fcn() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc index f240aa5..5b78a4f 100644 --- a/reference_model/src/ops/comparison.cc +++ b/reference_model/src/ops/comparison.cc @@ -28,6 +28,7 @@ int OpEqual::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; @@ -45,6 +46,7 @@ int OpGreater::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; @@ -62,6 +64,7 @@ int OpGreaterEqual::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; @@ -75,13 +78,16 @@ int OpGreaterEqual::register_fcn() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index 69b6a65..bffd659 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -639,6 +639,7 @@ int OpTranspose::eval() // template explicit instantiation DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) @@ -646,6 +647,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); @@ -653,6 +655,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); +DEF_INSTANTIATE_RESHAPE(OpReshape, BF16); DEF_INSTANTIATE_RESHAPE(OpReshape, FP32); DEF_INSTANTIATE_RESHAPE(OpReshape, INT8); DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); @@ -660,6 +663,7 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); @@ -667,6 +671,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); @@ -674,6 +679,7 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); @@ -681,6 +687,7 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index 5709a92..f5304a5 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -90,6 +90,7 @@ int OpIdentity::eval() // note OpConst is not templated DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 098b0ea..e4c0ee0 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -143,8 +143,9 @@ int OpAdd::register_fcn() }; break; case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a + b); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); @@ -371,6 +372,7 @@ int OpMaximum::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; @@ -388,6 +390,7 @@ int OpMinimum::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: case DType_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; @@ -407,8 +410,9 @@ int OpMul::register_fcn() switch (InDtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a * b); }; break; case DType_INT32: this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { @@ -457,8 +461,9 @@ int OpPow::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(powf(a, b)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -482,8 +487,9 @@ int OpSub::register_fcn() }; break; case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc(a - b); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); @@ -581,6 +587,7 @@ int OpTable::eval() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); @@ -617,23 +624,28 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); @@ -643,5 +655,6 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); // Instantiation of nodes for comparison operators opEqual, opGreater // and opGreaterEqual DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index d85da1a..677a4e2 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -108,6 +108,7 @@ int OpSelect<0, Dtype>::eval() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 00897cc..5347b8c 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -78,11 +78,14 @@ int OpAbs::register_fcn() { switch (Dtype) { - case DType_FP32: - case DType_FP16: + case DType_FP32: // No fpTrunc for FP32 as it is a no-op case DType_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; + case DType_FP16: + case DType_BF16: + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(a > (InEigenType)0 ? a : (-a)); }; + break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); } @@ -113,8 +116,9 @@ int OpCeil::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(ceilf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -161,8 +165,9 @@ int OpExp::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(expf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -177,8 +182,9 @@ int OpFloor::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(floorf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -193,8 +199,9 @@ int OpLog::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(logf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -245,10 +252,11 @@ int OpNegate::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); - return result; + return fpTrunc(result); }; break; case DType_INT16: @@ -297,8 +305,9 @@ int OpReciprocal::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.0 / a); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -313,8 +322,9 @@ int OpRsqrt::register_fcn() switch (Dtype) { case DType_FP16: + case DType_BF16: case DType_FP32: - this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; + this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc(1.0 / sqrtf(a)); }; break; default: ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); @@ -325,6 +335,7 @@ int OpRsqrt::register_fcn() // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); @@ -333,29 +344,36 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index cf1d9f7..66efee0 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -63,7 +63,7 @@ int OpResize::checkTensorAttributes() if (this->mode == ResizeMode_BILINEAR) { - if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16) + if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; @@ -71,7 +71,7 @@ int OpResize::checkTensorAttributes() } else { - if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16) + if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; @@ -159,15 +159,15 @@ int OpResize::eval() resize_t dy; resize_t dx; - if (std::is_floating_point::value) + if (std::is_floating_point::value || (typeid(resize_t) == typeid(Eigen::bfloat16))) { - dy = fy - iy; - dx = fx - ix; + dy = (resize_t)(fy - iy); + dx = (resize_t)(fx - ix); } else { - dy = y - (iy * scale_y_n); - dx = x - (ix * scale_x_n); + dy = (resize_t)(y - (iy * scale_y_n)); + dx = (resize_t)(x - (ix * scale_x_n)); } int32_t iy0 = MAX(iy, 0); @@ -190,6 +190,15 @@ int OpResize::eval() acc += (OutEigenType)v10 * dy * (1.0 - dx); acc += (OutEigenType)v11 * dy * dx; } + else if ((typeid(resize_t) == typeid(Eigen::bfloat16))) + { + Eigen::bfloat16 bf16_acc; + bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx); + bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx; + bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx); + bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx; + acc = (float)bf16_acc; + } else { acc = (OutEigenType)v00 * (scale_y_n - dy) * (scale_x_n - dx); @@ -201,7 +210,7 @@ int OpResize::eval() else { ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode"); - if (std::is_floating_point::value) + if (std::is_floating_point::value || (typeid(resize_t) == typeid(Eigen::bfloat16))) { iy = (dy >= 0.5) ? iy1 : iy0; ix = (dx >= 0.5) ? ix1 : ix0; @@ -213,6 +222,9 @@ int OpResize::eval() } acc = in->getTensor()(b, iy, ix, c); } + if ((typeid(resize_t) == typeid(Eigen::bfloat16))) { + ASSERT_MSG(checkValidBFloat(acc), "Resize accumulator float value is not a valid bfloat16 value."); + } out->getTensor()(b, oy, ox, c) = acc; } @@ -225,4 +237,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t); DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float); +DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16); DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 1ff8229..0121ccf 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -49,6 +49,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // tensor_ops case Op_ARGMAX: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); @@ -56,6 +57,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_AVG_POOL2D: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, BF16, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); @@ -63,6 +65,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32); @@ -71,6 +74,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_CONV3D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32); @@ -79,6 +83,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_DEPTHWISE_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32); @@ -87,6 +92,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_FULLY_CONNECTED: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32); @@ -95,12 +101,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_MATMUL: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, BF16, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, BF16); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); @@ -108,6 +116,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_TRANSPOSE_CONV2D: DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32); + DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32); DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32); @@ -117,22 +126,26 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // activation_funcs case Op_CLAMP: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); break; case Op_SIGMOID: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); break; case Op_TANH: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); break; // ewise_binary case Op_ADD: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); break; @@ -180,16 +193,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_MAXIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); break; case Op_MINIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); break; case Op_MUL: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); @@ -197,10 +213,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_POW: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); break; case Op_SUB: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); break; @@ -212,6 +230,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // ewise_unary case Op_ABS: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); break; @@ -222,6 +241,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_CEIL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); break; case Op_CLZ: @@ -229,14 +249,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_EXP: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); break; case Op_FLOOR: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); break; case Op_LOG: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); break; case Op_LOGICAL_NOT: @@ -244,6 +267,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_NEGATE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); @@ -251,16 +275,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_RECIPROCAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); break; case Op_RSQRT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); break; // ewise_ternary case Op_SELECT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); @@ -271,16 +298,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // comparison case Op_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); break; case Op_GREATER: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); break; case Op_GREATER_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); break; @@ -294,6 +324,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REDUCE_MAX: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); @@ -301,6 +332,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REDUCE_MIN: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); @@ -308,10 +340,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REDUCE_PRODUCT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); break; case Op_REDUCE_SUM: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); break; @@ -319,6 +353,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, // data layout case Op_CONCAT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); @@ -327,6 +362,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_PAD: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); @@ -335,6 +371,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_RESHAPE: DEF_FACTORY_RESHAPE(OpReshape, FP16); + DEF_FACTORY_RESHAPE(OpReshape, BF16); DEF_FACTORY_RESHAPE(OpReshape, FP32); DEF_FACTORY_RESHAPE(OpReshape, INT8); DEF_FACTORY_RESHAPE(OpReshape, INT16); @@ -343,6 +380,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_REVERSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); @@ -351,6 +389,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_SLICE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); @@ -359,6 +398,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, break; case Op_TILE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); @@ -368,6 +408,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, case Op_TRANSPOSE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); @@ -380,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, INT16); DEF_FACTORY_ONE_TYPE(OpGather, INT32); DEF_FACTORY_ONE_TYPE(OpGather, FP16); + DEF_FACTORY_ONE_TYPE(OpGather, BF16); DEF_FACTORY_ONE_TYPE(OpGather, FP32); break; case Op_SCATTER: @@ -387,6 +429,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpScatter, INT16); DEF_FACTORY_ONE_TYPE(OpScatter, INT32); DEF_FACTORY_ONE_TYPE(OpScatter, FP16); + DEF_FACTORY_ONE_TYPE(OpScatter, BF16); DEF_FACTORY_ONE_TYPE(OpScatter, FP32); break; @@ -397,6 +440,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48); DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16); DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16); + DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16); DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32); break; @@ -405,6 +449,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, return new OpConst(sgt, id); case Op_IDENTITY: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); @@ -435,6 +480,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index b525e69..f399bd1 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -111,6 +111,12 @@ return new OP(sgt, attribute, id); \ } +#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + { \ + return new OP(sgt, attribute, id); \ + } + #define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ { \ diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index eccba09..cd9d55f 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -80,10 +80,30 @@ int ReduceNode::checkTensorAttributes() return 0; } +// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0 +// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150 +// which seems to be due to incorrect data being passed internally as m_impl +struct AllReducer { + static const bool PacketAccess = false; + void reduce(const bool val, bool* accum) { + *accum = *accum && val; + } + bool initialize() const { return true; } + bool finalize(const bool accum) const { return accum; } +}; +struct AnyReducer { + static const bool PacketAccess = false; + void reduce(const bool val, bool* accum) { + *accum = *accum || val; + } + bool initialize() const { return false; } + bool finalize(const bool accum) const { return accum; } +}; + template int OpReduceAll::eval() { - this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions()); + this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions()); return GraphNode::eval(); } @@ -91,7 +111,7 @@ int OpReduceAll::eval() template int OpReduceAny::eval() { - this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions()); + this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions()); return GraphNode::eval(); } @@ -115,7 +135,16 @@ int OpReduceMin::eval() template int OpReduceProduct::eval() { - this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); + switch(Dtype) + { + case DType_FP16: + case DType_BF16: + this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc(f);}); + break; + default: + this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); + break; + } return GraphNode::eval(); } @@ -123,7 +152,16 @@ int OpReduceProduct::eval() template int OpReduceSum::eval() { - this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); + switch(Dtype) + { + case DType_FP16: + case DType_BF16: + this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc(f);}); + break; + default: + this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); + break; + } return GraphNode::eval(); } @@ -159,20 +197,24 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index b6c4043..bcd8ce5 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -227,10 +227,12 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8); DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16); DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); +DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32); diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 3de4899..647ca84 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -19,6 +19,8 @@ #include "tosa_generated.h" #include #include "half.hpp" +#include +#include "arith_util.h" using namespace tosa; @@ -76,6 +78,12 @@ struct GetEigenType using type = float; }; template <> +struct GetEigenType +{ + // NOTE: full precision used + using type = float; +}; +template <> struct GetEigenType { using type = int32_t; @@ -132,12 +140,6 @@ struct GetAccEigenType using type = typename GetEigenType::type; }; -template -struct GetHalfEigenType -{ - using type = half_float::half; -}; - // Meta function to get number of bits template struct GetNumBits diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 7db5182..b9ac94a 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -507,12 +507,13 @@ int OpAvgPool2d::eval() Eigen::array, 1> contract_dims = { Eigen::IndexPair(1, 0) }; Eigen::array bcast{ out_batch, 1, 1, out_channels }; + ETensor2 dm2_w = div_map_w.reshape(Eigen::array{ 1, out_width }); + ETensor2 dm2_h = div_map_h.reshape(Eigen::array{ out_height, 1 }); ETensor4 div_map = - div_map_h.reshape(Eigen::array{ out_height, 1 }) - .contract(div_map_w.reshape(Eigen::array{ 1, out_width }), contract_dims) + dm2_h.contract(dm2_w, contract_dims) .reshape(Eigen::array{ 1, out_height, out_width, 1 }) .broadcast(bcast); - if (Dtype != DType_FP32 && Dtype != DType_FP16) + if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16) { try { @@ -533,7 +534,7 @@ int OpAvgPool2d::eval() } else { - // Case for float-type resizes + // Case for float-types this->out->getTensor() = (sum / div_map.template cast()).template cast(); } @@ -1679,12 +1680,14 @@ int OpTransposeConv2d::eval() // template explicit instantiation DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP32); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, BF16, FP32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP32, FP32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32); @@ -1692,6 +1695,7 @@ DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32); // [in_t, weight_t, acc_t] DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, BF16, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32); @@ -1699,6 +1703,7 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, BF16, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32); @@ -1706,6 +1711,7 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, BF16, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32); @@ -1713,6 +1719,7 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, BF16, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32); @@ -1722,15 +1729,18 @@ DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP32); +DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, BF16, FP32); DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP32, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32); diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index f51c38c..e30c7bd 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -353,6 +353,9 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index ae216d8..112e641 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -15,6 +15,7 @@ #include "subgraph_traverser.h" #include "tosa_model_types.h" +#include "arith_util.h" #ifndef SUBGRAPH_ERROR_IF #define SUBGRAPH_ERROR_IF(COND, fmt, ...) \ @@ -403,6 +404,16 @@ int SubgraphTraverser::allocateTensor() tensor->setTensorValueFloat(f16_data.size(), f16_data.data()); } break; + case DType_BF16: + { + std::vector fp32_data; + TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); + // Ensure valid bfloat16 stored in each float + for (auto f : fp32_data) + ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f); + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } + break; case DType_FP32: { std::vector fp32_data; diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 8d192ca..4eaf21d 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -90,10 +90,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) int64_t* i64databuf = nullptr; bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; + DType dtype = getDtype(); - switch (getDtype()) + switch (dtype) { case DType_FP32: + case DType_BF16: fdatabuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(fdatabuf); @@ -154,19 +156,38 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) FATAL_ERROR("Unknown error parsing Numpy file: %s", filename); } - switch (getDtype()) + switch (dtype) { case DType_FP16: // Convert from fp16 to fp32 + //TODO(jw): remove this once we cast to fp16 in register_fcn/eval for (uint32_t i=0; i < elements; i++) { fdatabuf[i] = half_float::half_cast(f16databuf[i]); } - // Fall through to DType_FP32 case + if (setTensorValueFloat(elements, fdatabuf)) + { + free(f16databuf); + free(fdatabuf); + return 1; + } + break; + case DType_BF16: + for (uint32_t i=0; i < elements; i++) + { + ASSERT_MSG( + checkValidBFloat(fdatabuf[i]), + "Input float value not a valid bfloat16 value." + ); + } + if (setTensorValueFloat(elements, fdatabuf)) + { + free(fdatabuf); + return 1; + } + break; case DType_FP32: if (setTensorValueFloat(elements, fdatabuf)) { - if (f16databuf) - free(f16databuf); free(fdatabuf); return 1; } @@ -226,10 +247,12 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const bool* bdatabuf = nullptr; NumpyUtilities::NPError nperror; uint32_t elements = getElementCount(); + DType dtype = getDtype(); - switch (getDtype()) + switch (dtype) { case DType_FP32: + case DType_BF16: fdatabuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(fdatabuf); @@ -238,7 +261,6 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(fdatabuf); return 1; } - nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf); free(fdatabuf); diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 4efbf84..a3ce4bb 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -646,6 +646,7 @@ public: { case DType_FP32: case DType_FP16: + case DType_BF16: switch (rank) { case 0: diff --git a/thirdparty/eigen b/thirdparty/eigen index 21ae2af..3147391 160000 --- a/thirdparty/eigen +++ b/thirdparty/eigen @@ -1 +1 @@ -Subproject commit 21ae2afd4edaa1b69782c67a54182d34efe43f9c +Subproject commit 3147391d946bb4b6c68edd901f2add6ac1f31f8c diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index e1072a9..34a6279 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit e1072a9ed871fd474e7b09b7a74ae7be5f0a6f78 +Subproject commit 34a627959a61b4eccbeea4400cf9684debb331dc diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 8ae3218..b7a76b6 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -9,6 +9,7 @@ from enum import unique from pathlib import Path import numpy as np +from generator.tosa_utils import float32_is_valid_bfloat16 ################################## color_printing = True @@ -63,7 +64,12 @@ TestResultErrorStr = [ def test_check( - reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3 + reference, + result, + test_name="test", + quantize_tolerance=0, + float_tolerance=1e-3, + misc_checks=[], ): """Check if the result is the same as the expected reference.""" if not os.path.isfile(reference): @@ -111,6 +117,20 @@ def test_check( ) return (TestResult.MISMATCH, 0.0, msg) + # Perform miscellaneous checks + if "bf16" in misc_checks: + # Ensure floats are valid bfloat16 values + test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat]) + ref_res_is_bf16 = all( + [float32_is_valid_bfloat16(f) for f in reference_result.flat] + ) + if not (test_res_is_bf16 and ref_res_is_bf16): + msg = ( + "All output values must be valid bfloat16. " + "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}" + ) + return (TestResult.INCORRECT_FORMAT, 0.0, msg) + # for quantized test, allow +-(quantize_tolerance) error if reference_result.dtype == np.int32 or reference_result.dtype == np.int64: diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 0203513..932ad55 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -776,7 +776,7 @@ class TosaTensorValuesGen: ), "Op.MUL must have 2 placeholders, 0 consts" tens = [] - if dtypeList[0] in (DType.FP16, DType.FP32): + if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32): tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:])) else: placeholders = [] @@ -1130,6 +1130,8 @@ class TosaArgGen: accum_dtypes = [DType.INT48] elif dtype == DType.FP16: accum_dtypes = [DType.FP16, DType.FP32] + elif dtype == DType.BF16: + accum_dtypes = [DType.FP32] elif dtype == DType.FP32: accum_dtypes = [DType.FP32] elif error_name is None: @@ -1304,7 +1306,7 @@ class TosaArgGen: accum_dtypes = [DType.INT32] elif dtype == DType.FP16: accum_dtypes = [DType.FP16, DType.FP32] - elif dtype == DType.FP32: + elif dtype == DType.BF16 or dtype == DType.FP32: accum_dtypes = [DType.FP32] elif error_name is None: assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" @@ -1417,6 +1419,8 @@ class TosaArgGen: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: dtypeList = [DType.INT8, DType.INT16, DType.INT32] + elif inDtype == DType.BF16: + dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP32: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif error_name == ErrorIf.WrongInputType: @@ -1826,6 +1830,8 @@ class TosaArgGen: outputDTypeList = [DType.INT48] elif dtype == DType.FP16: outputDTypeList = [DType.FP16] + elif dtype == DType.BF16: + outputDTypeList = [DType.BF16] elif dtype == DType.FP32: outputDTypeList = [DType.FP32] elif error_name == ErrorIf.WrongInputType: diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index abe1a97..a850699 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -158,6 +158,15 @@ class TosaErrorIfArgGen: DType.INT48, DType.FP32, ) + elif dtype == DType.BF16: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + ) elif dtype == DType.FP32: incorrect_types = ( DType.INT4, @@ -299,8 +308,8 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]: - outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32] + if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]: + outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] else: @@ -425,6 +434,7 @@ class TosaErrorValidator: and output_dtype != DType.INT48 ) or (input_dtype == DType.FP16 and output_dtype != DType.FP16) + or (input_dtype == DType.BF16 and output_dtype != DType.BF16) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True @@ -442,25 +452,29 @@ class TosaErrorValidator: input_dtype == DType.FP16 and output_dtype not in (DType.FP16, DType.FP32) ) + or (input_dtype == DType.BF16 and output_dtype != DType.FP32) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) ): error_result = True elif op["op"] == Op.ARGMAX: if ( - input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + input_dtype + in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] and output_dtype != DType.INT32 ): error_result = True elif op["op"] == Op.MUL: if ( - input_dtype not in (DType.FP16, DType.FP32) + input_dtype not in (DType.FP16, DType.BF16, DType.FP32) and output_dtype != DType.INT32 ): error_result = True elif input_dtype == DType.FP16 and output_dtype != DType.FP16: error_result = True + elif input_dtype == DType.BF16 and output_dtype != DType.BF16: + error_result = True elif input_dtype == DType.FP32 and output_dtype != DType.FP32: error_result = True @@ -489,6 +503,7 @@ class TosaErrorValidator: DType.INT32, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -500,6 +515,7 @@ class TosaErrorValidator: DType.INT32, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( @@ -511,12 +527,17 @@ class TosaErrorValidator: DType.INT16, DType.FP32, DType.FP16, + DType.BF16, ] ) or ( input_dtype == DType.FP16 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] ) + or ( + input_dtype == DType.BF16 + and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] + ) or ( input_dtype == DType.FP32 and output_dtype not in [DType.INT8, DType.INT16, DType.INT32] @@ -537,6 +558,8 @@ class TosaErrorValidator: and output_dtype != DType.INT48 or input_dtype == DType.FP16 and output_dtype not in (DType.FP16, DType.FP32) + or input_dtype == DType.BF16 + and output_dtype != DType.FP32 or input_dtype == DType.FP32 and output_dtype != DType.FP32 ): @@ -2316,12 +2339,14 @@ class TosaInvalidValidator: not (input_dtype == DType.INT8 and output_dtype == DType.INT32) and not (input_dtype == DType.INT16 and output_dtype == DType.INT48) and not (input_dtype == DType.FP16 and output_dtype == DType.FP16) + and not (input_dtype == DType.BF16 and output_dtype == DType.BF16) and not (input_dtype == DType.FP32 and output_dtype == DType.FP32) ) elif mode == ResizeMode.NEAREST: # Invalid output data type / Invalid input datatype return (input_dtype != output_dtype) or ( - input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + input_dtype + not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] ) else: # Invalid resize mode diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 78d86cd..95e06ed 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -16,6 +16,7 @@ from generator.tosa_error_if import TosaInvalidValidator from generator.tosa_utils import DTYPE_ATTRIBUTES from generator.tosa_utils import MAX_RESIZE_DIMENSION from generator.tosa_utils import usableDTypes +from generator.tosa_utils import vect_f32_to_bf16 from tosa.DType import DType from tosa.Op import Op @@ -84,6 +85,10 @@ class TosaTestGen: ) elif dtype == DType.FP16: return np.float16(self.rng.random(size=shape)) + elif dtype == DType.BF16: + f32_tensor = np.float32(self.rng.random(size=shape)) + # Floor the last 16 bits of each f32 value + return np.float32(vect_f32_to_bf16(f32_tensor)) elif dtype == DType.FP32: return np.float32(self.rng.random(size=shape)) else: @@ -134,6 +139,9 @@ class TosaTestGen: elif dtype == DType.FP16: rand_f32 = self.rng.random() return np.float16(rand_f32) + elif dtype == DType.BF16: + rand_f32 = self.rng.random() + return vect_f32_to_bf16(rand_f32) elif dtype == DType.BOOL: return self.rng.choice([False, True]) # TOSA specific INT4 weight range from -7 to 7 @@ -324,7 +332,7 @@ class TosaTestGen: # Special for multiply: # Force the result to INT32 for INT types - if a.dtype not in (DType.FP16, DType.FP32): + if a.dtype not in (DType.FP16, DType.BF16, DType.FP32): result_tens.setDtype(DType.INT32) if error_name == ErrorIf.WrongOutputType: all_dtypes = [DType.INT8, DType.INT16, DType.INT48] @@ -1043,7 +1051,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - if a.dtype in (DType.FP16, DType.FP32): + if a.dtype in (DType.FP16, DType.BF16, DType.FP32): attr.ClampAttribute(0, 0, min_val, max_val) else: attr.ClampAttribute(min_val, max_val, 0, 0) @@ -1859,7 +1867,7 @@ class TosaTestGen: op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr ) - if a.dtype in (DType.FP32, DType.FP16, DType.INT32): + if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32): then_op, else_op = Op.ADD, Op.SUB elif a.dtype in (DType.INT8, DType.INT16): then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT @@ -2398,7 +2406,7 @@ class TosaTestGen: # if not specified, defaults to (1, 4) # 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum) # 'types': array of datatypes to be tested - TYPE_FP = [DType.FP32, DType.FP16] + TYPE_FP = [DType.FP32, DType.FP16, DType.BF16] TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4 TYPE_INT_FP = [ @@ -2406,13 +2414,20 @@ class TosaTestGen: DType.INT16, DType.INT32, DType.FP16, + DType.BF16, DType.FP32, ] # Excludes INT4 TYPE_BOOL = [DType.BOOL] - TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32] # floating-types and INT32 + TYPE_FI32 = [ + DType.FP32, + DType.FP16, + DType.BF16, + DType.INT32, + ] # floating-types and INT32 TYPE_FIB = [ DType.FP16, + DType.BF16, DType.FP32, DType.INT8, DType.INT16, @@ -2421,7 +2436,7 @@ class TosaTestGen: ] TYPE_FI16 = [DType.FP32, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32] + TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] # List of [Input Type 1, Input Type 2, Accumulator Type] TYPE_CONV = [ @@ -2430,6 +2445,7 @@ class TosaTestGen: [DType.INT16, DType.INT8, DType.INT48], [DType.FP16, DType.FP16, DType.FP16], [DType.FP16, DType.FP16, DType.FP32], + [DType.BF16, DType.BF16, DType.FP32], [DType.FP32, DType.FP32, DType.FP32], ] @@ -3448,7 +3464,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReduceSum, TosaArgGen.agAxis, ), - "types": (DType.FP16, DType.FP32, DType.INT32), + "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32), "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -3635,7 +3651,14 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, None, ), - "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32), + "types": ( + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP16, + DType.BF16, + DType.FP32, + ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -3676,7 +3699,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgDefault, TosaArgGen.agResize, ), - "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32), + "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32), "invalid_test_validators": ( TosaInvalidValidator.ivWrongDataTypeOrModeResize, ), @@ -3712,6 +3735,7 @@ class TosaTestGen: ), "types": ( DType.FP16, + DType.BF16, DType.FP32, DType.INT8, DType.INT16, @@ -3842,6 +3866,8 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, + DType.FP16, + DType.BF16, DType.FP32, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) @@ -3872,6 +3898,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3900,6 +3928,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3929,6 +3959,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] outputDType = rng.choice(wrong_dtypes) else: @@ -3955,6 +3987,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -3987,6 +4021,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([DType.INT32])) outputDType = rng.choice(wrong_dtypes) @@ -4189,6 +4225,7 @@ class OutputShaper: DType.INT48, DType.FP32, DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4226,6 +4263,8 @@ class OutputShaper: DType.INT16, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ) elif a.dtype == DType.INT16: incorrect_types = ( @@ -4234,8 +4273,12 @@ class OutputShaper: DType.INT16, DType.INT32, DType.FP32, + DType.FP16, + DType.BF16, ) - elif a.dtype == DType.FP32 or a.dtype == DType.FP16: + elif ( + a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16 + ): incorrect_types = ( DType.INT4, DType.INT8, @@ -4278,6 +4321,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, } wrong_dtypes = list(all_dtypes - set([input1.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4306,6 +4351,7 @@ class OutputShaper: DType.INT48, DType.FP32, DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4329,6 +4375,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4347,6 +4395,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4383,6 +4433,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4411,6 +4463,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([a.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4435,6 +4489,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([values.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4462,6 +4518,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -4483,6 +4541,8 @@ class OutputShaper: DType.INT32, DType.INT48, DType.FP32, + DType.FP16, + DType.BF16, ] wrong_dtypes.remove(output_dtype) output_dtype = rng.choice(wrong_dtypes) diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 104d9bb..d79ab3c 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -1,5 +1,9 @@ # Copyright (c) 2021-2022, ARM Limited. # SPDX-License-Identifier: Apache-2.0 +import struct +import sys + +import numpy as np from tosa.DType import DType # Maximum dimension size for output and inputs for RESIZE @@ -15,6 +19,7 @@ DTYPE_ATTRIBUTES = { DType.INT32: {"str": "i32", "width": 32}, DType.INT48: {"str": "i48", "width": 48}, DType.FP16: {"str": "f16", "width": 16}, + DType.BF16: {"str": "bf16", "width": 16}, DType.FP32: {"str": "f32", "width": 32}, } @@ -125,7 +130,11 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.FP32, DType.FP16, ) - elif input_dtype == DType.FP32 or input_dtype == DType.FP16: + elif ( + input_dtype == DType.FP32 + or input_dtype == DType.FP16 + or input_dtype == DType.BF16 + ): incorrect_types = ( DType.INT4, DType.INT8, @@ -134,3 +143,37 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT48, ) return rng.choice(a=incorrect_types) + + +def float32_is_valid_bfloat16(f): + """Return True if float value is valid bfloat16.""" + f32_bits = get_float32_bitstring(f) + return f32_bits[16:] == "0" * 16 + + +def get_float32_bitstring(f): + """Return a big-endian string of bits representing a 32 bit float.""" + f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0] + return f"{f32_bits_as_int:032b}" + + +def float32_to_bfloat16(f): + """Turns fp32 value into bfloat16 by flooring. + + Floors the least significant 16 bits of the input + fp32 value and returns this valid bfloat16 representation as fp32. + For simplicity during bit-wrangling, ignores underlying system + endianness and interprets as big-endian. + Returns a bf16-valid float following system's native byte order. + """ + f32_bits = get_float32_bitstring(f) + f32_floored_bits = f32_bits[:16] + "0" * 16 + + # Assume sys.byteorder matches system's underlying float byteorder + fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder) + return struct.unpack("@f", fp_bytes)[0] # native byteorder + + +vect_f32_to_bf16 = np.vectorize( + float32_to_bfloat16, otypes=(np.float32,) +) # NumPy vectorize: applies function to vector faster than looping diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py index 2fafacb..ab78b1a 100644 --- a/verif/generator/tosa_verif_build_tests.py +++ b/verif/generator/tosa_verif_build_tests.py @@ -5,6 +5,7 @@ import re from generator.tosa_test_gen import TosaTestGen from serializer.tosa_serializer import dtype_str_to_val +from serializer.tosa_serializer import DTypeNames # Used for parsing a comma-separated list of integers in a string @@ -150,13 +151,14 @@ def parseArgs(argv): help="Create tests with a particular input tensor rank", ) + # Used for parsing a comma-separated list of integers in a string parser.add_argument( "--target-dtype", dest="target_dtypes", action="append", default=None, type=lambda x: dtype_str_to_val(x), - help="Create test with a particular DType (may be repeated)", + help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)", ) parser.add_argument( diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py index b608fd8..50ff1ab 100644 --- a/verif/tests/test_tosa_refmodel.py +++ b/verif/tests/test_tosa_refmodel.py @@ -47,6 +47,7 @@ REF_MODEL_TYPE_TO_OUT = { "int32": "i32", "fp32": "f32", "fp16": "f16", + "bf16": "bf16", } @@ -127,11 +128,13 @@ TEST_PARAMS = [ ("abs", "int32", 1), ("abs", "fp32", 1), ("abs", "fp16", 1), + ("abs", "bf16", 1), ("negate", "int8", 1), ("negate", "int16", 1), ("negate", "int32", 1), ("negate", "fp32", 1), ("negate", "fp16", 1), + ("negate", "bf16", 1), # One test per axis (shape dimensions) ("concat", "bool", SHAPE_DIMS), ("concat", "int8", SHAPE_DIMS), @@ -139,6 +142,7 @@ TEST_PARAMS = [ ("concat", "int32", SHAPE_DIMS), ("concat", "fp32", SHAPE_DIMS), ("concat", "fp16", SHAPE_DIMS), + ("concat", "bf16", SHAPE_DIMS), ] @@ -165,6 +169,9 @@ def test_refmodel_simple_op(tosaTest): # Generate TOSA test(s) (mostly should be single test) test_dirs = tosaTest.create_test() + # Indicate miscellaneous checks to run in tosa_check + misc_checks = [] + for test_dir in test_dirs: # Run ref model desc_file = test_dir / TEST_DESC_FILENAME @@ -227,8 +234,15 @@ def test_refmodel_simple_op(tosaTest): np.save(str(result_file), result) assert result_file.is_file() + # Ensure valid bf16 + if tosaTest.ref_model_type == "bf16": + misc_checks.append("bf16") + # Check Numpy result versus refmodel check_result, tolerance, msg = tosa_check( - str(result_file), str(ofm_file), test_name=test_dir.name + str(result_file), + str(ofm_file), + test_name=test_dir.name, + misc_checks=misc_checks, ) assert check_result == TosaResult.PASS -- cgit v1.2.1