From 2c34b4616a10539211e7006bc43f3c71e86c30bb Mon Sep 17 00:00:00 2001 From: Won Jeon Date: Tue, 6 Feb 2024 18:37:00 +0000 Subject: Add support for FP8 to reference model Signed-off-by: Won Jeon Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08 --- reference_model/include/dtype.h | 16 +- reference_model/include/types.h | 30 +- reference_model/src/arith_util.h | 6 +- reference_model/src/float_utils.h | 533 ++++++++++++++++++++++ reference_model/src/generate/generate_utils.cc | 4 + reference_model/src/ops/data_layout.cc | 18 + reference_model/src/ops/data_nodes.cc | 4 +- reference_model/src/ops/op_factory.cc | 51 +++ reference_model/src/ops/scatter_gather.cc | 6 +- reference_model/src/ops/template_types.h | 24 +- reference_model/src/ops/tensor_ops.cc | 18 +- reference_model/src/ops/type_conversion.cc | 175 +++++++ reference_model/src/ops/type_conversion.h | 278 ++++++++++- reference_model/src/subgraph_traverser.cc | 18 +- reference_model/src/tensor.cc | 34 +- reference_model/src/tensor.h | 2 + reference_model/src/verify/verify_utils.cc | 15 +- thirdparty/serialization_lib | 2 +- verif/checker/tosa_result_checker.py | 13 + verif/conformance/tosa_main_profile_ops_info.json | 448 ++++++++++++++++++ verif/generator/tosa_arg_gen.py | 60 ++- verif/generator/tosa_error_if.py | 72 ++- verif/generator/tosa_test_gen.py | 95 +++- verif/generator/tosa_utils.py | 42 ++ 24 files changed, 1902 insertions(+), 62 deletions(-) create mode 100644 reference_model/src/float_utils.h diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h index 1b01a0e..3e8bdf5 100644 --- a/reference_model/include/dtype.h +++ b/reference_model/include/dtype.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -41,6 +41,8 @@ enum TOSA_REF_TYPE : uint32_t TOSA_REF_TYPE_FP16 = 10, TOSA_REF_TYPE_BF16 = 11, TOSA_REF_TYPE_SHAPE = 12, + TOSA_REF_TYPE_FP8E4M3 = 13, + TOSA_REF_TYPE_FP8E5M2 = 14, TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above }; @@ -74,6 +76,10 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) return EnumNameDType(DType_BF16); case TOSA_REF_TYPE_SHAPE: return EnumNameDType(DType_SHAPE); + case TOSA_REF_TYPE_FP8E4M3: + return EnumNameDType(DType_FP8E4M3); + case TOSA_REF_TYPE_FP8E5M2: + return EnumNameDType(DType_FP8E5M2); case TOSA_REF_TYPE_FP64: return "FP64"; default: @@ -85,7 +91,7 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) // return corresponding TOSA_REF_TYPE for DType inline TOSA_REF_TYPE ConvertDType(const DType dtype) { - assert(DType_MAX == DType_SHAPE); // must update whenever DType_MAX changes + assert(DType_MAX == DType_FP8E5M2); // must update whenever DType_MAX changes if (g_func_config.precise_mode) { @@ -95,6 +101,8 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) case DType_FP16: case DType_FP32: case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: return TOSA_REF_TYPE_FP64; default: break; @@ -127,6 +135,10 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) return TOSA_REF_TYPE_BF16; case DType_SHAPE: return TOSA_REF_TYPE_SHAPE; + case DType_FP8E4M3: + return TOSA_REF_TYPE_FP8E4M3; + case DType_FP8E5M2: + return TOSA_REF_TYPE_FP8E5M2; default: break; } diff --git a/reference_model/include/types.h b/reference_model/include/types.h index 15ee40c..32a8ce1 100644 --- a/reference_model/include/types.h +++ b/reference_model/include/types.h @@ -26,19 +26,21 @@ extern "C" enum tosa_datatype_t { - tosa_datatype_bf16_t = 0, - tosa_datatype_bool_t = 1, - tosa_datatype_fp16_t = 2, - tosa_datatype_fp32_t = 3, - tosa_datatype_int16_t = 4, - tosa_datatype_int32_t = 5, - tosa_datatype_int48_t = 6, - tosa_datatype_int4_t = 7, - tosa_datatype_int8_t = 8, - tosa_datatype_uint16_t = 9, - tosa_datatype_uint8_t = 10, - tosa_datatype_shape_t = 11, - tosa_datatype_fp64_t = 99 + tosa_datatype_bf16_t = 0, + tosa_datatype_bool_t = 1, + tosa_datatype_fp16_t = 2, + tosa_datatype_fp32_t = 3, + tosa_datatype_int16_t = 4, + tosa_datatype_int32_t = 5, + tosa_datatype_int48_t = 6, + tosa_datatype_int4_t = 7, + tosa_datatype_int8_t = 8, + tosa_datatype_uint16_t = 9, + tosa_datatype_uint8_t = 10, + tosa_datatype_shape_t = 11, + tosa_datatype_fp8e4m3_t = 12, + tosa_datatype_fp8e5m2_t = 13, + tosa_datatype_fp64_t = 99 }; struct tosa_tensor_t @@ -61,4 +63,4 @@ extern "C" } #endif /* __cplusplus */ -#endif // TYPES_H_ \ No newline at end of file +#endif // TYPES_H_ diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h index fb491db..f0d184c 100644 --- a/reference_model/src/arith_util.h +++ b/reference_model/src/arith_util.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -35,10 +35,8 @@ #include "func_debug.h" #include "half.hpp" #include "inttypes.h" -#include #include #include -#include #include #include #include @@ -244,7 +242,7 @@ float fpTrunc(float f_in) // No-op for fp32 break; default: - ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-truncated.", EnumNameTOSAREFTYPE(Dtype)); + ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-cast.", EnumNameTOSAREFTYPE(Dtype)); } return f_in; } diff --git a/reference_model/src/float_utils.h b/reference_model/src/float_utils.h new file mode 100644 index 0000000..b98c89b --- /dev/null +++ b/reference_model/src/float_utils.h @@ -0,0 +1,533 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FLOAT_UTILS_H_ +#define FLOAT_UTILS_H_ + +#include +#include +#include +#include +#if defined(__cpp_lib_bit_cast) +#include +#endif // defined(__cpp_lib_bit_cast) + +namespace tosa::reference::internal +{ + +namespace float_support +{ + +struct hidden +{}; + +#if defined(__cpp_lib_bit_cast) +#define BITCAST_CONSTEXPR constexpr inline + +constexpr inline int32_t get_bits(const float& f) +{ + return std::bit_cast(f); +} +constexpr inline float from_bits(const int32_t& i) +{ + return std::bit_cast(i); +} + +#else +#define BITCAST_CONSTEXPR inline + +union ufloat32 +{ + constexpr ufloat32(const float& x) + : f(x) + {} + constexpr ufloat32(const int32_t& x) + : i(x) + {} + + float f; + int32_t i; +}; + +inline int32_t get_bits(const float& f) +{ + return ufloat32(f).i; +} +inline float from_bits(const int32_t& i) +{ + return ufloat32(i).f; +} +#endif + +} // namespace float_support + +template = true> +class float_t +{ + storage_t m_data = 0; + +public: + static constexpr size_t n_exponent_bits = n_exp_bits; + static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits; + static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1; + + /// \brief Construct a floating point type with the given bit + /// representation. + static constexpr float_t from_bits(storage_t bits) + { + return float_t(float_support::hidden(), bits); + } + + /// \brief Construct a float from the given sign, exponent and significand + /// bits. + static constexpr float_t from_bits(bool pm, storage_t e, storage_t s) + { + storage_t bits = pm ? 1 : 0; + + bits <<= n_exp_bits; + bits |= e; + + bits <<= n_significand_bits; + if (with_denorm || e) + bits |= s; + + return float_t(float_support::hidden(), bits); + } + + /// \brief (Hidden) Construct a float type from a given bit pattern + constexpr float_t(const float_support::hidden&, storage_t bits) + : m_data(bits) + {} + + constexpr float_t() + : m_data(0) + {} + constexpr float_t(const float_t& other) + : m_data(other.m_data) + {} + + /// \brief Cast to a different floating point representation. + template + constexpr inline + operator float_t() const + { + using other_float_t = + float_t; + + // Shortcut for types which are fundamentally similar (e.g., bf16 -> + // fp32) + if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) && + has_nan == other_has_nan) + { + return other_float_t::from_bits(static_cast(m_data) + << (sizeof(other_storage_t) - sizeof(storage_t)) * 8); + } + + // Get initial values for the new floating point type + const bool sign_bit = m_data < 0; + int64_t new_exponent_bits = 0; + uint64_t new_significand = 0; + + if (is_nan() || is_infinity()) + { + new_exponent_bits = (1 << other_n_exp_bits) - 1; + + if (is_nan()) + { + if constexpr (other_has_infinity) + { + // Copy across the `not_quiet bit`; set the LSB. Don't + // attempt to copy across any of the rest of the payload. + new_significand = + 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits); + } + else + { + new_significand = (1ul << other_float_t::n_significand_bits) - 1; + } + } + else if constexpr (!other_has_infinity) + { + new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); + } + } + else if (!is_zero()) + { + const int64_t this_exponent_bits = exponent_bits(); + { + constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias; + new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1); + } + new_significand = this->significand() << (64 - n_significand_bits); + + // Normalise subnormals + if (this_exponent_bits == 0) + { + // Shift the most-significant 1 out of the magnitude to convert + // it to a significand. Modify the exponent accordingly. + uint8_t shift = __builtin_clzl(new_significand) + 1; + new_exponent_bits -= shift; + new_significand <<= shift; + } + + // Align the significand for the output type + uint32_t shift = 64 - other_float_t::n_significand_bits; + const bool other_is_subnormal = new_exponent_bits <= 0; + if (other_is_subnormal) + { + shift += 1 - new_exponent_bits; + new_exponent_bits = 0; + } + + const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1); + new_significand = shift == 64 ? 0 : new_significand >> shift; + + // Reinsert the most-significant-one if this is a subnormal in the + // output type. + new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift); + + // Apply rounding based on the bits shifted out of the significand + const uint64_t shift_half = 1ll << (shift - 1); + if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1))) + { + new_significand += 1; + + // Handle the case that the significand overflowed due to + // rounding + constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1; + if (new_significand > max_significand) + { + new_significand = 0; + new_exponent_bits++; + } + } + + // Saturate to infinity if the exponent is larger than can be + // represented in the output type. This can only occur if the size + // of the exponent of the new type is not greater than the exponent + // of the old type. + if constexpr (other_n_exp_bits <= n_exp_bits) + { + constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1; + if (new_exponent_bits >= inf_exp_bits) + { + new_exponent_bits = inf_exp_bits; + new_significand = + other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); + } + } + } + + return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand); + } + + /// \brief Convert from a 32-bit floating point value + BITCAST_CONSTEXPR + float_t(const float& f) + { + // If this format exactly represents the binary32 format then get + // the bits from the provided float; otherwise get a binary32 + // representation and then convert to this format. + if constexpr (represents_binary32()) + m_data = float_support::get_bits(f); + else + m_data = static_cast>( + static_cast>(f)) + .m_data; + } + + /// \brief Cast to a 32-bit floating point value + BITCAST_CONSTEXPR operator float() const + { + // If this format exactly represents the binary32 format then return + // a float; otherwise get a binary32 representation and then return + // a float. + if constexpr (represents_binary32()) + return float_support::from_bits(m_data); + else + return static_cast(this->operator float_t()); + } + + /// \brief Return whether this type represents the IEEE754 binary32 + /// format + constexpr static inline bool represents_binary32() + { + return std::is_same_v && n_exp_bits == 8 && has_nan && with_denorm && with_infinity; + } + + constexpr auto operator-() const + { + return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1))); + } + + constexpr bool is_subnormal() const + { + return exponent_bits() == 0 && significand() != 0; + } + + constexpr bool is_zero() const + { + return exponent_bits() == 0 && significand() == 0; + } + + constexpr bool is_nan() const + { + return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) && + ((with_infinity && significand()) || + (!with_infinity && significand() == (1ul << n_significand_bits) - 1)); + } + + constexpr bool is_infinity() const + { + return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand()); + } + + constexpr inline const storage_t& bits() const + { + return m_data; + } + + /// \brief Get the exponent + constexpr inline int64_t exponent() const + { + return std::max(exponent_bits(), 1ul) - exponent_bias; + } + + /// \brief Get the bits from the exponent + constexpr inline uint64_t exponent_bits() const + { + constexpr uint64_t mask = (1ul << n_exp_bits) - 1; + return (m_data >> n_significand_bits) & mask; + } + + constexpr inline uint64_t significand() const + { + return m_data & ((1ul << n_significand_bits) - 1); + } + + constexpr inline bool operator==(const float_t& other) const + { + return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits()); + } + + constexpr inline float_t& operator+=(const float_t& rhs) + { + this->m_data = static_cast(static_cast(*this) + static_cast(rhs)).bits(); + return *this; + } +}; + +// This should probably be exported so we can use it elsewhere +#undef BITCAST_CONSTEXPR + +namespace float_support +{ + +// Pre-C++23 these can't be computed as constexpr, so have to hardcode them + +template +struct digits10; // floor(log10(2) * (digits - 1) +template +struct max_digits10; // ceil(log10(2) * digits + 1) +template +struct min_exponent10; // floor(log10(2) * min_exponent) +template +struct max_exponent10; // floor(log10(2) * max_exponent) + +template <> +struct digits10<8> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<8> +{ + constexpr static inline int value = 4; +}; + +template <> +struct digits10<10> +{ + constexpr static inline int value = 2; +}; + +template <> +struct max_digits10<10> +{ + constexpr static inline int value = 5; +}; + +template <> +struct digits10<24> +{ + constexpr static inline int value = 6; +}; + +template <> +struct max_digits10<24> +{ + constexpr static inline int value = 9; +}; + +template <> +struct min_exponent10<-13> +{ + constexpr static inline int value = -3; +}; + +template <> +struct max_exponent10<16> +{ + constexpr static inline int value = 4; +}; + +template <> +struct min_exponent10<-125> +{ + constexpr static inline int value = -37; +}; + +template <> +struct max_exponent10<128> +{ + constexpr static inline int value = 38; +}; + +template +inline constexpr int digits10_v = digits10::value; +template +inline constexpr int max_digits10_v = max_digits10::value; + +template +inline constexpr int min_exponent10_v = min_exponent10::value; + +template +inline constexpr int max_exponent10_v = max_exponent10::value; + +} // namespace float_support + +} // namespace tosa::reference::internal + +namespace std +{ + +template +struct is_floating_point> + : std::integral_constant +{}; + +template +class numeric_limits> +{ + using this_float_t = tosa::reference::internal::float_t; + +public: + static constexpr bool is_specialized = true; + + static constexpr auto min() noexcept + { + return this_float_t::from_bits(false, 1, 0); + } + + static constexpr auto max() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2, + (1 << this_float_t::n_significand_bits) - 1); + } + + static constexpr auto lowest() noexcept + { + return -max(); + } + + static constexpr int digits = this_float_t::n_significand_bits + 1; + static constexpr int digits10 = tosa::reference::internal::float_support::digits10_v; + static constexpr int max_digits10 = tosa::reference::internal::float_support::max_digits10_v; + + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr int radix = 2; + + static constexpr auto epsilon() noexcept + { + return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0); + } + + static constexpr auto round_error() noexcept + { + return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0); + } + + static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1; + static constexpr int min_exponent10 = tosa::reference::internal::float_support::min_exponent10_v; + static constexpr int max_exponent = this_float_t::exponent_bias + 1; + static constexpr int max_exponent10 = tosa::reference::internal::float_support::max_exponent10_v; + + static constexpr bool has_infinity = with_inf; + static constexpr bool has_quiet_NaN = has_nan; + static constexpr bool has_signaling_NaN = true; + static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent; + static constexpr bool has_denorm_loss = false; + + static constexpr auto infinity() noexcept + { + if constexpr (with_inf) + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0); + } + else + { + return this_float_t::from_bits(false, 0, 0); + } + } + + static constexpr auto quiet_NaN() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, + 1 << (this_float_t::n_significand_bits - 1) | 1); + } + + static constexpr auto signaling_NaN() noexcept + { + return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1); + } + + static constexpr auto denorm_min() noexcept + { + return this_float_t::from_bits(false, 0, 1); + } + + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = false; + static constexpr bool is_modulo = false; + + static constexpr bool traps = false; + static constexpr bool tinyness_before = false; + static constexpr float_round_style round_style = round_to_nearest; +}; + +} // namespace std + +#endif // _FLOAT_UTILS_H_ diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index 8b16e97..271b7f5 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -34,6 +34,8 @@ NLOHMANN_JSON_SERIALIZE_ENUM(DType, { DType::DType_BF16, "BF16" }, { DType::DType_FP32, "FP32" }, { DType::DType_SHAPE, "SHAPE" }, + { DType::DType_FP8E4M3, "FP8E4M3" }, + { DType::DType_FP8E5M2, "FP8E5M2" }, }) NLOHMANN_JSON_SERIALIZE_ENUM(Op, @@ -225,6 +227,8 @@ size_t elementSizeFromType(DType type) case DType::DType_BOOL: case DType::DType_UINT8: case DType::DType_INT8: + case DType::DType_FP8E4M3: + case DType::DType_FP8E5M2: return 1; case DType::DType_UINT16: case DType::DType_INT16: diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index ec9614a..ddf0713 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -759,6 +759,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); @@ -768,6 +770,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BF16); @@ -776,6 +780,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2); DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); DEF_INSTANTIATE_RESHAPE(OpReshape, BF16); @@ -785,6 +791,8 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); DEF_INSTANTIATE_RESHAPE(OpReshape, FP64); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E4M3); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); @@ -794,6 +802,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); @@ -803,6 +813,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16); @@ -812,6 +824,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16); @@ -821,6 +835,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); @@ -830,3 +846,5 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2); diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index 705981c..64001a9 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -105,3 +105,5 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index af8332e..6d66c07 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -55,6 +55,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2); break; case Op_AVG_POOL2D: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); @@ -64,6 +66,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E4M3, FP16); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16); break; case Op_CONV2D: DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -74,6 +78,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); break; case Op_CONV3D: DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); @@ -84,6 +91,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); break; case Op_DEPTHWISE_CONV2D: DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); @@ -94,6 +103,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); break; case Op_FFT2D: DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); @@ -117,6 +128,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E4M3, FP16); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E5M2, FP16); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); @@ -125,6 +138,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E4M3); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E5M2); break; case Op_RFFT2D: DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32); @@ -139,6 +154,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); break; // activation_funcs @@ -409,6 +426,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2); break; case Op_PAD: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); @@ -419,6 +438,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2); break; case Op_DIM: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16); @@ -428,6 +449,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2); break; case Op_RESHAPE: DEF_FACTORY_RESHAPE(OpReshape, FP16); @@ -438,6 +461,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RESHAPE(OpReshape, INT32); DEF_FACTORY_RESHAPE(OpReshape, BOOL); DEF_FACTORY_RESHAPE(OpReshape, FP64); + DEF_FACTORY_RESHAPE(OpReshape, FP8E4M3); + DEF_FACTORY_RESHAPE(OpReshape, FP8E5M2); break; case Op_REVERSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); @@ -448,6 +473,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2); break; case Op_SLICE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); @@ -458,6 +485,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2); break; case Op_TILE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); @@ -468,6 +497,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2); break; case Op_TRANSPOSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); @@ -478,6 +509,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2); break; // scatter_gather @@ -489,6 +522,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, BF16); DEF_FACTORY_ONE_TYPE(OpGather, FP32); DEF_FACTORY_ONE_TYPE(OpGather, FP64); + DEF_FACTORY_ONE_TYPE(OpGather, FP8E4M3); + DEF_FACTORY_ONE_TYPE(OpGather, FP8E5M2); break; case Op_SCATTER: DEF_FACTORY_ONE_TYPE(OpScatter, INT8); @@ -498,6 +533,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpScatter, BF16); DEF_FACTORY_ONE_TYPE(OpScatter, FP32); DEF_FACTORY_ONE_TYPE(OpScatter, FP64); + DEF_FACTORY_ONE_TYPE(OpScatter, FP8E4M3); + DEF_FACTORY_ONE_TYPE(OpScatter, FP8E5M2); break; // image @@ -524,6 +561,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2); break; // type_conversion @@ -569,6 +608,18 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2); break; case Op_RESCALE: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index bd16ad1..85397ae 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -236,6 +236,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP8E4M3); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP8E5M2); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); @@ -244,3 +246,5 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP8E4M3); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP8E5M2); diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 342d5c2..41e6061 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -88,6 +88,18 @@ struct GetEigenType using type = float; }; template <> +struct GetEigenType +{ + // NOTE: full precision used + using type = float; +}; +template <> +struct GetEigenType +{ + // NOTE: full precision used + using type = float; +}; +template <> struct GetEigenType { using type = int32_t; @@ -200,6 +212,16 @@ struct GetNumBits { static constexpr int32_t value = 16; }; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 8; +}; +template <> +struct GetNumBits +{ + static constexpr int32_t value = 8; +}; // Meta function to get quantized min/max in compile time template diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index dd66f79..124dc87 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -555,7 +555,7 @@ int OpAvgPool2d::eval() } } if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 && - Dtype != TOSA_REF_TYPE_FP64) + Dtype != TOSA_REF_TYPE_FP64 && Dtype != TOSA_REF_TYPE_FP8E4M3 && Dtype != TOSA_REF_TYPE_FP8E5M2) { try { @@ -2155,6 +2155,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32); @@ -2163,6 +2165,8 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16); // [in_t, weight_t, out_t] DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -2173,6 +2177,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); @@ -2182,6 +2188,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); @@ -2191,6 +2199,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64); @@ -2211,6 +2221,8 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E4M3, FP16); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E5M2, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16); @@ -2218,6 +2230,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E4M3); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E5M2); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64); @@ -2230,3 +2244,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 484f768..5dbc7bd 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -15,6 +15,7 @@ #include "type_conversion.h" #include "arith_util.h" +#include "float_utils.h" #include "half.hpp" #include "quant_util.h" #include "template_types.h" @@ -24,6 +25,12 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; +using fp16 = tosa::reference::internal::float_t; +using bf16 = tosa::reference::internal::float_t; +using fp32 = tosa::reference::internal::float_t; +using fp8e4m3 = tosa::reference::internal::float_t; +using fp8e5m2 = tosa::reference::internal::float_t; + template OpRescale::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_RESCALE, id_) @@ -526,6 +533,162 @@ CastHelper::CastHelper() }; } +template +CastHelper::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to integer + fcn = [](float in) -> OutEigenType { + if (in >= float(OutMax)) + return OutMax; + if (in <= float(OutMin)) + return OutMin; + + OutEigenType out = std::rint(in); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to fp16 (stored as fp32) + fcn = [](float in) -> float { + half_float::half h = half_float::half(in); + float out = half_float::half_cast(h); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to bf16 (stored as fp32) + fcn = [](float in) -> float { return (float)in; }; +} + +CastHelper::CastHelper() +{ + // fp8e4m3 data (stored as fp32) converted to fp32 + fcn = [](InEigenType in) -> OutEigenType { return in; }; +} + +template +CastHelper::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to integer + fcn = [](float in) -> OutEigenType { + if (in >= float(OutMax)) + return OutMax; + if (in <= float(OutMin)) + return OutMin; + + OutEigenType out = std::rint(in); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to fp16 (stored as fp32) + fcn = [](float in) -> float { + half_float::half h = half_float::half(in); + float out = half_float::half_cast(h); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to bf16 (stored as fp32) + fcn = [](float in) -> float { return (float)in; }; +} + +CastHelper::CastHelper() +{ + // fp8e5m2 data (stored as fp32) converted to fp32 + fcn = [](InEigenType in) -> OutEigenType { return in; }; +} + +template +CastHelper::CastHelper() +{ + // Integer data converted to fp8e4m3 (stored as fp32) + fcn = [](InEigenType in) -> float { + auto f = static_cast(static_cast(float(in))); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp16 data (stored as fp32) converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // bf16 data (stored as fp32) converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp32 data converted to fp8e4m3 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + +template +CastHelper::CastHelper() +{ + // Integer data converted to fp8e5m2 (stored as fp32) + fcn = [](InEigenType in) -> float { + auto f = static_cast(static_cast(float(in))); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp16 data (stored as fp32) converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // bf16 data (stored as fp32) converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + +CastHelper::CastHelper() +{ + // fp32 data converted to fp8e5m2 (stored as fp32) + fcn = [](float in) -> float { + auto f = static_cast(static_cast(in)); + float out = static_cast(f); + return out; + }; +} + template CastHelper::CastHelper() { @@ -597,6 +760,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index 98799a0..75f244d 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -276,6 +276,282 @@ private: FcnType fcn; }; +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutMin = GetQMin::value; + static constexpr int32_t OutMax = GetQMax::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + static constexpr int32_t OutMin = GetQMin::value; + static constexpr int32_t OutMax = GetQMax::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType::type; + using OutEigenType = typename GetEigenType::type; + using FcnType = std::function; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + template class CastHelper { diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index fae0b30..33a9b94 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -595,6 +595,22 @@ int SubgraphTraverser::allocateTensor(std::string name) } } break; + case DType_FP8E4M3: + case DType_FP8E5M2: { + std::vector fp32_data; + TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); + // Ensure valid fp8 stored in each float + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } + } + break; case DType_FP32: { std::vector fp32_data; TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index f9ec937..27f21f3 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -115,12 +115,15 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) assert(dtype == ConvertDType(serialization_dtype)); // if dtype is FP64, serialization_dtype must be one of FP32, FP16, BF16 assert(dtype != TOSA_REF_TYPE_FP64 || serialization_dtype == DType_FP32 || serialization_dtype == DType_FP16 || - serialization_dtype == DType_BF16); + serialization_dtype == DType_BF16 || serialization_dtype == DType_FP8E4M3 || + serialization_dtype == DType_FP8E5M2); switch (serialization_dtype) { case DType_FP32: case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); @@ -208,6 +211,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) return 1; } break; + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: + if (setTensorValueFloat(elements, f32databuf)) + { + free(f32databuf); + return 1; + } + break; case TOSA_REF_TYPE_FP32: if (setTensorValueFloat(elements, f32databuf)) { @@ -276,6 +287,23 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename) return 1; } break; + case DType_FP8E4M3: + case DType_FP8E5M2: + // FP8E4M3 -> FP64 + f64databuf = (double*)calloc(sizeof(double), elements); + ASSERT_MEM(f64databuf); + for (uint32_t i = 0; i < elements; i++) + { + //ASSERT_MSG(checkValidFloat8(f32databuf[i]), "Input float value not a valid float8 value."); + f64databuf[i] = static_cast(f32databuf[i]); + } + if (setTensorValueDouble(elements, f64databuf)) + { + free(f32databuf); + free(f64databuf); + return 1; + } + break; case DType_FP32: // FP32 -> FP64 f64databuf = (double*)calloc(sizeof(double), elements); @@ -349,6 +377,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const free(f32databuf); break; case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: f32databuf = (float*)calloc(sizeof(float), elements); ASSERT_MEM(f32databuf); f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements); @@ -631,6 +661,8 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy vals) // continue with setting float vals in the tensor case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: if (vals.size() != elements) { WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index 2c3be7f..1659a2f 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -863,6 +863,8 @@ public: case TOSA_REF_TYPE_FP32: case TOSA_REF_TYPE_FP16: case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP8E5M2: switch (rank) { case 0: diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc index 14bc6f1..4ae245b 100644 --- a/reference_model/src/verify/verify_utils.cc +++ b/reference_model/src/verify/verify_utils.cc @@ -36,6 +36,8 @@ NLOHMANN_JSON_SERIALIZE_ENUM(DType, { DType::DType_FP16, "FP16" }, { DType::DType_BF16, "BF16" }, { DType::DType_FP32, "FP32" }, + { DType::DType_FP8E4M3, "FP8E4M3" }, + { DType::DType_FP8E5M2, "FP8E5M2" }, }) } // namespace tosa @@ -177,12 +179,13 @@ std::string positionToString(const std::vector& pos) DType mapToDType(tosa_datatype_t dataType) { static std::map typeMap = { - { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 }, - { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 }, - { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 }, - { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 }, - { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 }, - { tosa_datatype_shape_t, DType_SHAPE }, + { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 }, + { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 }, + { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 }, + { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 }, + { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 }, + { tosa_datatype_shape_t, DType_SHAPE }, { tosa_datatype_fp8e4m3_t, DType_FP8E4M3 }, + { tosa_datatype_fp8e5m2_t, DType_FP8E5M2 }, }; if (typeMap.count(dataType)) diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index 8137a43..a029f1f 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit 8137a4369acefa4c01f08db95a2b1b290e8dd70a +Subproject commit a029f1f02707f40f6990df53fd4f56684490d58f diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py index 212c809..4d6d345 100644 --- a/verif/checker/tosa_result_checker.py +++ b/verif/checker/tosa_result_checker.py @@ -13,6 +13,7 @@ from checker.color_print import print_color from checker.verifier import VerifierError from checker.verifier import VerifierLibrary from generator.tosa_utils import float32_is_valid_bfloat16 +from generator.tosa_utils import float32_is_valid_float8 from schemavalidation.schemavalidation import TestDescSchemaValidator @@ -195,6 +196,18 @@ def test_check( "reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}" ) return (TestResult.INCORRECT_FORMAT, 0.0, msg) + if "fp8e4m3" in misc_checks or "fp8e5m2" in misc_checks: + # Ensure floats are valid float8 values + test_res_is_fp8 = all([float32_is_valid_float8(f) for f in test_result.flat]) + ref_res_is_fp8 = all( + [float32_is_valid_float8(f) for f in reference_result.flat] + ) + if not (test_res_is_fp8 and ref_res_is_fp8): + msg = ( + "All output values must be valid float8. " + "reference_result: {ref_res_is_float8}; test_result: {test_res_is_float8}" + ) + return (TestResult.INCORRECT_FLOAT, 0.0, msg) # for quantized test, allow +-(quantize_tolerance) error if reference_result.dtype in ( diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json index 7792417..7559c62 100644 --- a/verif/conformance/tosa_main_profile_ops_info.json +++ b/verif/conformance/tosa_main_profile_ops_info.json @@ -185,6 +185,30 @@ "2" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -233,6 +257,24 @@ "--allow-pooling-and-conv-oversizes" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -315,6 +357,30 @@ "2,65538,1,1" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -527,6 +593,30 @@ "2" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -592,6 +682,30 @@ "1,2,1,65529" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -647,6 +761,24 @@ "--allow-pooling-and-conv-oversizes" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -722,6 +854,28 @@ "--allow-pooling-and-conv-oversizes" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--target-shape", + "1,7,18,5,4", + "--target-shape", + "1,6,12,17,3", + "--tensor-dim-range", + "1,4", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -787,6 +941,24 @@ "--allow-pooling-and-conv-oversizes" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -840,6 +1012,30 @@ "3" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -1183,6 +1379,30 @@ "5000,1,1" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -1505,6 +1725,30 @@ "1,65538,3" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -1551,6 +1795,24 @@ "--allow-pooling-and-conv-oversizes" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -1699,6 +1961,30 @@ "1,1,65539,1" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -1889,6 +2175,30 @@ "2" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -1935,6 +2245,30 @@ "1,65535,1,2" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -2046,6 +2380,24 @@ "2989,6,1" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -2091,6 +2443,30 @@ "1,65543,2,1" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -2161,6 +2537,30 @@ "1" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -2214,6 +2614,30 @@ "1" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-shape", + "10,24,9,13", + "--target-shape", + "8,14,20,5", + "--tensor-dim-range", + "1,16", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { @@ -3111,6 +3535,30 @@ "2" ] ] + }, + "float8": { + "from_version" : "v0.100.0", + "no_negative_tests": "true", + "generator_args": [ + [ + "--target-dtype", + "fp8e4m3", + "--target-dtype", + "fp8e5m2", + "--fp-values-range", + "-max,max", + "--tensor-dim-range", + "32,64", + "--target-rank", + "1", + "--target-rank", + "2", + "--target-rank", + "3", + "--num-rand-permutations", + "2" + ] + ] } }, "selection": { diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 7ec0cfe..d0b9eb9 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -641,6 +641,8 @@ class TosaTensorValuesGen: DType.FP32: (1 << 128) - (1 << (127 - 23)), DType.FP16: (1 << 16) - (1 << (15 - 10)), DType.BF16: (1 << 128) - (1 << (127 - 7)), + DType.FP8E4M3: 448, + DType.FP8E5M2: 57344, } # Default lowest normal values for random numbers @@ -648,6 +650,8 @@ class TosaTensorValuesGen: DType.FP32: np.exp2(-126), DType.FP16: np.exp2(-14), DType.BF16: np.exp2(-126), + DType.FP8E4M3: np.exp2(-9), + DType.FP8E5M2: np.exp2(-16), } @staticmethod @@ -715,6 +719,8 @@ class TosaTensorValuesGen: DType.FP16, DType.FP32, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ): # Change from inclusive to exclusive range data_range = (data_range[0], data_range[1] + 1) @@ -1734,7 +1740,13 @@ class TosaArgGen: and "data_gen" in testGen.TOSA_OP_LIST[opName] and gtu.dtypeIsSupportedByCompliance(dtype) ): - if dtype in [DType.FP16, DType.FP32, DType.BF16]: + if dtype in [ + DType.FP16, + DType.FP32, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ]: dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"] else: dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"] @@ -2140,6 +2152,8 @@ class TosaArgGen: accum_dtypes = [DType.FP32] elif dtype == DType.FP32: accum_dtypes = [DType.FP32] + elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2: + accum_dtypes = [DType.FP16] elif error_name is None: assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}" @@ -2350,7 +2364,13 @@ class TosaArgGen: if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]: pad_const_int = testGen.getRandNumberDType(dtype) pad_const_fp = 0 - elif dtype in (DType.FP16, DType.BF16, DType.FP32): + elif dtype in ( + DType.FP16, + DType.BF16, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ): pad_const_int = 0 pad_const_fp = testGen.getRandNumberDType(dtype) else: @@ -2468,6 +2488,8 @@ class TosaArgGen: accum_dtypes = [DType.FP16, DType.FP32] elif dtype == DType.BF16 or dtype == DType.FP32: accum_dtypes = [DType.FP32] + elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2: + accum_dtypes = [DType.FP16] elif error_name is None: assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}" else: @@ -2646,11 +2668,35 @@ class TosaArgGen: elif inDtype == DType.BOOL: dtypeList = [DType.INT8, DType.INT16, DType.INT32] elif inDtype == DType.FP16: - dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + dtypeList = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] elif inDtype == DType.BF16: - dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + dtypeList = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] elif inDtype == DType.FP32: - dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16] + dtypeList = [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP16, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ] + elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]: + dtypeList = [DType.FP16, DType.BF16, DType.FP32] elif error_name == ErrorIf.WrongInputType: # Pick some potentially correct output type for incorrect input type dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32] @@ -3232,6 +3278,10 @@ class TosaArgGen: outputDTypeList = [DType.BF16] elif dtype == DType.FP32: outputDTypeList = [DType.FP32] + elif dtype == DType.FP8E4M3: + outputDTypeList = [DType.FP8E4M3] + elif dtype == DType.FP8E5M2: + outputDTypeList = [DType.FP8E5M2] elif error_name == ErrorIf.WrongInputType: # If an incorrect input type is used then we set a 'correct' # output type to avoid other errors diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index 9a88acb..7a4d0d6 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -325,12 +325,32 @@ class TosaErrorIfArgGen: @staticmethod def eiCastErrorIf(testGen, input_dtype): - if input_dtype in [DType.BOOL, DType.FP32]: + # if input_dtype in [DType.BOOL, DType.FP32]: + # outputDType = [DType.BOOL, DType.INT48, DType.FP32] + if input_dtype in [DType.BOOL]: + outputDType = [ + DType.BOOL, + DType.INT48, + DType.FP32, + DType.FP16, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ] + elif input_dtype in [DType.FP32]: outputDType = [DType.BOOL, DType.INT48, DType.FP32] elif input_dtype in [DType.FP16, DType.BF16]: outputDType = [DType.BOOL, DType.INT48] elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]: outputDType = [DType.INT48] + elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]: + outputDType = [ + DType.BOOL, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + ] else: assert False, f"input_dtype ({input_dtype}) not supported" return outputDType @@ -476,13 +496,23 @@ class TosaErrorValidator: ) or (input_dtype == DType.BF16 and output_dtype != DType.FP32) or (input_dtype == DType.FP32 and output_dtype != DType.FP32) + or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16) + or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16) ): error_result = True elif op["op"] == Op.ARGMAX: if ( input_dtype - in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] + in [ + DType.INT8, + DType.INT16, + DType.FP16, + DType.BF16, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] and output_dtype != DType.INT32 ): error_result = True @@ -555,12 +585,26 @@ class TosaErrorValidator: or ( input_dtype == DType.FP16 and output_dtype - not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + not in [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] ) or ( input_dtype == DType.BF16 and output_dtype - not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32] + not in [ + DType.INT8, + DType.INT16, + DType.INT32, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ] ) or ( input_dtype == DType.FP32 @@ -571,6 +615,17 @@ class TosaErrorValidator: DType.INT32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ] + ) + or ( + input_dtype in [DType.FP8E4M3, DType.FP8E5M2] + and output_dtype + not in [ + DType.FP16, + DType.BF16, + DType.FP32, ] ) ): @@ -597,6 +652,10 @@ class TosaErrorValidator: and output_dtype != DType.FP32 or input_dtype == DType.FP32 and output_dtype != DType.FP32 + or input_dtype == DType.FP8E4M3 + and output_dtype != DType.FP16 + or input_dtype == DType.FP8E5M2 + and output_dtype != DType.FP16 ): error_result = True # invalid input types are ignored, to avoid reporting multiple errors @@ -2615,6 +2674,11 @@ class TosaErrorValidator: DType.FP32, ): error_result = True + elif ( + input_dtype in (DType.FP8E4M3, DType.FP8E5M2) + and accum_dtype != DType.FP16 + ): + error_result = True info_dict = { "error_name": error_name, diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 4ead982..bc931dc 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -76,7 +76,7 @@ class TosaTestGen: return tuple(sorted(vals)) self.random_float_range = {} - for dtype in (DType.FP32, DType.FP16, DType.BF16): + for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2): self.random_float_range[dtype] = convertFPRange( args.tensor_fp_value_range, TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype], @@ -152,7 +152,7 @@ class TosaTestGen: # Returns dtype value range boundaries (low, high) # The high boundary is excluded in the range # unless high_inclusive is True - if dtype in (DType.FP32, DType.FP16, DType.BF16): + if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2): return self.random_float_range[dtype] elif dtype == DType.BOOL: rng = (0, 2) @@ -197,7 +197,13 @@ class TosaTestGen: return np.uint8(self.rng.integers(low=low, high=high, size=shape)) elif dtype in (DType.INT48, DType.SHAPE): return np.int64(self.rng.integers(low=low, high=high, size=shape)) - elif dtype in (DType.FP16, DType.BF16, DType.FP32): + elif dtype in ( + DType.FP16, + DType.BF16, + DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, + ): f_tensor = self.rng.uniform(low=low, high=high, size=shape) if dtype == DType.FP16: @@ -207,6 +213,10 @@ class TosaTestGen: if dtype == DType.BF16: # Floor the last 16 bits of each f32 value return np.float32(gtu.vect_f32_to_bf16(f32_tensor)) + elif dtype == DType.FP8E4M3: + return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor)) + elif dtype == DType.FP8E5M2: + return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor)) else: return f32_tensor else: @@ -266,6 +276,12 @@ class TosaTestGen: elif dtype == DType.BF16: rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) return gtu.vect_f32_to_bf16(rand_f32) + elif dtype == DType.FP8E4M3: + rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) + return gtu.vect_f32_to_fp8e4m3(rand_f32) + elif dtype == DType.FP8E5M2: + rand_f32 = np.float32(self.rng.uniform(low=low, high=high)) + return gtu.vect_f32_to_fp8e5m2(rand_f32) elif dtype == DType.BOOL: return self.rng.choice([False, True]) elif dtype == DType.INT48 or dtype == DType.SHAPE: @@ -1408,8 +1424,11 @@ class TosaTestGen: max_val = max_val.astype(np.float32) attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val) - else: + elif a.dtype in (DType.INT8, DType.INT16): attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0) + else: + # to avoid internal error for incorrect input types + attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -3190,7 +3209,13 @@ class TosaTestGen: ] TYPE_FI16 = [DType.FP32, DType.INT16] - TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32] + TYPE_NARROW_INT_FP = [ + DType.INT8, + DType.INT16, + DType.FP16, + DType.BF16, + DType.FP32, + ] # List of [Input Type 1, Input Type 2, Accumulator Type] TYPE_CONV = [ @@ -3201,6 +3226,8 @@ class TosaTestGen: [DType.FP16, DType.FP16, DType.FP32], [DType.BF16, DType.BF16, DType.FP32], [DType.FP32, DType.FP32, DType.FP32], + [DType.FP8E4M3, DType.FP8E4M3, DType.FP16], + [DType.FP8E5M2, DType.FP8E5M2, DType.FP16], ] DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK) @@ -3217,7 +3244,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agAxis, ), - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, @@ -3244,7 +3271,7 @@ class TosaTestGen: TosaArgGen.agPooling, ), "qgen": TosaQuantGen.qgUnary, - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evKernelSmallerOne, @@ -3402,7 +3429,7 @@ class TosaTestGen: TosaArgGen.agMatMul, ), "qgen": TosaQuantGen.qgMatmul, - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evInputZeroPointNotZero, TosaErrorValidator.evWrongRank, @@ -3425,7 +3452,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agPooling, ), - "types": TYPE_NARROW_INT_FP, + "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,), "error_if_validators": ( TosaErrorValidator.evKernelSmallerOne, @@ -4389,7 +4416,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgConcat, TosaArgGen.agAxis, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -4413,7 +4440,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgPad, TosaArgGen.agPad, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evPadSmallerZero, @@ -4437,7 +4464,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agAxis, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisLargerRank, TosaErrorValidator.evAxisSmallerZero, @@ -4456,7 +4483,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgReshape, TosaArgGen.agReshape, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evTensorSizeInputOutputMismatch, TosaErrorValidator.evWrongInputType, @@ -4477,7 +4504,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agAxis, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evAxisSmallerZero, TosaErrorValidator.evAxisLargerRank, @@ -4500,7 +4527,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgSlice, TosaArgGen.agSlice, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( # TODO Turn off these error categories for now as the reference # model cannot allocate memory space for empty tensor. We probably @@ -4532,7 +4559,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgTile, TosaArgGen.agTile, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -4555,7 +4582,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agTranspose, ), - "types": TYPE_FIB, + "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evIndexOutsideBounds, TosaErrorValidator.evIndexUsedTwice, @@ -4581,7 +4608,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgLazyGenDefault, TosaArgGen.agNone, ), - "types": TYPE_FIB + [DType.INT48], + "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2], "data_gen": { "fp": (gtu.DataGenType.PSEUDO_RANDOM,), }, @@ -4618,6 +4645,8 @@ class TosaTestGen: DType.FP16, DType.BF16, DType.FP32, + DType.FP8E4M3, + DType.FP8E5M2, ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, @@ -4640,7 +4669,7 @@ class TosaTestGen: TosaTensorValuesGen.tvgScatter, TosaArgGen.agNone, ), - "types": TYPE_INT_FP, + "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2], "error_if_validators": ( TosaErrorValidator.evWrongInputType, TosaErrorValidator.evWrongOutputType, @@ -4709,6 +4738,8 @@ class TosaTestGen: DType.INT16, DType.INT32, DType.BOOL, + DType.FP8E4M3, + DType.FP8E5M2, ), "error_if_validators": ( TosaErrorValidator.evWrongInputType, @@ -5141,6 +5172,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ] wrong_dtypes = list(set(all_dtypes) - set([DType.INT32])) outputDType = rng.choice(wrong_dtypes) @@ -5194,6 +5227,8 @@ class OutputShaper: if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: excludes = [DType.FP16, DType.FP32] + if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]: + excludes = [DType.FP16] else: excludes = [out_dtype] wrong_dtypes = list(gtu.usableDTypes(excludes=excludes)) @@ -5344,6 +5379,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ] wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype])) outputDType = rng.choice(wrong_dtypes) @@ -5383,6 +5420,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ) elif a.dtype == DType.INT16: incorrect_types = ( @@ -5393,6 +5432,20 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, + ) + elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ) elif ( a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16 @@ -5403,6 +5456,8 @@ class OutputShaper: DType.INT16, DType.INT32, DType.INT48, + DType.FP8E4M3, + DType.FP8E5M2, ) out_dtype = rng.choice(a=incorrect_types) elif error_name == ErrorIf.WrongInputType: @@ -5669,6 +5724,8 @@ class OutputShaper: DType.FP32, DType.FP16, DType.BF16, + DType.FP8E4M3, + DType.FP8E5M2, ] wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype])) outputDType = rng.choice(wrong_dtypes) diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index 76e7388..31a0ff0 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -27,6 +27,8 @@ DTYPE_ATTRIBUTES = { DType.FP16: {"str": "f16", "width": 16, "json": "FP16"}, DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"}, DType.FP32: {"str": "f32", "width": 32, "json": "FP32"}, + DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"}, + DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"}, } @@ -186,6 +188,16 @@ def get_wrong_output_type(op_name, rng, input_dtype): DType.INT32, DType.INT48, ) + elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2: + incorrect_types = ( + DType.INT4, + DType.INT8, + DType.INT16, + DType.INT32, + DType.INT48, + DType.FP32, + DType.BF16, + ) else: # Assume all types but the input type are incorrect incorrect_types = list(usableDTypes(excludes=(input_dtype,))) @@ -209,6 +221,12 @@ def float32_is_valid_bfloat16(f): return f32_bits[16:] == "0" * 16 +def float32_is_valid_float8(f): + """Return True if float value is valid float8.""" + f32_bits = get_float32_bitstring(f) + return f32_bits[8:] == "0" * 24 + + def get_float32_bitstring(f): """Return a big-endian string of bits representing a 32 bit float.""" f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0] @@ -232,6 +250,30 @@ def float32_to_bfloat16(f): return struct.unpack("@f", fp_bytes)[0] # native byteorder +def float32_to_fp8e4m3(f): + """Turns fp32 value into fp8e4m3""" + f32_bits = get_float32_bitstring(f) + fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24 + fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder) + return struct.unpack("@f", fp_bytes)[0] # native byteorder + + +def float32_to_fp8e5m2(f): + """Turns fp32 value into fp8e5m2""" + f32_bits = get_float32_bitstring(f) + fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24 + fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder) + return struct.unpack("@f", fp_bytes)[0] + + vect_f32_to_bf16 = np.vectorize( float32_to_bfloat16, otypes=(np.float32,) ) # NumPy vectorize: applies function to vector faster than looping + +vect_f32_to_fp8e4m3 = np.vectorize( + float32_to_fp8e4m3, otypes=(np.float32,) +) # NumPy vectorize: applies function to vector faster than looping + +vect_f32_to_fp8e5m2 = np.vectorize( + float32_to_fp8e5m2, otypes=(np.float32,) +) # Numpy vectorize: applies function to vector faster than looping -- cgit v1.2.1