aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2022-10-19 12:20:31 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2022-11-09 12:19:51 +0000
commit24dbc420aae556649f50e645bd94489dab2cc75a (patch)
tree490345da43e9c5bae0f450ba05ffe85874077e0a
parent3b0544c1e7463295c49a48a162ebb9a546326829 (diff)
downloadreference_model-24dbc420aae556649f50e645bd94489dab2cc75a.tar.gz
Add BF16 support to reference model
* Upgrade Eigen to 3.4.0 (for bfloat16 support) and add work- arounds for reduce.any() and reduce.all() bugs (introduced between 3.3.7 and 3.4.0) * Truncation to bfloat16 now performed in eval() methods Signed-off-by: James Ward <james.ward@arm.com> Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: If5f5c988d76d3d30790acf3b97081726b89205fe
-rw-r--r--reference_model/include/func_config.h1
-rw-r--r--reference_model/src/arith_util.h89
-rw-r--r--reference_model/src/main.cpp3
-rw-r--r--reference_model/src/ops/activation_funcs.cc13
-rw-r--r--reference_model/src/ops/comparison.cc6
-rw-r--r--reference_model/src/ops/data_layout.cc7
-rw-r--r--reference_model/src/ops/data_nodes.cc1
-rw-r--r--reference_model/src/ops/ewise_binary.cc21
-rw-r--r--reference_model/src/ops/ewise_ternary.cc1
-rw-r--r--reference_model/src/ops/ewise_unary.cc36
-rw-r--r--reference_model/src/ops/image.cc29
-rw-r--r--reference_model/src/ops/op_factory.cc48
-rw-r--r--reference_model/src/ops/op_factory.h6
-rw-r--r--reference_model/src/ops/reduction.cc50
-rw-r--r--reference_model/src/ops/scatter_gather.cc2
-rw-r--r--reference_model/src/ops/template_types.h14
-rw-r--r--reference_model/src/ops/tensor_ops.cc18
-rw-r--r--reference_model/src/ops/type_conversion.cc3
-rw-r--r--reference_model/src/subgraph_traverser.cc11
-rw-r--r--reference_model/src/tensor.cc36
-rw-r--r--reference_model/src/tensor.h1
m---------thirdparty/eigen0
m---------thirdparty/serialization_lib0
-rw-r--r--verif/checker/tosa_result_checker.py22
-rw-r--r--verif/generator/tosa_arg_gen.py10
-rw-r--r--verif/generator/tosa_error_if.py35
-rw-r--r--verif/generator/tosa_test_gen.py80
-rw-r--r--verif/generator/tosa_utils.py45
-rw-r--r--verif/generator/tosa_verif_build_tests.py4
-rw-r--r--verif/tests/test_tosa_refmodel.py16
30 files changed, 542 insertions, 66 deletions
diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h
index 41df135..d9b51d5 100644
--- a/reference_model/include/func_config.h
+++ b/reference_model/include/func_config.h
@@ -36,6 +36,7 @@ struct func_config_t
uint32_t tosa_profile = 1;
uint32_t dump_intermediates = 0;
std::string fp_format = "0.5";
+ bool float_is_big_endian = false; // Set in arith_util.h by float_is_big_endian()
};
#endif
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h
index 554a7a2..33bdeed 100644
--- a/reference_model/src/arith_util.h
+++ b/reference_model/src/arith_util.h
@@ -31,13 +31,18 @@
#include <math.h>
#define __STDC_LIMIT_MACROS //enable min/max of plain data type
#include "func_debug.h"
+#include "func_config.h"
#include "inttypes.h"
+#include "tosa_generated.h"
#include <cassert>
#include <iostream>
#include <limits>
#include <stdint.h>
#include <typeinfo>
+#include <Eigen/Core>
+#include <bitset>
+using namespace tosa;
using namespace std;
inline size_t _count_one(uint64_t val)
@@ -191,4 +196,88 @@ constexpr T saturate(const uint32_t width, const intmax_t value)
// clang-format on
}
+inline void float_trunc_bytes(float* src)
+{
+ /* Set the least significant two bytes to zero for the input float value.*/
+ char src_as_bytes[sizeof(float)];
+ memcpy(src_as_bytes, src, sizeof(float));
+
+ if (g_func_config.float_is_big_endian)
+ {
+ src_as_bytes[2] = '\000';
+ src_as_bytes[3] = '\000';
+ }
+ else
+ {
+ src_as_bytes[0] = '\000';
+ src_as_bytes[1] = '\000';
+ }
+
+ memcpy(src, &src_as_bytes, sizeof(float));
+}
+
+inline void truncateFloatToBFloat(float* src, int64_t size) {
+ /* Set the least significant two bytes to zero for each float
+ value in the input src buffer. */
+ ASSERT_MEM(src);
+ ASSERT_MSG(size > 0, "Size of src (representing number of values in src) must be a positive integer.");
+ for (; size != 0; src++, size--)
+ {
+ float_trunc_bytes(src);
+ }
+}
+
+inline bool checkValidBFloat(float src)
+{
+ /* Checks if the least significant two bytes are zero. */
+ ASSERT_MEM(src);
+ char src_as_bytes[sizeof(float)];
+ memcpy(src_as_bytes, &src, sizeof(float));
+
+ if (g_func_config.float_is_big_endian)
+ {
+ return (src_as_bytes[2] == '\000' && src_as_bytes[3] == '\000');
+ }
+ else
+ {
+ return (src_as_bytes[0] == '\000' && src_as_bytes[1] == '\000');
+ }
+}
+
+inline bool float_is_big_endian()
+{
+ /* Compares float values 1.0 and -1.0 by checking whether the
+ negation causes the first or the last byte to change.
+ First byte changing would indicate the float representation
+ is big-endian.*/
+ float f = 1.0;
+ char f_as_bytes[sizeof(float)];
+ memcpy(f_as_bytes, &f, sizeof(float));
+ f = -f;
+ char f_neg_as_bytes[sizeof(float)];
+ memcpy(f_neg_as_bytes, &f, sizeof(float));
+ return f_as_bytes[0] != f_neg_as_bytes[0];
+}
+
+template <DType Dtype>
+float fpTrunc(float f_in)
+{
+ /* Truncates a float value based on the DType it represents.*/
+ switch (Dtype)
+ {
+ case DType_BF16:
+ truncateFloatToBFloat(&f_in, 1);
+ break;
+ case DType_FP16:
+ // TODO(jw): implement FP16 truncate function (no-op placeholder for now)
+ break;
+ case DType_FP32:
+ // No-op for fp32
+ break;
+ default:
+ ASSERT_MSG(false, "DType %s should not be float-truncated.", EnumNameDType(Dtype));
+ }
+ return f_in;
+}
+
#endif /* _ARITH_UTIL_H */
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 776fbf3..5c2735d 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -20,6 +20,7 @@
#include "ops/op_factory.h"
#include "subgraph_traverser.h"
#include "tosa_serialization_handler.h"
+#include "arith_util.h"
#include <fstream>
#include <iostream>
@@ -67,6 +68,8 @@ int main(int argc, char** argv)
return TOSA_VERSION_MISMATCH;
}
+ g_func_config.float_is_big_endian = float_is_big_endian();
+
json test_desc;
// Initialize test descriptor
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 61f7df6..46234e2 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -16,6 +16,7 @@
#include "activation_funcs.h"
#include "quant_util.h"
#include "template_types.h"
+#include "arith_util.h"
#include <cmath>
using namespace TosaReference;
@@ -28,13 +29,14 @@ int OpClamp<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
{
InEigenType min = (InEigenType)attribute->min_fp();
InEigenType max = (InEigenType)attribute->max_fp();
ERROR_IF(max < min, "OpClamp: max smaller than min");
- this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a); };
}
break;
case DType_INT8:
@@ -59,8 +61,9 @@ int OpSigmoid<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); };
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / (1.0 + (expf(-1.0 * a)))); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -75,8 +78,9 @@ int OpTanh<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); };
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(tanhf(a)); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -87,12 +91,15 @@ int OpTanh<Rank, Dtype>::register_fcn()
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
index f240aa5..5b78a4f 100644
--- a/reference_model/src/ops/comparison.cc
+++ b/reference_model/src/ops/comparison.cc
@@ -28,6 +28,7 @@ int OpEqual<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
@@ -45,6 +46,7 @@ int OpGreater<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
@@ -62,6 +64,7 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
@@ -75,13 +78,16 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn()
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 69b6a65..bffd659 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -639,6 +639,7 @@ int OpTranspose<Rank, Dtype>::eval()
// template explicit instantiation
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
@@ -646,6 +647,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
@@ -653,6 +655,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
+DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
DEF_INSTANTIATE_RESHAPE(OpReshape, FP32);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
@@ -660,6 +663,7 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
@@ -667,6 +671,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
@@ -674,6 +679,7 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
@@ -681,6 +687,7 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index 5709a92..f5304a5 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -90,6 +90,7 @@ int OpIdentity<Rank, Dtype>::eval()
// note OpConst is not templated
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 098b0ea..e4c0ee0 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -143,8 +143,9 @@ int OpAdd<Rank, Dtype>::register_fcn()
};
break;
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a + b); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
@@ -371,6 +372,7 @@ int OpMaximum<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
@@ -388,6 +390,7 @@ int OpMinimum<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
@@ -407,8 +410,9 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
switch (InDtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a * b); };
break;
case DType_INT32:
this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
@@ -457,8 +461,9 @@ int OpPow<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(powf(a, b)); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -482,8 +487,9 @@ int OpSub<Rank, Dtype>::register_fcn()
};
break;
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a - b); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
@@ -581,6 +587,7 @@ int OpTable<Rank, InDtype>::eval()
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
@@ -617,23 +624,28 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
@@ -643,5 +655,6 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
// Instantiation of nodes for comparison operators opEqual, opGreater
// and opGreaterEqual
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index d85da1a..677a4e2 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -108,6 +108,7 @@ int OpSelect<0, Dtype>::eval()
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 00897cc..5347b8c 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -78,11 +78,14 @@ int OpAbs<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP32:
- case DType_FP16:
+ case DType_FP32: // No fpTrunc for FP32 as it is a no-op
case DType_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
break;
+ case DType_FP16:
+ case DType_BF16:
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a > (InEigenType)0 ? a : (-a)); };
+ break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
}
@@ -113,8 +116,9 @@ int OpCeil<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(ceilf(a)); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -161,8 +165,9 @@ int OpExp<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(expf(a)); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -177,8 +182,9 @@ int OpFloor<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(floorf(a)); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -193,8 +199,9 @@ int OpLog<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(logf(a)); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -245,10 +252,11 @@ int OpNegate<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
this->fcn = [](InEigenType a) -> OutEigenType {
InEigenType result = -(a);
- return result;
+ return fpTrunc<Dtype>(result);
};
break;
case DType_INT16:
@@ -297,8 +305,9 @@ int OpReciprocal<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / a); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -313,8 +322,9 @@ int OpRsqrt<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
+ case DType_BF16:
case DType_FP32:
- this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
+ this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / sqrtf(a)); };
break;
default:
ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
@@ -325,6 +335,7 @@ int OpRsqrt<Rank, Dtype>::register_fcn()
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
@@ -333,29 +344,36 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index cf1d9f7..66efee0 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -63,7 +63,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16)
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -71,7 +71,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16)
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -159,15 +159,15 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
resize_t dy;
resize_t dx;
- if (std::is_floating_point<resize_t>::value)
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
{
- dy = fy - iy;
- dx = fx - ix;
+ dy = (resize_t)(fy - iy);
+ dx = (resize_t)(fx - ix);
}
else
{
- dy = y - (iy * scale_y_n);
- dx = x - (ix * scale_x_n);
+ dy = (resize_t)(y - (iy * scale_y_n));
+ dx = (resize_t)(x - (ix * scale_x_n));
}
int32_t iy0 = MAX(iy, 0);
@@ -190,6 +190,15 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
acc += (OutEigenType)v10 * dy * (1.0 - dx);
acc += (OutEigenType)v11 * dy * dx;
}
+ else if ((typeid(resize_t) == typeid(Eigen::bfloat16)))
+ {
+ Eigen::bfloat16 bf16_acc;
+ bf16_acc = (Eigen::bfloat16)v00 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)(1.0 - dx);
+ bf16_acc += (Eigen::bfloat16)v01 * (Eigen::bfloat16)(1.0 - dy) * (Eigen::bfloat16)dx;
+ bf16_acc += (Eigen::bfloat16)v10 * (Eigen::bfloat16)dy * (Eigen::bfloat16)(1.0 - dx);
+ bf16_acc += (Eigen::bfloat16)v11 * (Eigen::bfloat16)dy * (Eigen::bfloat16)dx;
+ acc = (float)bf16_acc;
+ }
else
{
acc = (OutEigenType)v00 * (scale_y_n - dy) * (scale_x_n - dx);
@@ -201,7 +210,7 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
else
{
ASSERT_MSG(mode == ResizeMode_NEAREST, "OpResize: invalid mode");
- if (std::is_floating_point<resize_t>::value)
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)))
{
iy = (dy >= 0.5) ? iy1 : iy0;
ix = (dx >= 0.5) ? ix1 : ix0;
@@ -213,6 +222,9 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
}
acc = in->getTensor()(b, iy, ix, c);
}
+ if ((typeid(resize_t) == typeid(Eigen::bfloat16))) {
+ ASSERT_MSG(checkValidBFloat(acc), "Resize accumulator float value is not a valid bfloat16 value.");
+ }
out->getTensor()(b, oy, ox, c) = acc;
}
@@ -225,4 +237,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float);
+DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16);
DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float);
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 1ff8229..0121ccf 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -49,6 +49,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// tensor_ops
case Op_ARGMAX:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
@@ -56,6 +57,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_AVG_POOL2D:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, BF16, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
@@ -63,6 +65,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_CONV2D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32);
@@ -71,6 +74,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_CONV3D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32);
@@ -79,6 +83,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32);
@@ -87,6 +92,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_FULLY_CONNECTED:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, BF16, BF16, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32);
@@ -95,12 +101,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_MATMUL:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, BF16, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, BF16);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
@@ -108,6 +116,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_TRANSPOSE_CONV2D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32);
@@ -117,22 +126,26 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// activation_funcs
case Op_CLAMP:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
break;
case Op_SIGMOID:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
break;
case Op_TANH:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
break;
// ewise_binary
case Op_ADD:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
break;
@@ -180,16 +193,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_MAXIMUM:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
break;
case Op_MINIMUM:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
break;
case Op_MUL:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
@@ -197,10 +213,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_POW:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
break;
case Op_SUB:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
break;
@@ -212,6 +230,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// ewise_unary
case Op_ABS:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
break;
@@ -222,6 +241,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_CEIL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
break;
case Op_CLZ:
@@ -229,14 +249,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_EXP:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
break;
case Op_FLOOR:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
break;
case Op_LOG:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
break;
case Op_LOGICAL_NOT:
@@ -244,6 +267,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_NEGATE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
@@ -251,16 +275,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_RECIPROCAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
break;
case Op_RSQRT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
break;
// ewise_ternary
case Op_SELECT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
@@ -271,16 +298,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// comparison
case Op_EQUAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
break;
case Op_GREATER:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
break;
case Op_GREATER_EQUAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
break;
@@ -294,6 +324,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_REDUCE_MAX:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
@@ -301,6 +332,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_REDUCE_MIN:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
@@ -308,10 +340,12 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_REDUCE_PRODUCT:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
break;
case Op_REDUCE_SUM:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
break;
@@ -319,6 +353,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// data layout
case Op_CONCAT:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
@@ -327,6 +362,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_PAD:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
@@ -335,6 +371,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_RESHAPE:
DEF_FACTORY_RESHAPE(OpReshape, FP16);
+ DEF_FACTORY_RESHAPE(OpReshape, BF16);
DEF_FACTORY_RESHAPE(OpReshape, FP32);
DEF_FACTORY_RESHAPE(OpReshape, INT8);
DEF_FACTORY_RESHAPE(OpReshape, INT16);
@@ -343,6 +380,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_REVERSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
@@ -351,6 +389,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_SLICE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
@@ -359,6 +398,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_TILE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
@@ -368,6 +408,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_TRANSPOSE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
@@ -380,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, INT16);
DEF_FACTORY_ONE_TYPE(OpGather, INT32);
DEF_FACTORY_ONE_TYPE(OpGather, FP16);
+ DEF_FACTORY_ONE_TYPE(OpGather, BF16);
DEF_FACTORY_ONE_TYPE(OpGather, FP32);
break;
case Op_SCATTER:
@@ -387,6 +429,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpScatter, INT16);
DEF_FACTORY_ONE_TYPE(OpScatter, INT32);
DEF_FACTORY_ONE_TYPE(OpScatter, FP16);
+ DEF_FACTORY_ONE_TYPE(OpScatter, BF16);
DEF_FACTORY_ONE_TYPE(OpScatter, FP32);
break;
@@ -397,6 +440,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48);
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16);
DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16);
DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32);
break;
@@ -405,6 +449,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
return new OpConst(sgt, id);
case Op_IDENTITY:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
@@ -435,6 +480,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index b525e69..f399bd1 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -111,6 +111,12 @@
return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
}
+#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
+ }
+
#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
{ \
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index eccba09..cd9d55f 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -80,10 +80,30 @@ int ReduceNode<Rank, Dtype>::checkTensorAttributes()
return 0;
}
+// These 2 reducers are to overcome a bug introduced in Eigen between 3.3.7 and 3.4.0
+// The in-built .any and .all operations now fail on an assert in TensorMorphing.h:150
+// which seems to be due to incorrect data being passed internally as m_impl
+struct AllReducer {
+ static const bool PacketAccess = false;
+ void reduce(const bool val, bool* accum) {
+ *accum = *accum && val;
+ }
+ bool initialize() const { return true; }
+ bool finalize(const bool accum) const { return accum; }
+};
+struct AnyReducer {
+ static const bool PacketAccess = false;
+ void reduce(const bool val, bool* accum) {
+ *accum = *accum || val;
+ }
+ bool initialize() const { return false; }
+ bool finalize(const bool accum) const { return accum; }
+};
+
template <int Rank, DType Dtype>
int OpReduceAll<Rank, Dtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions());
+ this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions());
return GraphNode::eval();
}
@@ -91,7 +111,7 @@ int OpReduceAll<Rank, Dtype>::eval()
template <int Rank, DType Dtype>
int OpReduceAny<Rank, Dtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions());
+ this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions());
return GraphNode::eval();
}
@@ -115,7 +135,16 @@ int OpReduceMin<Rank, Dtype>::eval()
template <int Rank, DType Dtype>
int OpReduceProduct<Rank, Dtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
+ switch(Dtype)
+ {
+ case DType_FP16:
+ case DType_BF16:
+ this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
+ break;
+ default:
+ this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
+ break;
+ }
return GraphNode::eval();
}
@@ -123,7 +152,16 @@ int OpReduceProduct<Rank, Dtype>::eval()
template <int Rank, DType Dtype>
int OpReduceSum<Rank, Dtype>::eval()
{
- this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
+ switch(Dtype)
+ {
+ case DType_FP16:
+ case DType_BF16:
+ this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
+ break;
+ default:
+ this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
+ break;
+ }
return GraphNode::eval();
}
@@ -159,20 +197,24 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index b6c4043..bcd8ce5 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -227,10 +227,12 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32);
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index 3de4899..647ca84 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -19,6 +19,8 @@
#include "tosa_generated.h"
#include <Eigen/CXX11/Tensor>
#include "half.hpp"
+#include <Eigen/Core>
+#include "arith_util.h"
using namespace tosa;
@@ -76,6 +78,12 @@ struct GetEigenType<DType_FP16>
using type = float;
};
template <>
+struct GetEigenType<DType_BF16>
+{
+ // NOTE: full precision used
+ using type = float;
+};
+template <>
struct GetEigenType<DType_INT32>
{
using type = int32_t;
@@ -132,12 +140,6 @@ struct GetAccEigenType
using type = typename GetEigenType<Dtype>::type;
};
-template <DType Dtype>
-struct GetHalfEigenType
-{
- using type = half_float::half;
-};
-
// Meta function to get number of bits
template <DType T>
struct GetNumBits
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 7db5182..b9ac94a 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -507,12 +507,13 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
+ ETensor2<int32_t> dm2_w = div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width });
+ ETensor2<int32_t> dm2_h = div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 });
ETensor4<int32_t> div_map =
- div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 })
- .contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims)
+ dm2_h.contract(dm2_w, contract_dims)
.reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
.broadcast(bcast);
- if (Dtype != DType_FP32 && Dtype != DType_FP16)
+ if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16)
{
try
{
@@ -533,7 +534,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
}
else
{
- // Case for float-type resizes
+ // Case for float-types
this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
}
@@ -1679,12 +1680,14 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval()
// template explicit instantiation
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP32);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, BF16, FP32);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP32, FP32);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32);
@@ -1692,6 +1695,7 @@ DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32);
// [in_t, weight_t, acc_t]
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, BF16, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32);
@@ -1699,6 +1703,7 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, BF16, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP32, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32);
@@ -1706,6 +1711,7 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, BF16, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32);
@@ -1713,6 +1719,7 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, BF16, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP32, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32);
@@ -1722,15 +1729,18 @@ DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP32);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, BF16, FP32);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP32, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32);
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index f51c38c..e30c7bd 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -353,6 +353,9 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index ae216d8..112e641 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -15,6 +15,7 @@
#include "subgraph_traverser.h"
#include "tosa_model_types.h"
+#include "arith_util.h"
#ifndef SUBGRAPH_ERROR_IF
#define SUBGRAPH_ERROR_IF(COND, fmt, ...) \
@@ -403,6 +404,16 @@ int SubgraphTraverser::allocateTensor()
tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
}
break;
+ case DType_BF16:
+ {
+ std::vector<float> fp32_data;
+ TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+ // Ensure valid bfloat16 stored in each float
+ for (auto f : fp32_data)
+ ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ }
+ break;
case DType_FP32:
{
std::vector<float> fp32_data;
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 8d192ca..4eaf21d 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -90,10 +90,12 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
int64_t* i64databuf = nullptr;
bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
+ DType dtype = getDtype();
- switch (getDtype())
+ switch (dtype)
{
case DType_FP32:
+ case DType_BF16:
fdatabuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(fdatabuf);
@@ -154,19 +156,38 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
FATAL_ERROR("Unknown error parsing Numpy file: %s", filename);
}
- switch (getDtype())
+ switch (dtype)
{
case DType_FP16:
// Convert from fp16 to fp32
+ //TODO(jw): remove this once we cast to fp16 in register_fcn/eval
for (uint32_t i=0; i < elements; i++) {
fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
}
- // Fall through to DType_FP32 case
+ if (setTensorValueFloat(elements, fdatabuf))
+ {
+ free(f16databuf);
+ free(fdatabuf);
+ return 1;
+ }
+ break;
+ case DType_BF16:
+ for (uint32_t i=0; i < elements; i++)
+ {
+ ASSERT_MSG(
+ checkValidBFloat(fdatabuf[i]),
+ "Input float value not a valid bfloat16 value."
+ );
+ }
+ if (setTensorValueFloat(elements, fdatabuf))
+ {
+ free(fdatabuf);
+ return 1;
+ }
+ break;
case DType_FP32:
if (setTensorValueFloat(elements, fdatabuf))
{
- if (f16databuf)
- free(f16databuf);
free(fdatabuf);
return 1;
}
@@ -226,10 +247,12 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
bool* bdatabuf = nullptr;
NumpyUtilities::NPError nperror;
uint32_t elements = getElementCount();
+ DType dtype = getDtype();
- switch (getDtype())
+ switch (dtype)
{
case DType_FP32:
+ case DType_BF16:
fdatabuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(fdatabuf);
@@ -238,7 +261,6 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(fdatabuf);
return 1;
}
-
nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf);
free(fdatabuf);
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 4efbf84..a3ce4bb 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -646,6 +646,7 @@ public:
{
case DType_FP32:
case DType_FP16:
+ case DType_BF16:
switch (rank)
{
case 0:
diff --git a/thirdparty/eigen b/thirdparty/eigen
-Subproject 21ae2afd4edaa1b69782c67a54182d34efe43f9
+Subproject 3147391d946bb4b6c68edd901f2add6ac1f31f8
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject e1072a9ed871fd474e7b09b7a74ae7be5f0a6f7
+Subproject 34a627959a61b4eccbeea4400cf9684debb331d
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 8ae3218..b7a76b6 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -9,6 +9,7 @@ from enum import unique
from pathlib import Path
import numpy as np
+from generator.tosa_utils import float32_is_valid_bfloat16
##################################
color_printing = True
@@ -63,7 +64,12 @@ TestResultErrorStr = [
def test_check(
- reference, result, test_name="test", quantize_tolerance=0, float_tolerance=1e-3
+ reference,
+ result,
+ test_name="test",
+ quantize_tolerance=0,
+ float_tolerance=1e-3,
+ misc_checks=[],
):
"""Check if the result is the same as the expected reference."""
if not os.path.isfile(reference):
@@ -111,6 +117,20 @@ def test_check(
)
return (TestResult.MISMATCH, 0.0, msg)
+ # Perform miscellaneous checks
+ if "bf16" in misc_checks:
+ # Ensure floats are valid bfloat16 values
+ test_res_is_bf16 = all([float32_is_valid_bfloat16(f) for f in test_result.flat])
+ ref_res_is_bf16 = all(
+ [float32_is_valid_bfloat16(f) for f in reference_result.flat]
+ )
+ if not (test_res_is_bf16 and ref_res_is_bf16):
+ msg = (
+ "All output values must be valid bfloat16. "
+ "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
+ )
+ return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+
# for quantized test, allow +-(quantize_tolerance) error
if reference_result.dtype == np.int32 or reference_result.dtype == np.int64:
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 0203513..932ad55 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -776,7 +776,7 @@ class TosaTensorValuesGen:
), "Op.MUL must have 2 placeholders, 0 consts"
tens = []
- if dtypeList[0] in (DType.FP16, DType.FP32):
+ if dtypeList[0] in (DType.FP16, DType.BF16, DType.FP32):
tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
else:
placeholders = []
@@ -1130,6 +1130,8 @@ class TosaArgGen:
accum_dtypes = [DType.INT48]
elif dtype == DType.FP16:
accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.BF16:
+ accum_dtypes = [DType.FP32]
elif dtype == DType.FP32:
accum_dtypes = [DType.FP32]
elif error_name is None:
@@ -1304,7 +1306,7 @@ class TosaArgGen:
accum_dtypes = [DType.INT32]
elif dtype == DType.FP16:
accum_dtypes = [DType.FP16, DType.FP32]
- elif dtype == DType.FP32:
+ elif dtype == DType.BF16 or dtype == DType.FP32:
accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
@@ -1417,6 +1419,8 @@ class TosaArgGen:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
+ elif inDtype == DType.BF16:
+ dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP32:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif error_name == ErrorIf.WrongInputType:
@@ -1826,6 +1830,8 @@ class TosaArgGen:
outputDTypeList = [DType.INT48]
elif dtype == DType.FP16:
outputDTypeList = [DType.FP16]
+ elif dtype == DType.BF16:
+ outputDTypeList = [DType.BF16]
elif dtype == DType.FP32:
outputDTypeList = [DType.FP32]
elif error_name == ErrorIf.WrongInputType:
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index abe1a97..a850699 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -158,6 +158,15 @@ class TosaErrorIfArgGen:
DType.INT48,
DType.FP32,
)
+ elif dtype == DType.BF16:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ )
elif dtype == DType.FP32:
incorrect_types = (
DType.INT4,
@@ -299,8 +308,8 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]:
- outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32]
+ if input_dtype in [DType.BOOL, DType.FP16, DType.BF16, DType.FP32]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.BF16, DType.FP32]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
else:
@@ -425,6 +434,7 @@ class TosaErrorValidator:
and output_dtype != DType.INT48
)
or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
+ or (input_dtype == DType.BF16 and output_dtype != DType.BF16)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
):
error_result = True
@@ -442,25 +452,29 @@ class TosaErrorValidator:
input_dtype == DType.FP16
and output_dtype not in (DType.FP16, DType.FP32)
)
+ or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
- input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ input_dtype
+ in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
and output_dtype != DType.INT32
):
error_result = True
elif op["op"] == Op.MUL:
if (
- input_dtype not in (DType.FP16, DType.FP32)
+ input_dtype not in (DType.FP16, DType.BF16, DType.FP32)
and output_dtype != DType.INT32
):
error_result = True
elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
error_result = True
+ elif input_dtype == DType.BF16 and output_dtype != DType.BF16:
+ error_result = True
elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
error_result = True
@@ -489,6 +503,7 @@ class TosaErrorValidator:
DType.INT32,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
)
or (
@@ -500,6 +515,7 @@ class TosaErrorValidator:
DType.INT32,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
)
or (
@@ -511,6 +527,7 @@ class TosaErrorValidator:
DType.INT16,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
)
or (
@@ -518,6 +535,10 @@ class TosaErrorValidator:
and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
or (
+ input_dtype == DType.BF16
+ and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
+ )
+ or (
input_dtype == DType.FP32
and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
@@ -537,6 +558,8 @@ class TosaErrorValidator:
and output_dtype != DType.INT48
or input_dtype == DType.FP16
and output_dtype not in (DType.FP16, DType.FP32)
+ or input_dtype == DType.BF16
+ and output_dtype != DType.FP32
or input_dtype == DType.FP32
and output_dtype != DType.FP32
):
@@ -2316,12 +2339,14 @@ class TosaInvalidValidator:
not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
+ and not (input_dtype == DType.BF16 and output_dtype == DType.BF16)
and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
)
elif mode == ResizeMode.NEAREST:
# Invalid output data type / Invalid input datatype
return (input_dtype != output_dtype) or (
- input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ input_dtype
+ not in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
)
else:
# Invalid resize mode
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 78d86cd..95e06ed 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -16,6 +16,7 @@ from generator.tosa_error_if import TosaInvalidValidator
from generator.tosa_utils import DTYPE_ATTRIBUTES
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
+from generator.tosa_utils import vect_f32_to_bf16
from tosa.DType import DType
from tosa.Op import Op
@@ -84,6 +85,10 @@ class TosaTestGen:
)
elif dtype == DType.FP16:
return np.float16(self.rng.random(size=shape))
+ elif dtype == DType.BF16:
+ f32_tensor = np.float32(self.rng.random(size=shape))
+ # Floor the last 16 bits of each f32 value
+ return np.float32(vect_f32_to_bf16(f32_tensor))
elif dtype == DType.FP32:
return np.float32(self.rng.random(size=shape))
else:
@@ -134,6 +139,9 @@ class TosaTestGen:
elif dtype == DType.FP16:
rand_f32 = self.rng.random()
return np.float16(rand_f32)
+ elif dtype == DType.BF16:
+ rand_f32 = self.rng.random()
+ return vect_f32_to_bf16(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
# TOSA specific INT4 weight range from -7 to 7
@@ -324,7 +332,7 @@ class TosaTestGen:
# Special for multiply:
# Force the result to INT32 for INT types
- if a.dtype not in (DType.FP16, DType.FP32):
+ if a.dtype not in (DType.FP16, DType.BF16, DType.FP32):
result_tens.setDtype(DType.INT32)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
@@ -1043,7 +1051,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- if a.dtype in (DType.FP16, DType.FP32):
+ if a.dtype in (DType.FP16, DType.BF16, DType.FP32):
attr.ClampAttribute(0, 0, min_val, max_val)
else:
attr.ClampAttribute(min_val, max_val, 0, 0)
@@ -1859,7 +1867,7 @@ class TosaTestGen:
op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
)
- if a.dtype in (DType.FP32, DType.FP16, DType.INT32):
+ if a.dtype in (DType.FP32, DType.BF16, DType.FP16, DType.INT32):
then_op, else_op = Op.ADD, Op.SUB
elif a.dtype in (DType.INT8, DType.INT16):
then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
@@ -2398,7 +2406,7 @@ class TosaTestGen:
# if not specified, defaults to (1, 4)
# 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
# 'types': array of datatypes to be tested
- TYPE_FP = [DType.FP32, DType.FP16]
+ TYPE_FP = [DType.FP32, DType.FP16, DType.BF16]
TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
TYPE_INT_FP = [
@@ -2406,13 +2414,20 @@ class TosaTestGen:
DType.INT16,
DType.INT32,
DType.FP16,
+ DType.BF16,
DType.FP32,
] # Excludes INT4
TYPE_BOOL = [DType.BOOL]
- TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32] # floating-types and INT32
+ TYPE_FI32 = [
+ DType.FP32,
+ DType.FP16,
+ DType.BF16,
+ DType.INT32,
+ ] # floating-types and INT32
TYPE_FIB = [
DType.FP16,
+ DType.BF16,
DType.FP32,
DType.INT8,
DType.INT16,
@@ -2421,7 +2436,7 @@ class TosaTestGen:
]
TYPE_FI16 = [DType.FP32, DType.INT16]
- TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
# List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
@@ -2430,6 +2445,7 @@ class TosaTestGen:
[DType.INT16, DType.INT8, DType.INT48],
[DType.FP16, DType.FP16, DType.FP16],
[DType.FP16, DType.FP16, DType.FP32],
+ [DType.BF16, DType.BF16, DType.FP32],
[DType.FP32, DType.FP32, DType.FP32],
]
@@ -3448,7 +3464,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgReduceSum,
TosaArgGen.agAxis,
),
- "types": (DType.FP16, DType.FP32, DType.INT32),
+ "types": (DType.FP16, DType.BF16, DType.FP32, DType.INT32),
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -3635,7 +3651,14 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
None,
),
- "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32),
+ "types": (
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -3676,7 +3699,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agResize,
),
- "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32),
+ "types": (DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32),
"invalid_test_validators": (
TosaInvalidValidator.ivWrongDataTypeOrModeResize,
),
@@ -3712,6 +3735,7 @@ class TosaTestGen:
),
"types": (
DType.FP16,
+ DType.BF16,
DType.FP32,
DType.INT8,
DType.INT16,
@@ -3842,6 +3866,8 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
+ DType.FP16,
+ DType.BF16,
DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
@@ -3872,6 +3898,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3900,6 +3928,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3929,6 +3959,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
outputDType = rng.choice(wrong_dtypes)
else:
@@ -3955,6 +3987,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3987,6 +4021,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
@@ -4189,6 +4225,7 @@ class OutputShaper:
DType.INT48,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4226,6 +4263,8 @@ class OutputShaper:
DType.INT16,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
)
elif a.dtype == DType.INT16:
incorrect_types = (
@@ -4234,8 +4273,12 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
)
- elif a.dtype == DType.FP32 or a.dtype == DType.FP16:
+ elif (
+ a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
+ ):
incorrect_types = (
DType.INT4,
DType.INT8,
@@ -4278,6 +4321,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
}
wrong_dtypes = list(all_dtypes - set([input1.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4306,6 +4351,7 @@ class OutputShaper:
DType.INT48,
DType.FP32,
DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4329,6 +4375,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4347,6 +4395,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4383,6 +4433,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4411,6 +4463,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4435,6 +4489,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4462,6 +4518,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4483,6 +4541,8 @@ class OutputShaper:
DType.INT32,
DType.INT48,
DType.FP32,
+ DType.FP16,
+ DType.BF16,
]
wrong_dtypes.remove(output_dtype)
output_dtype = rng.choice(wrong_dtypes)
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 104d9bb..d79ab3c 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -1,5 +1,9 @@
# Copyright (c) 2021-2022, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
+import struct
+import sys
+
+import numpy as np
from tosa.DType import DType
# Maximum dimension size for output and inputs for RESIZE
@@ -15,6 +19,7 @@ DTYPE_ATTRIBUTES = {
DType.INT32: {"str": "i32", "width": 32},
DType.INT48: {"str": "i48", "width": 48},
DType.FP16: {"str": "f16", "width": 16},
+ DType.BF16: {"str": "bf16", "width": 16},
DType.FP32: {"str": "f32", "width": 32},
}
@@ -125,7 +130,11 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.FP32,
DType.FP16,
)
- elif input_dtype == DType.FP32 or input_dtype == DType.FP16:
+ elif (
+ input_dtype == DType.FP32
+ or input_dtype == DType.FP16
+ or input_dtype == DType.BF16
+ ):
incorrect_types = (
DType.INT4,
DType.INT8,
@@ -134,3 +143,37 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT48,
)
return rng.choice(a=incorrect_types)
+
+
+def float32_is_valid_bfloat16(f):
+ """Return True if float value is valid bfloat16."""
+ f32_bits = get_float32_bitstring(f)
+ return f32_bits[16:] == "0" * 16
+
+
+def get_float32_bitstring(f):
+ """Return a big-endian string of bits representing a 32 bit float."""
+ f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
+ return f"{f32_bits_as_int:032b}"
+
+
+def float32_to_bfloat16(f):
+ """Turns fp32 value into bfloat16 by flooring.
+
+ Floors the least significant 16 bits of the input
+ fp32 value and returns this valid bfloat16 representation as fp32.
+ For simplicity during bit-wrangling, ignores underlying system
+ endianness and interprets as big-endian.
+ Returns a bf16-valid float following system's native byte order.
+ """
+ f32_bits = get_float32_bitstring(f)
+ f32_floored_bits = f32_bits[:16] + "0" * 16
+
+ # Assume sys.byteorder matches system's underlying float byteorder
+ fp_bytes = int(f32_floored_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0] # native byteorder
+
+
+vect_f32_to_bf16 = np.vectorize(
+ float32_to_bfloat16, otypes=(np.float32,)
+) # NumPy vectorize: applies function to vector faster than looping
diff --git a/verif/generator/tosa_verif_build_tests.py b/verif/generator/tosa_verif_build_tests.py
index 2fafacb..ab78b1a 100644
--- a/verif/generator/tosa_verif_build_tests.py
+++ b/verif/generator/tosa_verif_build_tests.py
@@ -5,6 +5,7 @@ import re
from generator.tosa_test_gen import TosaTestGen
from serializer.tosa_serializer import dtype_str_to_val
+from serializer.tosa_serializer import DTypeNames
# Used for parsing a comma-separated list of integers in a string
@@ -150,13 +151,14 @@ def parseArgs(argv):
help="Create tests with a particular input tensor rank",
)
+ # Used for parsing a comma-separated list of integers in a string
parser.add_argument(
"--target-dtype",
dest="target_dtypes",
action="append",
default=None,
type=lambda x: dtype_str_to_val(x),
- help="Create test with a particular DType (may be repeated)",
+ help=f"Create test with a particular DType: [{', '.join([d.lower() for d in DTypeNames[1:]])}] (may be repeated)",
)
parser.add_argument(
diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py
index b608fd8..50ff1ab 100644
--- a/verif/tests/test_tosa_refmodel.py
+++ b/verif/tests/test_tosa_refmodel.py
@@ -47,6 +47,7 @@ REF_MODEL_TYPE_TO_OUT = {
"int32": "i32",
"fp32": "f32",
"fp16": "f16",
+ "bf16": "bf16",
}
@@ -127,11 +128,13 @@ TEST_PARAMS = [
("abs", "int32", 1),
("abs", "fp32", 1),
("abs", "fp16", 1),
+ ("abs", "bf16", 1),
("negate", "int8", 1),
("negate", "int16", 1),
("negate", "int32", 1),
("negate", "fp32", 1),
("negate", "fp16", 1),
+ ("negate", "bf16", 1),
# One test per axis (shape dimensions)
("concat", "bool", SHAPE_DIMS),
("concat", "int8", SHAPE_DIMS),
@@ -139,6 +142,7 @@ TEST_PARAMS = [
("concat", "int32", SHAPE_DIMS),
("concat", "fp32", SHAPE_DIMS),
("concat", "fp16", SHAPE_DIMS),
+ ("concat", "bf16", SHAPE_DIMS),
]
@@ -165,6 +169,9 @@ def test_refmodel_simple_op(tosaTest):
# Generate TOSA test(s) (mostly should be single test)
test_dirs = tosaTest.create_test()
+ # Indicate miscellaneous checks to run in tosa_check
+ misc_checks = []
+
for test_dir in test_dirs:
# Run ref model
desc_file = test_dir / TEST_DESC_FILENAME
@@ -227,8 +234,15 @@ def test_refmodel_simple_op(tosaTest):
np.save(str(result_file), result)
assert result_file.is_file()
+ # Ensure valid bf16
+ if tosaTest.ref_model_type == "bf16":
+ misc_checks.append("bf16")
+
# Check Numpy result versus refmodel
check_result, tolerance, msg = tosa_check(
- str(result_file), str(ofm_file), test_name=test_dir.name
+ str(result_file),
+ str(ofm_file),
+ test_name=test_dir.name,
+ misc_checks=misc_checks,
)
assert check_result == TosaResult.PASS