From 2c34b4616a10539211e7006bc43f3c71e86c30bb Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Tue, 6 Feb 2024 18:37:00 +0000 Subject: Add support for FP8 to reference model Signed-off-by: Won Jeon Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08 --- reference_model/include/dtype.h | 16 +- reference_model/include/types.h | 30 +- reference_model/src/arith_util.h | 6 +- reference_model/src/float_utils.h | 533 +++++++++++++++++++++++++ reference_model/src/generate/generate_utils.cc | 4 + reference_model/src/ops/data_layout.cc | 18 + reference_model/src/ops/data_nodes.cc | 4 +- reference_model/src/ops/op_factory.cc | 51 +++ reference_model/src/ops/scatter_gather.cc | 6 +- reference_model/src/ops/template_types.h | 24 +- reference_model/src/ops/tensor_ops.cc | 18 +- reference_model/src/ops/type_conversion.cc | 175 ++++++++ reference_model/src/ops/type_conversion.h | 278 ++++++++++++- reference_model/src/subgraph_traverser.cc | 18 +- reference_model/src/tensor.cc | 34 +- reference_model/src/tensor.h | 2 + reference_model/src/verify/verify_utils.cc | 15 +- 17 files changed, 1199 insertions(+), 33 deletions(-) create mode 100644 reference_model/src/float_utils.h (limited to 'reference_model') diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h index 1b01a0e..3e8bdf5 100644 --- a/reference_model/include/dtype.h +++ b/reference_model/include/dtype.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -41,6 +41,8 @@ enum TOSA_REF_TYPE : uint32_t TOSA_REF_TYPE_FP16 = 10, TOSA_REF_TYPE_BF16 = 11, TOSA_REF_TYPE_SHAPE = 12, + TOSA_REF_TYPE_FP8E4M3 = 13, + TOSA_REF_TYPE_FP8E5M2 = 14, TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above }; @@ -74,6 +76,10 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) return EnumNameDType(DType_BF16); case TOSA_REF_TYPE_SHAPE: return EnumNameDType(DType_SHAPE); + case TOSA_REF_TYPE_FP8E4M3: + return EnumNameDType(DType_FP8E4M3); + case TOSA_REF_TYPE_FP8E5M2: + return EnumNameDType(DType_FP8E5M2); case TOSA_REF_TYPE_FP64: return "FP64"; default: @@ -85,7 +91,7 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) // return corresponding TOSA_REF_TYPE for DType inline TOSA_REF_TYPE ConvertDType(const DType dtype) { - assert(DType_MAX == DType_SHAPE); // must update whenever DType_MAX changes + assert(DType_MAX == DType_FP8E5M2); // must update whenever DType_MAX changes if (g_func_config.precise_mode) { @@ -95,6 +101,8 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) case DType_FP16: case DType_FP32: case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: return TOSA_REF_TYPE_FP64; default: break; @@ -127,6 +135,10 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) return TOSA_REF_TYPE_BF16; case DType_SHAPE: return TOSA_REF_TYPE_SHAPE; + case DType_FP8E4M3: + return TOSA_REF_TYPE_FP8E4M3; + case DType_FP8E5M2: + return TOSA_REF_TYPE_FP8E5M2; default: break; } diff --git a/reference_model/include/types.h b/reference_model/include/types.h index 15ee40c..32a8ce1 100644 --- a/reference_model/include/types.h +++ b/reference_model/include/types.h @@ -26,19 +26,21 @@ extern "C" enum tosa_datatype_t { - tosa_datatype_bf16_t = 0, - tosa_datatype_bool_t = 1, - tosa_datatype_fp16_t = 2, - tosa_datatype_fp32_t = 3, - tosa_datatype_int16_t = 4, - tosa_datatype_int32_t = 5, - tosa_datatype_int48_t = 6, - tosa_datatype_int4_t = 7, - tosa_datatype_int8_t = 8, - tosa_datatype_uint16_t = 9, - tosa_datatype_uint8_t = 10, - tosa_datatype_shape_t = 11, - tosa_datatype_fp64_t = 99 + tosa_datatype_bf16_t = 0, + tosa_datatype_bool_t = 1, + tosa_datatype_fp16_t = 2, + tosa_datatype_fp32_t = 3, + tosa_datatype_int16_t = 4, + tosa_datatype_int32_t = 5, + tosa_datatype_int48_t = 6, + tosa_datatype_int4_t = 7, + tosa_datatype_int8_t = 8, + tosa_datatype_uint16_t = 9, + tosa_datatype_uint8_t = 10, + tosa_datatype_shape_t = 11, + tosa_datatype_fp8e4m3_t = 12, + tosa_datatype_fp8e5m2_t = 13, + tosa_datatype_fp64_t = 99 }; struct tosa_tensor_t @@ -61,4 +63,4 @@ extern "C" } #endif /* __cplusplus */ -#endif // TYPES_H_ \ No newline at end of file +#endif // TYPES_H_ diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h index fb491db..f0d184c 100644 --- a/reference_model/src/arith_util.h +++ b/reference_model/src/arith_util.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -35,10 +35,8 @@ #include "func_debug.h" #include "half.hpp" #include "inttypes.h" -#include #include #include -#include #include #include #include @@ -244,7 +242,7 @@ float fpTrunc(float f_in) // No-op for fp32 break; default: - ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-truncated.", EnumNameTOSAREFTYPE(Dtype)); + ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-cast.", EnumNameTOSAREFTYPE(Dtype)); } return f_in; } diff --git a/reference_model/src/float_utils.h b/reference_model/src/float_utils.h new file mode 100644 index 0000000..b98c89b --- /dev/null +++ b/reference_model/src/float_utils.h @@ -0,0 +1,533 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FLOAT_UTILS_H_ +#define FLOAT_UTILS_H_ + +#include +#include +#include +#include +#if defined(__cpp_lib_bit_cast) +#include +#endif // defined(__cpp_lib_bit_cast) + +namespace tosa::reference::internal +{ + +namespace float_support +{ + +struct hidden +{}; + +#if defined(__cpp_lib_bit_cast) +#define BITCAST_CONSTEXPR constexpr inline + +constexpr inline int32_t get_bits(const float& f) +{ + return std::bit_cast(f); +} +constexpr inline float from_bits(const int32_t& i) +{ + return std::bit_cast(i); +} + +#else +#define BITCAST_CONSTEXPR inline + +union ufloat32 +{ + constexpr ufloat32(const float& x) + : f(x) + {} + constexpr ufloat32(const int32_t& x) + : i(x) + {} + + float f; + int32_t i; +}; + +inline int32_t get_bits(const float& f) +{ + return ufloat32(f).i; +} +inline float from_bits(const int32_t& i) +{ + return ufloat32(i).f; +} +#endif + +} // namespace float_support + +template = true> +class float_t +{ + storage_t m_data = 0; + +public: + static constexpr size_t n_exponent_bits = n_exp_bits; + static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits; + static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1; + + /// \brief Construct a floating point type with the given bit + /// representation. + static constexpr float_t from_bits(storage_t bits) + { + return float_t(float_support::hidden(), bits); + } + + /// \brief Construct a float from the given sign, exponent and significand + /// bits. + static constexpr float_t from_bits(bool pm, storage_t e, storage_t s) + { + storage_t bits = pm ? 1 : 0; + + bits <<= n_exp_bits; + bits |= e; + + bits <<= n_significand_bits; + if (with_denorm || e) + bits |= s; + + return float_t(float_support::hidden(), bits); + } + + /// \brief (Hidden) Construct a float type from a given bit pattern + constexpr float_t(const float_support::hidden&, storage_t bits) + : m_data(bits) + {} + + constexpr float_t() + : m_data(0) + {} + constexpr float_t(const float_t& other) + : m_data(other.m_data) + {} + + /// \brief Cast to a different floating point representation. + template + constexpr inline + operator float_t() const + { + using other_float_t = + float_t; + + // Shortcut for types which are fundamentally similar (e.g., bf16 -> + // fp32) + if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) && + has_nan == other_has_nan) + { + return other_float_t::from_bits(static_cast(m_data) + << (sizeof(other_storage_t) - sizeof(storage_t)) * 8); + } + + // Get initial values for the new floating point type + const bool sign_bit = m_data < 0; + int64_t new_exponent_bits = 0; + uint64_t new_significand = 0; + + if (is_nan() || is_infinity()) + { + new_exponent_bits = (1 << other_n_exp_bits) - 1; + + if (is_nan()) + { + if constexpr (other_has_infinity) + { + // Copy across the `not_quiet bit`; set the LSB. Don't + // attempt to copy across any of the rest of the payload. + new_significand = + 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits); + } + else + { + new_significand = (1ul << other_float_t::n_significand_bits) - 1; + } + } + else if constexpr (!other_has_infinity) + { + new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); + } + } + else if (!is_zero()) + { + const int64_t this_exponent_bits = exponent_bits(); + { + constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias; + new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1); + } + new_significand = this->significand() << (64 - n_significand_bits); + + // Normalise subnormals + if (this_exponent_bits == 0) + { + // Shift the most-significant 1 out of the magnitude to convert + // it to a significand. Modify the exponent accordingly. + uint8_t shift = __builtin_clzl(new_significand) + 1; + new_exponent_bits -= shift; + new_significand <<= shift; + } + + // Align the significand for the output type + uint32_t shift = 64 - other_float_t::n_significand_bits; + const bool other_is_subnormal = new_exponent_bits <= 0; + if (other_is_subnormal) + { + shift += 1 - new_exponent_bits; + new_exponent_bits = 0; + } + + const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1); + new_significand = shift == 64 ? 0 : new_significand >> shift; + + // Reinsert the most-significant-one if this is a subnormal in the + // output type. + new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift); + + // Apply rounding based on the bits shifted out of the significand + const uint64_t shift_half = 1ll << (shift - 1); + if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1))) + { + new_significand += 1; + + // Handle the case that the significand overflowed due to + // rounding + constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1; + if (new_significand > max_significand) + { + new_significand = 0; + new_exponent_bits++; + } + } + + // Saturate to infinity if the exponent is larger than can be + // represented in the output type. This can only occur if the size + // of the exponent of the new type is not greater than the exponent + // of the old type. + if constexpr (other_n_exp_bits <= n_exp_bits) + { + constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1; + if (new_exponent_bits >= inf_exp_bits) + { + new_exponent_bits = inf_exp_bits; + new_significand = + other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); + } + } + } + + return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand); + } + + /// \brief Convert from a 32-bit floating point value + BITCAST_CONSTEXPR + float_t(const float& f) + { + // If this format exactly represents the binary32 format then get + // the bits from the provided float; otherwise get a binary32 + // representation and then convert to this format. + if constexpr (represents_binary32()) + m_data = float_support::get_bits(f); + else + m_data = static_cast>( + static_cast>(f)) + .m_data; + } + + /// \brief Cast to a 32-bit floating point value + BITCAST_CONSTEXPR operator float() const + { + // If this format exactly represents the binary32 format then return + // a float; otherwise get a binary32 representation and then return + // a float. + if constexpr (represents_binary32()) + return float_support::from_bits(m_data); + else + return static_cast(this->operator float_t()); + } + + /// \brief Return whether this type represents the IEEE754 binary32 + /// format + constexpr static inline bool represents_binary32() + { + return std::is_same_v && n_exp_bits == 8 && has_nan && with_denorm && with_infinity; + } + + constexpr auto operator-() const + { + return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1))); + } + + constexpr bool is_subnormal() const + { + return exponent_bits() == 0 && significand() != 0; + } + + constexpr bool is_zero() const + { + return exponent_bits() == 0 && significand() == 0; + } + + constexpr bool is_nan() const + { + return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) && + ((with_infinity && significand()) || + (!with_infinity && significand() == (1ul << n_significand_bits) - 1)); + } + + constexpr bool is_infinity() const + { + return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand()); + } + + constexpr inline const storage_t& bits() const + { + return m_data; + } + + /// \brief Get the exponent + constexpr inline int64_t exponent() const + { + return std::max(exponent_bits(), 1ul) - exponent_bias; + } + + /// \brief Get the bits from the exponent + constexpr inline uint64_t exponent_bits() const + { + constexpr uint64_t mask = (1ul << n_exp_bits) - 1; + return (m_data >> n_significand_bits) & mask; + } + + constexpr inline uint64_t significand() const + { + return m_data & ((1ul << n_significand_bits) - 1); + } + + constexpr inline bool operator==(const float_t& other) const + { + return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits()); + } + + constexpr inline float_t& operator+=(const float_t& rhs) + { + this->m_data = static_cast(static_cast(*this) + static_cast(rhs)).bits(); + return *this; + } +}; + +// This should probably be exported so we can use it elsewhere +#undef BITCAST_CONSTEXPR + +namespace float_support +{ + +// Pre-C++23 these can't be computed as constexpr, so have to hardcode them + +template +struct digits10; // floor(log10(2) * (digits - 1) +template +struct max_digits10; // ceil(log10(2) * digits + 1) +template +struct min_exponent10; // floor(log10(2) * min_exponent) +template +struct max_exponent10; // floor(log10(2) * max_exponent) + +template <> +struct digits10<8> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<8> +{ + constexpr static inline int value = 4; +}; + +template <> +struct digits10<10> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<10> +{ + constexpr static inline int value = 5; +}; + +template <> +struct digits10<24> +{ + constexpr static inline int value = 6; +}; + +template <> +struct max_digits10<24> +{ + constexpr static inline int value = 9; +}; + +template <> +struct min_exponent10<-13> +{ + constexpr static inline int value = -3; +}; + +template <> +struct max_exponent10<16> +{ + constexpr static inline int value = 4; +}; + +template <> +struct min_exponent10<-125> +{ + constexpr static inline int value = -37; +}; + +template <> +struct max_exponent10<128> +{ + constexpr static inline int value = 38; +}; + +template +inline constexpr int digits10_v = digits10::value; +template +inline constexpr int max_digits10_v = max_digits10::value; + +template +inline constexpr int min_exponent10_v = min_exponent10::value; + +template +inline constexpr int max_exponent10_v = max_exponent10::value; + +} // namespace float_support + +} // namespace tosa::reference::internal + +namespace std +{ + +template +struct is_floating_point> + : std::integral_constant +{}; + +template +class numeric_limits> +{ + using this_float_t = tosa::reference::internal::float_t; + +public: + static constexpr bool is_specialized = true; + + static constexpr auto min() noexcept + { + return this_float_t::from_bits(false, 1, 0); + } + + static constexpr auto max() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2, + (1 << this_float_t::n_significand_bits) - 1); + } + + static constexpr auto lowest() noexcept + { + return -max(); + } + + static constexpr int digits = this_float_t::n_significand_bits + 1; + static constexpr int digits10 = tosa::reference::internal::float_support::digits10_v; + static constexpr int max_digits10 = tosa::reference::internal::float_support::max_digits10_v; + + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr int radix = 2; + + static constexpr auto epsilon() noexcept + { + return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0); + } + + static constexpr auto round_error() noexcept + { + return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0); + } + + static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1; + static constexpr int min_exponent10 = tosa::reference::internal::float_support::min_exponent10_v; + static constexpr int max_exponent = this_float_t::exponent_bias + 1; + static constexpr int max_exponent10 = tosa::reference::internal::float_support::max_exponent10_v; + + static constexpr bool has_infinity = with_inf; + static constexpr bool has_quiet_NaN = has_nan; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent; + static constexpr bool has_denorm_loss = false; + + static constexpr auto infinity() noexcept + { + if constexpr (with_inf) + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0); + } + else + { + return this_float_t::from_bits(false, 0, 0); + } + } + + static constexpr auto quiet_NaN() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, + 1 << (this_float_t::n_significand_bits - 1) | 1); + } + + static constexpr auto signaling_NaN() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1); + } + + static constexpr auto denorm_min() noexcept + { + return this_float_t::from_bits(false, 0, 1); + } + + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = false; + static constexpr bool is_modulo = false; + + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; + +} // namespace std + +#endif // _FLOAT_UTILS_H_ diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index 8b16e97..271b7f5 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -34,6 +34,8 @@ NLOHMANN_JSON_SERIALIZE_ENUM(DType, { DType::DType_BF16, "BF16" }, { DType::DType_FP32, "FP32" }, { DType::DType_SHAPE, "SHAPE" }, + { DType::DType_FP8E4M3, "FP8E4M3" }, + { DType::DType_FP8E5M2, "FP8E5M2" }, }) NLOHMANN_JSON_SERIALIZE_ENUM(Op, @@ -225,6 +227,8 @@ size_t elementSizeFromType(DType type) case DType::DType_BOOL: case DType::DType_UINT8: case DType::DType_INT8: + case DType::DType_FP8E4M3: + case DType::DType_FP8E5M2: return 1; case DType::DType_UINT16: case DType::DType_INT16: diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index ec9614a..ddf0713 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -759,6 +759,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); @@ -768,6 +770,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BF16); @@ -776,6 +780,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2); DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); DEF_INSTANTIATE_RESHAPE(OpReshape, BF16); @@ -785,6 +791,8 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); DEF_INSTANTIATE_RESHAPE(OpReshape, FP64); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E4M3); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); @@ -794,6 +802,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); @@ -803,6 +813,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16); @@ -812,6 +824,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16); @@ -821,6 +835,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); @@ -830,3 +846,5 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index 705981c..64001a9 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -105,3 +105,5 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index af8332e..6d66c07 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -55,6 +55,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2); break; case Op_AVG_POOL2D: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); @@ -64,6 +66,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E4M3, FP16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16); break; case Op_CONV2D: DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -74,6 +78,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); break; case Op_CONV3D: DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); @@ -84,6 +91,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); break; case Op_DEPTHWISE_CONV2D: DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); @@ -94,6 +103,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); break; case Op_FFT2D: DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); @@ -117,6 +128,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E4M3, FP16); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E5M2, FP16); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); @@ -125,6 +138,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E4M3); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E5M2); break; case Op_RFFT2D: DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32); @@ -139,6 +154,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); break; // activation_funcs @@ -409,6 +426,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2); break; case Op_PAD: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); @@ -419,6 +438,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2); break; case Op_DIM: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16); @@ -428,6 +449,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2); break; case Op_RESHAPE: DEF_FACTORY_RESHAPE(OpReshape, FP16); @@ -438,6 +461,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RESHAPE(OpReshape, INT32); DEF_FACTORY_RESHAPE(OpReshape, BOOL); DEF_FACTORY_RESHAPE(OpReshape, FP64); + DEF_FACTORY_RESHAPE(OpReshape, FP8E4M3); + DEF_FACTORY_RESHAPE(OpReshape, FP8E5M2); break; case Op_REVERSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); @@ -448,6 +473,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2); break; case Op_SLICE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); @@ -458,6 +485,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2); break; case Op_TILE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); @@ -468,6 +497,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2); break; case Op_TRANSPOSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); @@ -478,6 +509,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2); break; // scatter_gather @@ -489,6 +522,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, BF16); DEF_FACTORY_ONE_TYPE(OpGather, FP32); DEF_FACTORY_ONE_TYPE(OpGather, FP64); + DEF_FACTORY_ONE_TYPE(OpGather, FP8E4M3); + DEF_FACTORY_ONE_TYPE(OpGather, FP8E5M2); break; case Op_SCATTER: DEF_FACTORY_ONE_TYPE(OpScatter, INT8); @@ -498,6 +533,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpScatter, BF16); DEF_FACTORY_ONE_TYPE(OpScatter, FP32); DEF_FACTORY_ONE_TYPE(OpScatter, FP64); + DEF_FACTORY_ONE_TYPE(OpScatter, FP8E4M3); + DEF_FACTORY_ONE_TYPE(OpScatter, FP8E5M2); break; // image @@ -524,6 +561,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2); break; // type_conversion @@ -569,6 +608,18 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2); break; case Op_RESCALE: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index bd16ad1..85397ae 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -236,6 +236,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP8E4M3); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP8E5M2); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); @@ -244,3 +246,5 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP8E4M3); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP8E5M2); diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 342d5c2..41e6061 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -88,6 +88,18 @@ struct GetEigenType using type = float; }; template <> +struct GetEigenType +{ + // NOTE: full precision used + using type = float; +}; +template <> +struct GetEigenType +{ + // NOTE: full precision used + using type = float; +}; +template <> struct GetEigenType { using type = int32_t; @@ -200,6 +212,16 @@ struct GetNumBits { static constexpr int32_t value = 16; }; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 8; +}; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 8; +}; // Meta function to get quantized min/max in compile time template diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index dd66f79..124dc87 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -555,7 +555,7 @@ int OpAvgPool2d::eval() } } if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 && - Dtype != TOSA_REF_TYPE_FP64) + Dtype != TOSA_REF_TYPE_FP64 && Dtype != TOSA_REF_TYPE_FP8E4M3 && Dtype != TOSA_REF_TYPE_FP8E5M2) { try { @@ -2155,6 +2155,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32); @@ -2163,6 +2165,8 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16); // [in_t, weight_t, out_t] DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -2173,6 +2177,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); @@ -2182,6 +2188,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); @@ -2191,6 +2199,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64); @@ -2211,6 +2221,8 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E4M3, FP16); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E5M2, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16); @@ -2218,6 +2230,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E4M3); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E5M2); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64); @@ -2230,3 +2244,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 484f768..5dbc7bd 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -15,6 +15,7 @@ #include "type_conversion.h" #include "arith_util.h" +#include "float_utils.h" #include "half.hpp" #include "quant_util.h" #include "template_types.h" @@ -24,6 +25,12 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; +using fp16 = tosa::reference::internal::float_t; +using bf16 = tosa::reference::internal::float_t; +using fp32 = tosa::reference::internal::float_t; +using fp8e4m3 = tosa::reference::internal::float_t; +using fp8e5m2 = tosa::reference::internal::float_t; + template OpRescale::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_RESCALE, id_) @@ -526,6 +533,162 @@ CastHelper::CastHelper() }; } +template +CastHelper::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to integer + fcn = [](float in) -> OutEigenType { + if (in >= float(OutMax)) + return OutMax; + if (in <= float(OutMin)) + return OutMin; + + OutEigenType out = std::rint(in); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to fp16 (stored as fp32) + fcn = [](float in) -> float { + half_float::half h = half_float::half(in); + float out = half_float::half_cast(h); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to bf16 (stored as fp32) + fcn = [](float in) -> float { return (float)in; }; +} + +CastHelper::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to fp32 + fcn = [](InEigenType in) -> OutEigenType { return in; }; +} + +template +CastHelper::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to integer + fcn = [](float in) -> OutEigenType { + if (in >= float(OutMax)) + return OutMax; + if (in <= float(OutMin)) + return OutMin; + + OutEigenType out = std::rint(in); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to fp16 (stored as fp32) + fcn = [](float in) -> float { + half_float::half h = half_float::half(in); + float out = half_float::half_cast(h); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to bf16 (stored as fp32) + fcn = [](float in) -> float { return (float)in; }; +} + +CastHelper::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to fp32 + fcn = [](InEigenType in) -> OutEigenType { return in; }; +} + +template +CastHelper::CastHelper() +{ + // Integer data converted to fp8e4m3 (stored as fp32) + fcn = [](InEigenType in) -> float { + auto f = static_cast(static_cast(float(in))); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp16 data (stored as fp32) converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // bf16 data (stored as fp32) converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp32 data converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + +template +CastHelper::CastHelper() +{ + // Integer data converted to fp8e5m2 (stored as fp32) + fcn = [](InEigenType in) -> float { + auto f = static_cast(static_cast(float(in))); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp16 data (stored as fp32) converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // bf16 data (stored as fp32) converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp32 data converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + template CastHelper::CastHelper() { @@ -597,6 +760,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 98799a0..75f244d 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -276,6 +276,282 @@ private: FcnType fcn; }; +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutMin = GetQMin::value; + static constexpr int32_t OutMax = GetQMax::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutMin = GetQMin::value; + static constexpr int32_t OutMax = GetQMax::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template class CastHelper { diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index fae0b30..33a9b94 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -595,6 +595,22 @@ int SubgraphTraverser::allocateTensor(std::string name) } } break; + case DType_FP8E4M3: + case DType_FP8E5M2: { + std::vector fp32_data; + TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); + // Ensure valid fp8 stored in each float + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } + } + break; case DType_FP32: { std::vector fp32_data; TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index f9ec937..27f21f3 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -115,12 +115,15 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) assert(dtype == ConvertDType(serialization_dtype)); // if dtype is FP64, serialization_dtype must be one of FP32, FP16, BF16 assert(dtype != TOSA_REF_TYPE_FP64 || serialization_dtype == DType_FP32 || serialization_dtype == DType_FP16 || - serialization_dtype == DType_BF16); + serialization_dtype == DType_BF16 || serialization_dtype == DType_FP8E4M3 || + serialization_dtype == DType_FP8E5M2); switch (serialization_dtype) { case DType_FP32: case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); @@ -208,6 +211,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) return 1; } break; + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: + if (setTensorValueFloat(elements, f32databuf)) + { + free(f32databuf); + return 1; + } + break; case TOSA_REF_TYPE_FP32: if (setTensorValueFloat(elements, f32databuf)) { @@ -276,6 +287,23 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) return 1; } break; + case DType_FP8E4M3: + case DType_FP8E5M2: + // FP8E4M3 -> FP64 + f64databuf = (double*)calloc(sizeof(double), elements); + ASSERT_MEM(f64databuf); + for (uint32_t i = 0; i < elements; i++) + { + //ASSERT_MSG(checkValidFloat8(f32databuf[i]), "Input float value not a valid float8 value."); + f64databuf[i] = static_cast(f32databuf[i]); + } + if (setTensorValueDouble(elements, f64databuf)) + { + free(f32databuf); + free(f64databuf); + return 1; + } + break; case DType_FP32: // FP32 -> FP64 f64databuf = (double*)calloc(sizeof(double), elements); @@ -349,6 +377,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(f32databuf); break; case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); @@ -631,6 +661,8 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) // continue with setting float vals in the tensor case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 2c3be7f..1659a2f 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -863,6 +863,8 @@ public: case TOSA_REF_TYPE_FP32: case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: switch (rank) { case 0: diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc index 14bc6f1..4ae245b 100644 --- a/reference_model/src/verify/verify_utils.cc +++ b/reference_model/src/verify/verify_utils.cc @@ -36,6 +36,8 @@ NLOHMANN_JSON_SERIALIZE_ENUM(DType, { DType::DType_FP16, "FP16" }, { DType::DType_BF16, "BF16" }, { DType::DType_FP32, "FP32" }, + { DType::DType_FP8E4M3, "FP8E4M3" }, + { DType::DType_FP8E5M2, "FP8E5M2" }, }) } // namespace tosa @@ -177,12 +179,13 @@ std::string positionToString(const std::vector& pos) DType mapToDType(tosa_datatype_t dataType) { static std::map typeMap = { - { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 }, - { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 }, - { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 }, - { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 }, - { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 }, - { tosa_datatype_shape_t, DType_SHAPE }, + { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 }, + { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 }, + { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 }, + { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 }, + { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 }, + { tosa_datatype_shape_t, DType_SHAPE }, { tosa_datatype_fp8e4m3_t, DType_FP8E4M3 }, + { tosa_datatype_fp8e5m2_t, DType_FP8E5M2 }, }; if (typeMap.count(dataType)) -- cgit v1.2.1