aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWon Jeon <won.jeon@arm.com>2024-02-06 18:37:00 +0000
committerWon Jeon <won.jeon@arm.com>2024-02-21 19:38:55 +0000
commit2c34b4616a10539211e7006bc43f3c71e86c30bb (patch)
treeaa4043a610ecd4c6d35b876cfb013dbe7dd0ab01
parent587cc84c2b8c4b0d030b5e257c9a32461c0969b9 (diff)
downloadreference_model-2c34b4616a10539211e7006bc43f3c71e86c30bb.tar.gz
Add support for FP8 to reference model
Signed-off-by: Won Jeon <won.jeon@arm.com> Change-Id: I99b70f94aff2ccd4af64875697e124eb60bc5b08
-rw-r--r--reference_model/include/dtype.h16
-rw-r--r--reference_model/include/types.h30
-rw-r--r--reference_model/src/arith_util.h6
-rw-r--r--reference_model/src/float_utils.h533
-rw-r--r--reference_model/src/generate/generate_utils.cc4
-rw-r--r--reference_model/src/ops/data_layout.cc18
-rw-r--r--reference_model/src/ops/data_nodes.cc4
-rw-r--r--reference_model/src/ops/op_factory.cc51
-rw-r--r--reference_model/src/ops/scatter_gather.cc6
-rw-r--r--reference_model/src/ops/template_types.h24
-rw-r--r--reference_model/src/ops/tensor_ops.cc18
-rw-r--r--reference_model/src/ops/type_conversion.cc175
-rw-r--r--reference_model/src/ops/type_conversion.h278
-rw-r--r--reference_model/src/subgraph_traverser.cc18
-rw-r--r--reference_model/src/tensor.cc34
-rw-r--r--reference_model/src/tensor.h2
-rw-r--r--reference_model/src/verify/verify_utils.cc15
m---------thirdparty/serialization_lib0
-rw-r--r--verif/checker/tosa_result_checker.py13
-rw-r--r--verif/conformance/tosa_main_profile_ops_info.json448
-rw-r--r--verif/generator/tosa_arg_gen.py60
-rw-r--r--verif/generator/tosa_error_if.py72
-rw-r--r--verif/generator/tosa_test_gen.py95
-rw-r--r--verif/generator/tosa_utils.py42
24 files changed, 1901 insertions, 61 deletions
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h
index 1b01a0e..3e8bdf5 100644
--- a/reference_model/include/dtype.h
+++ b/reference_model/include/dtype.h
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -41,6 +41,8 @@ enum TOSA_REF_TYPE : uint32_t
TOSA_REF_TYPE_FP16 = 10,
TOSA_REF_TYPE_BF16 = 11,
TOSA_REF_TYPE_SHAPE = 12,
+ TOSA_REF_TYPE_FP8E4M3 = 13,
+ TOSA_REF_TYPE_FP8E5M2 = 14,
TOSA_REF_TYPE_FP64 = 99, // FP64 is special: add new data types above
};
@@ -74,6 +76,10 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e)
return EnumNameDType(DType_BF16);
case TOSA_REF_TYPE_SHAPE:
return EnumNameDType(DType_SHAPE);
+ case TOSA_REF_TYPE_FP8E4M3:
+ return EnumNameDType(DType_FP8E4M3);
+ case TOSA_REF_TYPE_FP8E5M2:
+ return EnumNameDType(DType_FP8E5M2);
case TOSA_REF_TYPE_FP64:
return "FP64";
default:
@@ -85,7 +91,7 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e)
// return corresponding TOSA_REF_TYPE for DType
inline TOSA_REF_TYPE ConvertDType(const DType dtype)
{
- assert(DType_MAX == DType_SHAPE); // must update whenever DType_MAX changes
+ assert(DType_MAX == DType_FP8E5M2); // must update whenever DType_MAX changes
if (g_func_config.precise_mode)
{
@@ -95,6 +101,8 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype)
case DType_FP16:
case DType_FP32:
case DType_BF16:
+ case DType_FP8E4M3:
+ case DType_FP8E5M2:
return TOSA_REF_TYPE_FP64;
default:
break;
@@ -127,6 +135,10 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype)
return TOSA_REF_TYPE_BF16;
case DType_SHAPE:
return TOSA_REF_TYPE_SHAPE;
+ case DType_FP8E4M3:
+ return TOSA_REF_TYPE_FP8E4M3;
+ case DType_FP8E5M2:
+ return TOSA_REF_TYPE_FP8E5M2;
default:
break;
}
diff --git a/reference_model/include/types.h b/reference_model/include/types.h
index 15ee40c..32a8ce1 100644
--- a/reference_model/include/types.h
+++ b/reference_model/include/types.h
@@ -26,19 +26,21 @@ extern "C"
enum tosa_datatype_t
{
- tosa_datatype_bf16_t = 0,
- tosa_datatype_bool_t = 1,
- tosa_datatype_fp16_t = 2,
- tosa_datatype_fp32_t = 3,
- tosa_datatype_int16_t = 4,
- tosa_datatype_int32_t = 5,
- tosa_datatype_int48_t = 6,
- tosa_datatype_int4_t = 7,
- tosa_datatype_int8_t = 8,
- tosa_datatype_uint16_t = 9,
- tosa_datatype_uint8_t = 10,
- tosa_datatype_shape_t = 11,
- tosa_datatype_fp64_t = 99
+ tosa_datatype_bf16_t = 0,
+ tosa_datatype_bool_t = 1,
+ tosa_datatype_fp16_t = 2,
+ tosa_datatype_fp32_t = 3,
+ tosa_datatype_int16_t = 4,
+ tosa_datatype_int32_t = 5,
+ tosa_datatype_int48_t = 6,
+ tosa_datatype_int4_t = 7,
+ tosa_datatype_int8_t = 8,
+ tosa_datatype_uint16_t = 9,
+ tosa_datatype_uint8_t = 10,
+ tosa_datatype_shape_t = 11,
+ tosa_datatype_fp8e4m3_t = 12,
+ tosa_datatype_fp8e5m2_t = 13,
+ tosa_datatype_fp64_t = 99
};
struct tosa_tensor_t
@@ -61,4 +63,4 @@ extern "C"
}
#endif /* __cplusplus */
-#endif // TYPES_H_ \ No newline at end of file
+#endif // TYPES_H_
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h
index fb491db..f0d184c 100644
--- a/reference_model/src/arith_util.h
+++ b/reference_model/src/arith_util.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -35,10 +35,8 @@
#include "func_debug.h"
#include "half.hpp"
#include "inttypes.h"
-#include <Eigen/Core>
#include <bitset>
#include <cassert>
-#include <iostream>
#include <limits>
#include <stdint.h>
#include <typeinfo>
@@ -244,7 +242,7 @@ float fpTrunc(float f_in)
// No-op for fp32
break;
default:
- ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-truncated.", EnumNameTOSAREFTYPE(Dtype));
+ ASSERT_MSG(false, "TOSA_REF_TYPE %s should not be float-cast.", EnumNameTOSAREFTYPE(Dtype));
}
return f_in;
}
diff --git a/reference_model/src/float_utils.h b/reference_model/src/float_utils.h
new file mode 100644
index 0000000..b98c89b
--- /dev/null
+++ b/reference_model/src/float_utils.h
@@ -0,0 +1,533 @@
+// Copyright (c) 2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FLOAT_UTILS_H_
+#define FLOAT_UTILS_H_
+
+#include <algorithm>
+#include <cstdint>
+#include <limits>
+#include <type_traits>
+#if defined(__cpp_lib_bit_cast)
+#include <bit>
+#endif // defined(__cpp_lib_bit_cast)
+
+namespace tosa::reference::internal
+{
+
+namespace float_support
+{
+
+struct hidden
+{};
+
+#if defined(__cpp_lib_bit_cast)
+#define BITCAST_CONSTEXPR constexpr inline
+
+constexpr inline int32_t get_bits(const float& f)
+{
+ return std::bit_cast<int32_t>(f);
+}
+constexpr inline float from_bits(const int32_t& i)
+{
+ return std::bit_cast<float>(i);
+}
+
+#else
+#define BITCAST_CONSTEXPR inline
+
+union ufloat32
+{
+ constexpr ufloat32(const float& x)
+ : f(x)
+ {}
+ constexpr ufloat32(const int32_t& x)
+ : i(x)
+ {}
+
+ float f;
+ int32_t i;
+};
+
+inline int32_t get_bits(const float& f)
+{
+ return ufloat32(f).i;
+}
+inline float from_bits(const int32_t& i)
+{
+ return ufloat32(i).f;
+}
+#endif
+
+} // namespace float_support
+
+template <typename storage_t,
+ size_t n_exp_bits,
+ bool has_nan,
+ bool with_denorm,
+ bool with_infinity,
+ std::enable_if_t<(n_exp_bits + 1 < sizeof(storage_t) * 8), bool> = true>
+class float_t
+{
+ storage_t m_data = 0;
+
+public:
+ static constexpr size_t n_exponent_bits = n_exp_bits;
+ static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits;
+ static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1;
+
+ /// \brief Construct a floating point type with the given bit
+ /// representation.
+ static constexpr float_t from_bits(storage_t bits)
+ {
+ return float_t(float_support::hidden(), bits);
+ }
+
+ /// \brief Construct a float from the given sign, exponent and significand
+ /// bits.
+ static constexpr float_t from_bits(bool pm, storage_t e, storage_t s)
+ {
+ storage_t bits = pm ? 1 : 0;
+
+ bits <<= n_exp_bits;
+ bits |= e;
+
+ bits <<= n_significand_bits;
+ if (with_denorm || e)
+ bits |= s;
+
+ return float_t(float_support::hidden(), bits);
+ }
+
+ /// \brief (Hidden) Construct a float type from a given bit pattern
+ constexpr float_t(const float_support::hidden&, storage_t bits)
+ : m_data(bits)
+ {}
+
+ constexpr float_t()
+ : m_data(0)
+ {}
+ constexpr float_t(const float_t& other)
+ : m_data(other.m_data)
+ {}
+
+ /// \brief Cast to a different floating point representation.
+ template <typename other_storage_t,
+ size_t other_n_exp_bits,
+ bool other_has_nan,
+ bool other_has_denorm,
+ bool other_has_infinity>
+ constexpr inline
+ operator float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>() const
+ {
+ using other_float_t =
+ float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>;
+
+ // Shortcut for types which are fundamentally similar (e.g., bf16 ->
+ // fp32)
+ if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) &&
+ has_nan == other_has_nan)
+ {
+ return other_float_t::from_bits(static_cast<other_storage_t>(m_data)
+ << (sizeof(other_storage_t) - sizeof(storage_t)) * 8);
+ }
+
+ // Get initial values for the new floating point type
+ const bool sign_bit = m_data < 0;
+ int64_t new_exponent_bits = 0;
+ uint64_t new_significand = 0;
+
+ if (is_nan() || is_infinity())
+ {
+ new_exponent_bits = (1 << other_n_exp_bits) - 1;
+
+ if (is_nan())
+ {
+ if constexpr (other_has_infinity)
+ {
+ // Copy across the `not_quiet bit`; set the LSB. Don't
+ // attempt to copy across any of the rest of the payload.
+ new_significand =
+ 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits);
+ }
+ else
+ {
+ new_significand = (1ul << other_float_t::n_significand_bits) - 1;
+ }
+ }
+ else if constexpr (!other_has_infinity)
+ {
+ new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1);
+ }
+ }
+ else if (!is_zero())
+ {
+ const int64_t this_exponent_bits = exponent_bits();
+ {
+ constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias;
+ new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1);
+ }
+ new_significand = this->significand() << (64 - n_significand_bits);
+
+ // Normalise subnormals
+ if (this_exponent_bits == 0)
+ {
+ // Shift the most-significant 1 out of the magnitude to convert
+ // it to a significand. Modify the exponent accordingly.
+ uint8_t shift = __builtin_clzl(new_significand) + 1;
+ new_exponent_bits -= shift;
+ new_significand <<= shift;
+ }
+
+ // Align the significand for the output type
+ uint32_t shift = 64 - other_float_t::n_significand_bits;
+ const bool other_is_subnormal = new_exponent_bits <= 0;
+ if (other_is_subnormal)
+ {
+ shift += 1 - new_exponent_bits;
+ new_exponent_bits = 0;
+ }
+
+ const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1);
+ new_significand = shift == 64 ? 0 : new_significand >> shift;
+
+ // Reinsert the most-significant-one if this is a subnormal in the
+ // output type.
+ new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift);
+
+ // Apply rounding based on the bits shifted out of the significand
+ const uint64_t shift_half = 1ll << (shift - 1);
+ if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1)))
+ {
+ new_significand += 1;
+
+ // Handle the case that the significand overflowed due to
+ // rounding
+ constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1;
+ if (new_significand > max_significand)
+ {
+ new_significand = 0;
+ new_exponent_bits++;
+ }
+ }
+
+ // Saturate to infinity if the exponent is larger than can be
+ // represented in the output type. This can only occur if the size
+ // of the exponent of the new type is not greater than the exponent
+ // of the old type.
+ if constexpr (other_n_exp_bits <= n_exp_bits)
+ {
+ constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1;
+ if (new_exponent_bits >= inf_exp_bits)
+ {
+ new_exponent_bits = inf_exp_bits;
+ new_significand =
+ other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1);
+ }
+ }
+ }
+
+ return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand);
+ }
+
+ /// \brief Convert from a 32-bit floating point value
+ BITCAST_CONSTEXPR
+ float_t(const float& f)
+ {
+ // If this format exactly represents the binary32 format then get
+ // the bits from the provided float; otherwise get a binary32
+ // representation and then convert to this format.
+ if constexpr (represents_binary32())
+ m_data = float_support::get_bits(f);
+ else
+ m_data = static_cast<float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_infinity>>(
+ static_cast<float_t<int32_t, 8, true, true, true>>(f))
+ .m_data;
+ }
+
+ /// \brief Cast to a 32-bit floating point value
+ BITCAST_CONSTEXPR operator float() const
+ {
+ // If this format exactly represents the binary32 format then return
+ // a float; otherwise get a binary32 representation and then return
+ // a float.
+ if constexpr (represents_binary32())
+ return float_support::from_bits(m_data);
+ else
+ return static_cast<float>(this->operator float_t<int32_t, 8, true, true, true>());
+ }
+
+ /// \brief Return whether this type represents the IEEE754 binary32
+ /// format
+ constexpr static inline bool represents_binary32()
+ {
+ return std::is_same_v<storage_t, int32_t> && n_exp_bits == 8 && has_nan && with_denorm && with_infinity;
+ }
+
+ constexpr auto operator-() const
+ {
+ return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1)));
+ }
+
+ constexpr bool is_subnormal() const
+ {
+ return exponent_bits() == 0 && significand() != 0;
+ }
+
+ constexpr bool is_zero() const
+ {
+ return exponent_bits() == 0 && significand() == 0;
+ }
+
+ constexpr bool is_nan() const
+ {
+ return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) &&
+ ((with_infinity && significand()) ||
+ (!with_infinity && significand() == (1ul << n_significand_bits) - 1));
+ }
+
+ constexpr bool is_infinity() const
+ {
+ return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand());
+ }
+
+ constexpr inline const storage_t& bits() const
+ {
+ return m_data;
+ }
+
+ /// \brief Get the exponent
+ constexpr inline int64_t exponent() const
+ {
+ return std::max<int64_t>(exponent_bits(), 1ul) - exponent_bias;
+ }
+
+ /// \brief Get the bits from the exponent
+ constexpr inline uint64_t exponent_bits() const
+ {
+ constexpr uint64_t mask = (1ul << n_exp_bits) - 1;
+ return (m_data >> n_significand_bits) & mask;
+ }
+
+ constexpr inline uint64_t significand() const
+ {
+ return m_data & ((1ul << n_significand_bits) - 1);
+ }
+
+ constexpr inline bool operator==(const float_t& other) const
+ {
+ return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits());
+ }
+
+ constexpr inline float_t& operator+=(const float_t& rhs)
+ {
+ this->m_data = static_cast<float_t>(static_cast<float>(*this) + static_cast<float>(rhs)).bits();
+ return *this;
+ }
+};
+
+// This should probably be exported so we can use it elsewhere
+#undef BITCAST_CONSTEXPR
+
+namespace float_support
+{
+
+// Pre-C++23 these can't be computed as constexpr, so have to hardcode them
+
+template <int>
+struct digits10; // floor(log10(2) * (digits - 1)
+template <int>
+struct max_digits10; // ceil(log10(2) * digits + 1)
+template <int>
+struct min_exponent10; // floor(log10(2) * min_exponent)
+template <int>
+struct max_exponent10; // floor(log10(2) * max_exponent)
+
+template <>
+struct digits10<8>
+{
+ constexpr static inline int value = 2;
+};
+
+template <>
+struct max_digits10<8>
+{
+ constexpr static inline int value = 4;
+};
+
+template <>
+struct digits10<10>
+{
+ constexpr static inline int value = 2;
+};
+
+template <>
+struct max_digits10<10>
+{
+ constexpr static inline int value = 5;
+};
+
+template <>
+struct digits10<24>
+{
+ constexpr static inline int value = 6;
+};
+
+template <>
+struct max_digits10<24>
+{
+ constexpr static inline int value = 9;
+};
+
+template <>
+struct min_exponent10<-13>
+{
+ constexpr static inline int value = -3;
+};
+
+template <>
+struct max_exponent10<16>
+{
+ constexpr static inline int value = 4;
+};
+
+template <>
+struct min_exponent10<-125>
+{
+ constexpr static inline int value = -37;
+};
+
+template <>
+struct max_exponent10<128>
+{
+ constexpr static inline int value = 38;
+};
+
+template <int d>
+inline constexpr int digits10_v = digits10<d>::value;
+template <int d>
+inline constexpr int max_digits10_v = max_digits10<d>::value;
+
+template <int e>
+inline constexpr int min_exponent10_v = min_exponent10<e>::value;
+
+template <int e>
+inline constexpr int max_exponent10_v = max_exponent10<e>::value;
+
+} // namespace float_support
+
+} // namespace tosa::reference::internal
+
+namespace std
+{
+
+template <typename storage_t, size_t n_exp_bits, bool has_nan, bool has_denorm, bool has_inf>
+struct is_floating_point<tosa::reference::internal::float_t<storage_t, n_exp_bits, has_nan, has_denorm, has_inf>>
+ : std::integral_constant<bool, true>
+{};
+
+template <typename storage_t, size_t n_exp_bits, bool has_nan, bool with_denorm, bool with_inf>
+class numeric_limits<tosa::reference::internal::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>>
+{
+ using this_float_t = tosa::reference::internal::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>;
+
+public:
+ static constexpr bool is_specialized = true;
+
+ static constexpr auto min() noexcept
+ {
+ return this_float_t::from_bits(false, 1, 0);
+ }
+
+ static constexpr auto max() noexcept
+ {
+ return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2,
+ (1 << this_float_t::n_significand_bits) - 1);
+ }
+
+ static constexpr auto lowest() noexcept
+ {
+ return -max();
+ }
+
+ static constexpr int digits = this_float_t::n_significand_bits + 1;
+ static constexpr int digits10 = tosa::reference::internal::float_support::digits10_v<digits>;
+ static constexpr int max_digits10 = tosa::reference::internal::float_support::max_digits10_v<digits>;
+
+ static constexpr bool is_signed = true;
+ static constexpr bool is_integer = false;
+ static constexpr bool is_exact = false;
+ static constexpr int radix = 2;
+
+ static constexpr auto epsilon() noexcept
+ {
+ return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0);
+ }
+
+ static constexpr auto round_error() noexcept
+ {
+ return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0);
+ }
+
+ static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1;
+ static constexpr int min_exponent10 = tosa::reference::internal::float_support::min_exponent10_v<min_exponent>;
+ static constexpr int max_exponent = this_float_t::exponent_bias + 1;
+ static constexpr int max_exponent10 = tosa::reference::internal::float_support::max_exponent10_v<max_exponent>;
+
+ static constexpr bool has_infinity = with_inf;
+ static constexpr bool has_quiet_NaN = has_nan;
+ static constexpr bool has_signaling_NaN = true;
+ static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent;
+ static constexpr bool has_denorm_loss = false;
+
+ static constexpr auto infinity() noexcept
+ {
+ if constexpr (with_inf)
+ {
+ return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0);
+ }
+ else
+ {
+ return this_float_t::from_bits(false, 0, 0);
+ }
+ }
+
+ static constexpr auto quiet_NaN() noexcept
+ {
+ return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1,
+ 1 << (this_float_t::n_significand_bits - 1) | 1);
+ }
+
+ static constexpr auto signaling_NaN() noexcept
+ {
+ return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1);
+ }
+
+ static constexpr auto denorm_min() noexcept
+ {
+ return this_float_t::from_bits(false, 0, 1);
+ }
+
+ static constexpr bool is_iec559 = false;
+ static constexpr bool is_bounded = false;
+ static constexpr bool is_modulo = false;
+
+ static constexpr bool traps = false;
+ static constexpr bool tinyness_before = false;
+ static constexpr float_round_style round_style = round_to_nearest;
+};
+
+} // namespace std
+
+#endif // _FLOAT_UTILS_H_
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index 8b16e97..271b7f5 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -34,6 +34,8 @@ NLOHMANN_JSON_SERIALIZE_ENUM(DType,
{ DType::DType_BF16, "BF16" },
{ DType::DType_FP32, "FP32" },
{ DType::DType_SHAPE, "SHAPE" },
+ { DType::DType_FP8E4M3, "FP8E4M3" },
+ { DType::DType_FP8E5M2, "FP8E5M2" },
})
NLOHMANN_JSON_SERIALIZE_ENUM(Op,
@@ -225,6 +227,8 @@ size_t elementSizeFromType(DType type)
case DType::DType_BOOL:
case DType::DType_UINT8:
case DType::DType_INT8:
+ case DType::DType_FP8E4M3:
+ case DType::DType_FP8E5M2:
return 1;
case DType::DType_UINT16:
case DType::DType_INT16:
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index ec9614a..ddf0713 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -759,6 +759,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
@@ -768,6 +770,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BF16);
@@ -776,6 +780,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2);
DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
@@ -785,6 +791,8 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
DEF_INSTANTIATE_RESHAPE(OpReshape, FP64);
+DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E4M3);
+DEF_INSTANTIATE_RESHAPE(OpReshape, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
@@ -794,6 +802,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
@@ -803,6 +813,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
@@ -812,6 +824,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
@@ -821,6 +835,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
@@ -830,3 +846,5 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2);
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index 705981c..64001a9 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -105,3 +105,5 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2);
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index af8332e..6d66c07 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -55,6 +55,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2);
break;
case Op_AVG_POOL2D:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
@@ -64,6 +66,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E4M3, FP16);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16);
break;
case Op_CONV2D:
DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -74,6 +78,9 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
break;
case Op_CONV3D:
DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
@@ -84,6 +91,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
break;
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
@@ -94,6 +103,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
break;
case Op_FFT2D:
DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
@@ -117,6 +128,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E4M3, FP16);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP8E5M2, FP16);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
@@ -125,6 +138,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E4M3);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP8E5M2);
break;
case Op_RFFT2D:
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32);
@@ -139,6 +154,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
break;
// activation_funcs
@@ -409,6 +426,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP8E5M2);
break;
case Op_PAD:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
@@ -419,6 +438,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP8E5M2);
break;
case Op_DIM:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP16);
@@ -428,6 +449,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpDim, FP8E5M2);
break;
case Op_RESHAPE:
DEF_FACTORY_RESHAPE(OpReshape, FP16);
@@ -438,6 +461,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RESHAPE(OpReshape, INT32);
DEF_FACTORY_RESHAPE(OpReshape, BOOL);
DEF_FACTORY_RESHAPE(OpReshape, FP64);
+ DEF_FACTORY_RESHAPE(OpReshape, FP8E4M3);
+ DEF_FACTORY_RESHAPE(OpReshape, FP8E5M2);
break;
case Op_REVERSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
@@ -448,6 +473,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP8E5M2);
break;
case Op_SLICE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
@@ -458,6 +485,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP8E5M2);
break;
case Op_TILE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
@@ -468,6 +497,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP8E5M2);
break;
case Op_TRANSPOSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
@@ -478,6 +509,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E4M3);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP8E5M2);
break;
// scatter_gather
@@ -489,6 +522,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, BF16);
DEF_FACTORY_ONE_TYPE(OpGather, FP32);
DEF_FACTORY_ONE_TYPE(OpGather, FP64);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP8E4M3);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP8E5M2);
break;
case Op_SCATTER:
DEF_FACTORY_ONE_TYPE(OpScatter, INT8);
@@ -498,6 +533,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpScatter, BF16);
DEF_FACTORY_ONE_TYPE(OpScatter, FP32);
DEF_FACTORY_ONE_TYPE(OpScatter, FP64);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP8E4M3);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP8E5M2);
break;
// image
@@ -524,6 +561,8 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP8E5M2);
break;
// type_conversion
@@ -569,6 +608,18 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2);
break;
case Op_RESCALE:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index bd16ad1..85397ae 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -236,6 +236,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FP8E4M3);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FP8E5M2);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
@@ -244,3 +246,5 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP8E4M3);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP8E5M2);
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index 342d5c2..41e6061 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -88,6 +88,18 @@ struct GetEigenType<TOSA_REF_TYPE_BF16>
using type = float;
};
template <>
+struct GetEigenType<TOSA_REF_TYPE_FP8E4M3>
+{
+ // NOTE: full precision used
+ using type = float;
+};
+template <>
+struct GetEigenType<TOSA_REF_TYPE_FP8E5M2>
+{
+ // NOTE: full precision used
+ using type = float;
+};
+template <>
struct GetEigenType<TOSA_REF_TYPE_INT32>
{
using type = int32_t;
@@ -200,6 +212,16 @@ struct GetNumBits<TOSA_REF_TYPE_FP16>
{
static constexpr int32_t value = 16;
};
+template <>
+struct GetNumBits<TOSA_REF_TYPE_FP8E4M3>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<TOSA_REF_TYPE_FP8E5M2>
+{
+ static constexpr int32_t value = 8;
+};
// Meta function to get quantized min/max in compile time
template <TOSA_REF_TYPE T>
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index dd66f79..124dc87 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -555,7 +555,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
}
}
if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
- Dtype != TOSA_REF_TYPE_FP64)
+ Dtype != TOSA_REF_TYPE_FP64 && Dtype != TOSA_REF_TYPE_FP8E4M3 && Dtype != TOSA_REF_TYPE_FP8E5M2)
{
try
{
@@ -2155,6 +2155,8 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E4M3);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP8E5M2);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
@@ -2163,6 +2165,8 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16);
// [in_t, weight_t, out_t]
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -2173,6 +2177,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
@@ -2182,6 +2188,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
@@ -2191,6 +2199,8 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
@@ -2211,6 +2221,8 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E4M3, FP16);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP8E5M2, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
@@ -2218,6 +2230,8 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E4M3);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E5M2);
DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
@@ -2230,3 +2244,5 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16);
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 484f768..5dbc7bd 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -15,6 +15,7 @@
#include "type_conversion.h"
#include "arith_util.h"
+#include "float_utils.h"
#include "half.hpp"
#include "quant_util.h"
#include "template_types.h"
@@ -24,6 +25,12 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
+using fp16 = tosa::reference::internal::float_t<int16_t, 5, true, true, true>;
+using bf16 = tosa::reference::internal::float_t<int16_t, 8, true, true, true>;
+using fp32 = tosa::reference::internal::float_t<int32_t, 8, true, true, true>;
+using fp8e4m3 = tosa::reference::internal::float_t<int8_t, 4, true, true, false>;
+using fp8e5m2 = tosa::reference::internal::float_t<int8_t, 5, true, true, true>;
+
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_RESCALE, id_)
@@ -527,6 +534,162 @@ CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper()
}
template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP8E4M3, OutDtype>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to integer
+ fcn = [](float in) -> OutEigenType {
+ if (in >= float(OutMax))
+ return OutMax;
+ if (in <= float(OutMin))
+ return OutMin;
+
+ OutEigenType out = std::rint(in);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP16>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to fp16 (stored as fp32)
+ fcn = [](float in) -> float {
+ half_float::half h = half_float::half(in);
+ float out = half_float::half_cast<half_float::half, float>(h);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_BF16>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to bf16 (stored as fp32)
+ fcn = [](float in) -> float { return (float)in; };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32>::CastHelper()
+{
+ // fp8e4m3 data (stored as fp32) converted to fp32
+ fcn = [](InEigenType in) -> OutEigenType { return in; };
+}
+
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to integer
+ fcn = [](float in) -> OutEigenType {
+ if (in >= float(OutMax))
+ return OutMax;
+ if (in <= float(OutMin))
+ return OutMin;
+
+ OutEigenType out = std::rint(in);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP16>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to fp16 (stored as fp32)
+ fcn = [](float in) -> float {
+ half_float::half h = half_float::half(in);
+ float out = half_float::half_cast<half_float::half, float>(h);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_BF16>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to bf16 (stored as fp32)
+ fcn = [](float in) -> float { return (float)in; };
+}
+
+CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32>::CastHelper()
+{
+ // fp8e5m2 data (stored as fp32) converted to fp32
+ fcn = [](InEigenType in) -> OutEigenType { return in; };
+}
+
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // Integer data converted to fp8e4m3 (stored as fp32)
+ fcn = [](InEigenType in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(float(in)));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // fp16 data (stored as fp32) converted to fp8e4m3 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // bf16 data (stored as fp32) converted to fp8e4m3 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3>::CastHelper()
+{
+ // fp32 data converted to fp8e4m3 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e4m3>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // Integer data converted to fp8e5m2 (stored as fp32)
+ fcn = [](InEigenType in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(float(in)));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // fp16 data (stored as fp32) converted to fp8e5m2 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // bf16 data (stored as fp32) converted to fp8e5m2 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2>::CastHelper()
+{
+ // fp32 data converted to fp8e5m2 (stored as fp32)
+ fcn = [](float in) -> float {
+ auto f = static_cast<fp32>(static_cast<fp8e5m2>(in));
+ float out = static_cast<float>(f);
+ return out;
+ };
+}
+
+template <TOSA_REF_TYPE OutDtype>
CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper()
{
switch (OutDtype)
@@ -597,6 +760,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BF16, FP8E5M2);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E4M3, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP8E5M2, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, FP8E5M2);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E4M3);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP8E5M2);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index 98799a0..75f244d 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -277,6 +277,282 @@ private:
};
template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_FP8E4M3, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP16>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_BF16>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E4M3, TOSA_REF_TYPE_FP32>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_FP8E5M2, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP16>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_BF16>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP8E5M2, TOSA_REF_TYPE_FP32>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_FP8E4M3>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E4M3>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E4M3>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E4M3>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E4M3>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_FP8E5M2>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP8E5M2>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP8E5M2>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <>
+class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP8E5M2>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP8E5M2>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <TOSA_REF_TYPE OutDtype>
class CastHelper<TOSA_REF_TYPE_FP64, OutDtype>
{
public:
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index fae0b30..33a9b94 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -595,6 +595,22 @@ int SubgraphTraverser::allocateTensor(std::string name)
}
}
break;
+ case DType_FP8E4M3:
+ case DType_FP8E5M2: {
+ std::vector<float> fp32_data;
+ TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+ // Ensure valid fp8 stored in each float
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+ }
+ else
+ {
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ }
+ }
+ break;
case DType_FP32: {
std::vector<float> fp32_data;
TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index f9ec937..27f21f3 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -115,12 +115,15 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
assert(dtype == ConvertDType(serialization_dtype));
// if dtype is FP64, serialization_dtype must be one of FP32, FP16, BF16
assert(dtype != TOSA_REF_TYPE_FP64 || serialization_dtype == DType_FP32 || serialization_dtype == DType_FP16 ||
- serialization_dtype == DType_BF16);
+ serialization_dtype == DType_BF16 || serialization_dtype == DType_FP8E4M3 ||
+ serialization_dtype == DType_FP8E5M2);
switch (serialization_dtype)
{
case DType_FP32:
case DType_BF16:
+ case DType_FP8E4M3:
+ case DType_FP8E5M2:
f32databuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(f32databuf);
@@ -208,6 +211,14 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
return 1;
}
break;
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
+ if (setTensorValueFloat(elements, f32databuf))
+ {
+ free(f32databuf);
+ return 1;
+ }
+ break;
case TOSA_REF_TYPE_FP32:
if (setTensorValueFloat(elements, f32databuf))
{
@@ -276,6 +287,23 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
return 1;
}
break;
+ case DType_FP8E4M3:
+ case DType_FP8E5M2:
+ // FP8E4M3 -> FP64
+ f64databuf = (double*)calloc(sizeof(double), elements);
+ ASSERT_MEM(f64databuf);
+ for (uint32_t i = 0; i < elements; i++)
+ {
+ //ASSERT_MSG(checkValidFloat8(f32databuf[i]), "Input float value not a valid float8 value.");
+ f64databuf[i] = static_cast<double>(f32databuf[i]);
+ }
+ if (setTensorValueDouble(elements, f64databuf))
+ {
+ free(f32databuf);
+ free(f64databuf);
+ return 1;
+ }
+ break;
case DType_FP32:
// FP32 -> FP64
f64databuf = (double*)calloc(sizeof(double), elements);
@@ -349,6 +377,8 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
free(f32databuf);
break;
case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
f32databuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(f32databuf);
f16databuf = (half_float::half*)calloc(sizeof(half_float::half), elements);
@@ -631,6 +661,8 @@ int TosaReference::Tensor::readfromVector(const ArrayProxy<float> vals)
// continue with setting float vals in the tensor
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
if (vals.size() != elements)
{
WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 2c3be7f..1659a2f 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -863,6 +863,8 @@ public:
case TOSA_REF_TYPE_FP32:
case TOSA_REF_TYPE_FP16:
case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP8E5M2:
switch (rank)
{
case 0:
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index 14bc6f1..4ae245b 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -36,6 +36,8 @@ NLOHMANN_JSON_SERIALIZE_ENUM(DType,
{ DType::DType_FP16, "FP16" },
{ DType::DType_BF16, "BF16" },
{ DType::DType_FP32, "FP32" },
+ { DType::DType_FP8E4M3, "FP8E4M3" },
+ { DType::DType_FP8E5M2, "FP8E5M2" },
})
} // namespace tosa
@@ -177,12 +179,13 @@ std::string positionToString(const std::vector<int32_t>& pos)
DType mapToDType(tosa_datatype_t dataType)
{
static std::map<tosa_datatype_t, DType> typeMap = {
- { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 },
- { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 },
- { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 },
- { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 },
- { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 },
- { tosa_datatype_shape_t, DType_SHAPE },
+ { tosa_datatype_bool_t, DType_BOOL }, { tosa_datatype_int4_t, DType_INT4 },
+ { tosa_datatype_int8_t, DType_INT8 }, { tosa_datatype_uint16_t, DType_UINT16 },
+ { tosa_datatype_int16_t, DType_INT16 }, { tosa_datatype_int32_t, DType_INT32 },
+ { tosa_datatype_int48_t, DType_INT48 }, { tosa_datatype_fp16_t, DType_FP16 },
+ { tosa_datatype_bf16_t, DType_BF16 }, { tosa_datatype_fp32_t, DType_FP32 },
+ { tosa_datatype_shape_t, DType_SHAPE }, { tosa_datatype_fp8e4m3_t, DType_FP8E4M3 },
+ { tosa_datatype_fp8e5m2_t, DType_FP8E5M2 },
};
if (typeMap.count(dataType))
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 8137a4369acefa4c01f08db95a2b1b290e8dd70
+Subproject a029f1f02707f40f6990df53fd4f56684490d58
diff --git a/verif/checker/tosa_result_checker.py b/verif/checker/tosa_result_checker.py
index 212c809..4d6d345 100644
--- a/verif/checker/tosa_result_checker.py
+++ b/verif/checker/tosa_result_checker.py
@@ -13,6 +13,7 @@ from checker.color_print import print_color
from checker.verifier import VerifierError
from checker.verifier import VerifierLibrary
from generator.tosa_utils import float32_is_valid_bfloat16
+from generator.tosa_utils import float32_is_valid_float8
from schemavalidation.schemavalidation import TestDescSchemaValidator
@@ -195,6 +196,18 @@ def test_check(
"reference_result: {ref_res_is_bf16}; test_result: {test_res_is_bf16}"
)
return (TestResult.INCORRECT_FORMAT, 0.0, msg)
+ if "fp8e4m3" in misc_checks or "fp8e5m2" in misc_checks:
+ # Ensure floats are valid float8 values
+ test_res_is_fp8 = all([float32_is_valid_float8(f) for f in test_result.flat])
+ ref_res_is_fp8 = all(
+ [float32_is_valid_float8(f) for f in reference_result.flat]
+ )
+ if not (test_res_is_fp8 and ref_res_is_fp8):
+ msg = (
+ "All output values must be valid float8. "
+ "reference_result: {ref_res_is_float8}; test_result: {test_res_is_float8}"
+ )
+ return (TestResult.INCORRECT_FLOAT, 0.0, msg)
# for quantized test, allow +-(quantize_tolerance) error
if reference_result.dtype in (
diff --git a/verif/conformance/tosa_main_profile_ops_info.json b/verif/conformance/tosa_main_profile_ops_info.json
index 7792417..7559c62 100644
--- a/verif/conformance/tosa_main_profile_ops_info.json
+++ b/verif/conformance/tosa_main_profile_ops_info.json
@@ -185,6 +185,30 @@
"2"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -233,6 +257,24 @@
"--allow-pooling-and-conv-oversizes"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -315,6 +357,30 @@
"2,65538,1,1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -527,6 +593,30 @@
"2"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -592,6 +682,30 @@
"1,2,1,65529"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -647,6 +761,24 @@
"--allow-pooling-and-conv-oversizes"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -722,6 +854,28 @@
"--allow-pooling-and-conv-oversizes"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--target-shape",
+ "1,7,18,5,4",
+ "--target-shape",
+ "1,6,12,17,3",
+ "--tensor-dim-range",
+ "1,4",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -787,6 +941,24 @@
"--allow-pooling-and-conv-oversizes"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -840,6 +1012,30 @@
"3"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1183,6 +1379,30 @@
"5000,1,1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1505,6 +1725,30 @@
"1,65538,3"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1551,6 +1795,24 @@
"--allow-pooling-and-conv-oversizes"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1699,6 +1961,30 @@
"1,1,65539,1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1889,6 +2175,30 @@
"2"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -1935,6 +2245,30 @@
"1,65535,1,2"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -2046,6 +2380,24 @@
"2989,6,1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -2091,6 +2443,30 @@
"1,65543,2,1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -2161,6 +2537,30 @@
"1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -2214,6 +2614,30 @@
"1"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-shape",
+ "10,24,9,13",
+ "--target-shape",
+ "8,14,20,5",
+ "--tensor-dim-range",
+ "1,16",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
@@ -3111,6 +3535,30 @@
"2"
]
]
+ },
+ "float8": {
+ "from_version" : "v0.100.0",
+ "no_negative_tests": "true",
+ "generator_args": [
+ [
+ "--target-dtype",
+ "fp8e4m3",
+ "--target-dtype",
+ "fp8e5m2",
+ "--fp-values-range",
+ "-max,max",
+ "--tensor-dim-range",
+ "32,64",
+ "--target-rank",
+ "1",
+ "--target-rank",
+ "2",
+ "--target-rank",
+ "3",
+ "--num-rand-permutations",
+ "2"
+ ]
+ ]
}
},
"selection": {
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index 7ec0cfe..d0b9eb9 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -641,6 +641,8 @@ class TosaTensorValuesGen:
DType.FP32: (1 << 128) - (1 << (127 - 23)),
DType.FP16: (1 << 16) - (1 << (15 - 10)),
DType.BF16: (1 << 128) - (1 << (127 - 7)),
+ DType.FP8E4M3: 448,
+ DType.FP8E5M2: 57344,
}
# Default lowest normal values for random numbers
@@ -648,6 +650,8 @@ class TosaTensorValuesGen:
DType.FP32: np.exp2(-126),
DType.FP16: np.exp2(-14),
DType.BF16: np.exp2(-126),
+ DType.FP8E4M3: np.exp2(-9),
+ DType.FP8E5M2: np.exp2(-16),
}
@staticmethod
@@ -715,6 +719,8 @@ class TosaTensorValuesGen:
DType.FP16,
DType.FP32,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
):
# Change from inclusive to exclusive range
data_range = (data_range[0], data_range[1] + 1)
@@ -1734,7 +1740,13 @@ class TosaArgGen:
and "data_gen" in testGen.TOSA_OP_LIST[opName]
and gtu.dtypeIsSupportedByCompliance(dtype)
):
- if dtype in [DType.FP16, DType.FP32, DType.BF16]:
+ if dtype in [
+ DType.FP16,
+ DType.FP32,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]:
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["fp"]
else:
dataGenTypesList = testGen.TOSA_OP_LIST[opName]["data_gen"]["int"]
@@ -2140,6 +2152,8 @@ class TosaArgGen:
accum_dtypes = [DType.FP32]
elif dtype == DType.FP32:
accum_dtypes = [DType.FP32]
+ elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
+ accum_dtypes = [DType.FP16]
elif error_name is None:
assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
@@ -2350,7 +2364,13 @@ class TosaArgGen:
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+ elif dtype in (
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
@@ -2468,6 +2488,8 @@ class TosaArgGen:
accum_dtypes = [DType.FP16, DType.FP32]
elif dtype == DType.BF16 or dtype == DType.FP32:
accum_dtypes = [DType.FP32]
+ elif dtype == DType.FP8E4M3 or dtype == DType.FP8E5M2:
+ accum_dtypes = [DType.FP16]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
else:
@@ -2646,11 +2668,35 @@ class TosaArgGen:
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
elif inDtype == DType.BF16:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
elif inDtype == DType.FP32:
- dtypeList = [DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.BF16]
+ dtypeList = [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ elif inDtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ dtypeList = [DType.FP16, DType.BF16, DType.FP32]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
@@ -3232,6 +3278,10 @@ class TosaArgGen:
outputDTypeList = [DType.BF16]
elif dtype == DType.FP32:
outputDTypeList = [DType.FP32]
+ elif dtype == DType.FP8E4M3:
+ outputDTypeList = [DType.FP8E4M3]
+ elif dtype == DType.FP8E5M2:
+ outputDTypeList = [DType.FP8E5M2]
elif error_name == ErrorIf.WrongInputType:
# If an incorrect input type is used then we set a 'correct'
# output type to avoid other errors
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index 9a88acb..7a4d0d6 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -325,12 +325,32 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP32]:
+ # if input_dtype in [DType.BOOL, DType.FP32]:
+ # outputDType = [DType.BOOL, DType.INT48, DType.FP32]
+ if input_dtype in [DType.BOOL]:
+ outputDType = [
+ DType.BOOL,
+ DType.INT48,
+ DType.FP32,
+ DType.FP16,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ elif input_dtype in [DType.FP32]:
outputDType = [DType.BOOL, DType.INT48, DType.FP32]
elif input_dtype in [DType.FP16, DType.BF16]:
outputDType = [DType.BOOL, DType.INT48]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
+ elif input_dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ outputDType = [
+ DType.BOOL,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ ]
else:
assert False, f"input_dtype ({input_dtype}) not supported"
return outputDType
@@ -476,13 +496,23 @@ class TosaErrorValidator:
)
or (input_dtype == DType.BF16 and output_dtype != DType.FP32)
or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
+ or (input_dtype == DType.FP8E4M3 and output_dtype != DType.FP16)
+ or (input_dtype == DType.FP8E5M2 and output_dtype != DType.FP16)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
input_dtype
- in [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+ in [
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
and output_dtype != DType.INT32
):
error_result = True
@@ -555,12 +585,26 @@ class TosaErrorValidator:
or (
input_dtype == DType.FP16
and output_dtype
- not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
)
or (
input_dtype == DType.BF16
and output_dtype
- not in [DType.INT8, DType.INT16, DType.INT32, DType.FP32]
+ not in [
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
)
or (
input_dtype == DType.FP32
@@ -571,6 +615,17 @@ class TosaErrorValidator:
DType.INT32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ]
+ )
+ or (
+ input_dtype in [DType.FP8E4M3, DType.FP8E5M2]
+ and output_dtype
+ not in [
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
]
)
):
@@ -597,6 +652,10 @@ class TosaErrorValidator:
and output_dtype != DType.FP32
or input_dtype == DType.FP32
and output_dtype != DType.FP32
+ or input_dtype == DType.FP8E4M3
+ and output_dtype != DType.FP16
+ or input_dtype == DType.FP8E5M2
+ and output_dtype != DType.FP16
):
error_result = True
# invalid input types are ignored, to avoid reporting multiple errors
@@ -2615,6 +2674,11 @@ class TosaErrorValidator:
DType.FP32,
):
error_result = True
+ elif (
+ input_dtype in (DType.FP8E4M3, DType.FP8E5M2)
+ and accum_dtype != DType.FP16
+ ):
+ error_result = True
info_dict = {
"error_name": error_name,
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 4ead982..bc931dc 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -76,7 +76,7 @@ class TosaTestGen:
return tuple(sorted(vals))
self.random_float_range = {}
- for dtype in (DType.FP32, DType.FP16, DType.BF16):
+ for dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
self.random_float_range[dtype] = convertFPRange(
args.tensor_fp_value_range,
TosaTensorValuesGen.TVG_FLOAT_HIGH_VALUE[dtype],
@@ -152,7 +152,7 @@ class TosaTestGen:
# Returns dtype value range boundaries (low, high)
# The high boundary is excluded in the range
# unless high_inclusive is True
- if dtype in (DType.FP32, DType.FP16, DType.BF16):
+ if dtype in (DType.FP32, DType.FP16, DType.BF16, DType.FP8E4M3, DType.FP8E5M2):
return self.random_float_range[dtype]
elif dtype == DType.BOOL:
rng = (0, 2)
@@ -197,7 +197,13 @@ class TosaTestGen:
return np.uint8(self.rng.integers(low=low, high=high, size=shape))
elif dtype in (DType.INT48, DType.SHAPE):
return np.int64(self.rng.integers(low=low, high=high, size=shape))
- elif dtype in (DType.FP16, DType.BF16, DType.FP32):
+ elif dtype in (
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ ):
f_tensor = self.rng.uniform(low=low, high=high, size=shape)
if dtype == DType.FP16:
@@ -207,6 +213,10 @@ class TosaTestGen:
if dtype == DType.BF16:
# Floor the last 16 bits of each f32 value
return np.float32(gtu.vect_f32_to_bf16(f32_tensor))
+ elif dtype == DType.FP8E4M3:
+ return np.float32(gtu.vect_f32_to_fp8e4m3(f32_tensor))
+ elif dtype == DType.FP8E5M2:
+ return np.float32(gtu.vect_f32_to_fp8e5m2(f32_tensor))
else:
return f32_tensor
else:
@@ -266,6 +276,12 @@ class TosaTestGen:
elif dtype == DType.BF16:
rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
return gtu.vect_f32_to_bf16(rand_f32)
+ elif dtype == DType.FP8E4M3:
+ rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+ return gtu.vect_f32_to_fp8e4m3(rand_f32)
+ elif dtype == DType.FP8E5M2:
+ rand_f32 = np.float32(self.rng.uniform(low=low, high=high))
+ return gtu.vect_f32_to_fp8e5m2(rand_f32)
elif dtype == DType.BOOL:
return self.rng.choice([False, True])
elif dtype == DType.INT48 or dtype == DType.SHAPE:
@@ -1408,8 +1424,11 @@ class TosaTestGen:
max_val = max_val.astype(np.float32)
attr.ClampAttribute(self.ser.builder, 0, 0, min_val, max_val)
- else:
+ elif a.dtype in (DType.INT8, DType.INT16):
attr.ClampAttribute(self.ser.builder, min_val, max_val, 0, 0)
+ else:
+ # to avoid internal error for incorrect input types
+ attr.ClampAttribute(self.ser.builder, 0, 0, 0, 0)
self.ser.addOperator(op["op"], input_list, output_list, attr)
@@ -3190,7 +3209,13 @@ class TosaTestGen:
]
TYPE_FI16 = [DType.FP32, DType.INT16]
- TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.BF16, DType.FP32]
+ TYPE_NARROW_INT_FP = [
+ DType.INT8,
+ DType.INT16,
+ DType.FP16,
+ DType.BF16,
+ DType.FP32,
+ ]
# List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
@@ -3201,6 +3226,8 @@ class TosaTestGen:
[DType.FP16, DType.FP16, DType.FP32],
[DType.BF16, DType.BF16, DType.FP32],
[DType.FP32, DType.FP32, DType.FP32],
+ [DType.FP8E4M3, DType.FP8E4M3, DType.FP16],
+ [DType.FP8E5M2, DType.FP8E5M2, DType.FP16],
]
DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
@@ -3217,7 +3244,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
@@ -3244,7 +3271,7 @@ class TosaTestGen:
TosaArgGen.agPooling,
),
"qgen": TosaQuantGen.qgUnary,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
@@ -3402,7 +3429,7 @@ class TosaTestGen:
TosaArgGen.agMatMul,
),
"qgen": TosaQuantGen.qgMatmul,
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evInputZeroPointNotZero,
TosaErrorValidator.evWrongRank,
@@ -3425,7 +3452,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agPooling,
),
- "types": TYPE_NARROW_INT_FP,
+ "types": TYPE_NARROW_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"invalid_test_validators": (TosaInvalidValidator.ivHeightWidthInvalid,),
"error_if_validators": (
TosaErrorValidator.evKernelSmallerOne,
@@ -4389,7 +4416,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgConcat,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -4413,7 +4440,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgPad,
TosaArgGen.agPad,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evPadSmallerZero,
@@ -4437,7 +4464,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -4456,7 +4483,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgReshape,
TosaArgGen.agReshape,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evTensorSizeInputOutputMismatch,
TosaErrorValidator.evWrongInputType,
@@ -4477,7 +4504,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agAxis,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evAxisSmallerZero,
TosaErrorValidator.evAxisLargerRank,
@@ -4500,7 +4527,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgSlice,
TosaArgGen.agSlice,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
# TODO Turn off these error categories for now as the reference
# model cannot allocate memory space for empty tensor. We probably
@@ -4532,7 +4559,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgTile,
TosaArgGen.agTile,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -4555,7 +4582,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agTranspose,
),
- "types": TYPE_FIB,
+ "types": TYPE_FIB + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evIndexOutsideBounds,
TosaErrorValidator.evIndexUsedTwice,
@@ -4581,7 +4608,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgLazyGenDefault,
TosaArgGen.agNone,
),
- "types": TYPE_FIB + [DType.INT48],
+ "types": TYPE_FIB + [DType.INT48, DType.FP8E4M3, DType.FP8E5M2],
"data_gen": {
"fp": (gtu.DataGenType.PSEUDO_RANDOM,),
},
@@ -4618,6 +4645,8 @@ class TosaTestGen:
DType.FP16,
DType.BF16,
DType.FP32,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -4640,7 +4669,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgScatter,
TosaArgGen.agNone,
),
- "types": TYPE_INT_FP,
+ "types": TYPE_INT_FP + [DType.FP8E4M3, DType.FP8E5M2],
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -4709,6 +4738,8 @@ class TosaTestGen:
DType.INT16,
DType.INT32,
DType.BOOL,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
@@ -5141,6 +5172,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
@@ -5194,6 +5227,8 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
excludes = [DType.FP16, DType.FP32]
+ if ifm.dtype in [DType.FP8E4M3, DType.FP8E5M2]:
+ excludes = [DType.FP16]
else:
excludes = [out_dtype]
wrong_dtypes = list(gtu.usableDTypes(excludes=excludes))
@@ -5344,6 +5379,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -5383,6 +5420,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
elif a.dtype == DType.INT16:
incorrect_types = (
@@ -5393,6 +5432,20 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
+ )
+ elif a.dtype == DType.FP8E4M3 or a.dtype == DType.FP8E5M2:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
elif (
a.dtype == DType.FP32 or a.dtype == DType.FP16 or a.dtype == DType.BF16
@@ -5403,6 +5456,8 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
)
out_dtype = rng.choice(a=incorrect_types)
elif error_name == ErrorIf.WrongInputType:
@@ -5669,6 +5724,8 @@ class OutputShaper:
DType.FP32,
DType.FP16,
DType.BF16,
+ DType.FP8E4M3,
+ DType.FP8E5M2,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 76e7388..31a0ff0 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -27,6 +27,8 @@ DTYPE_ATTRIBUTES = {
DType.FP16: {"str": "f16", "width": 16, "json": "FP16"},
DType.BF16: {"str": "bf16", "width": 16, "json": "BF16"},
DType.FP32: {"str": "f32", "width": 32, "json": "FP32"},
+ DType.FP8E4M3: {"str": "f8e4m3", "width": 8, "json": "FP8E4M3"},
+ DType.FP8E5M2: {"str": "f8e5m2", "width": 8, "json": "FP8E5M2"},
}
@@ -186,6 +188,16 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT32,
DType.INT48,
)
+ elif input_dtype == DType.FP8E4M3 or input_dtype == DType.FP8E5M2:
+ incorrect_types = (
+ DType.INT4,
+ DType.INT8,
+ DType.INT16,
+ DType.INT32,
+ DType.INT48,
+ DType.FP32,
+ DType.BF16,
+ )
else:
# Assume all types but the input type are incorrect
incorrect_types = list(usableDTypes(excludes=(input_dtype,)))
@@ -209,6 +221,12 @@ def float32_is_valid_bfloat16(f):
return f32_bits[16:] == "0" * 16
+def float32_is_valid_float8(f):
+ """Return True if float value is valid float8."""
+ f32_bits = get_float32_bitstring(f)
+ return f32_bits[8:] == "0" * 24
+
+
def get_float32_bitstring(f):
"""Return a big-endian string of bits representing a 32 bit float."""
f32_bits_as_int = struct.unpack(">L", struct.pack(">f", f))[0]
@@ -232,6 +250,30 @@ def float32_to_bfloat16(f):
return struct.unpack("@f", fp_bytes)[0] # native byteorder
+def float32_to_fp8e4m3(f):
+ """Turns fp32 value into fp8e4m3"""
+ f32_bits = get_float32_bitstring(f)
+ fp8_bits = f32_bits[0] + f32_bits[1:5] + f32_bits[9:12] + "0" * 24
+ fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0] # native byteorder
+
+
+def float32_to_fp8e5m2(f):
+ """Turns fp32 value into fp8e5m2"""
+ f32_bits = get_float32_bitstring(f)
+ fp8_bits = f32_bits[0] + f32_bits[1:6] + f32_bits[9:11] + "0" * 24
+ fp_bytes = int(fp8_bits, 2).to_bytes(4, byteorder=sys.byteorder)
+ return struct.unpack("@f", fp_bytes)[0]
+
+
vect_f32_to_bf16 = np.vectorize(
float32_to_bfloat16, otypes=(np.float32,)
) # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e4m3 = np.vectorize(
+ float32_to_fp8e4m3, otypes=(np.float32,)
+) # NumPy vectorize: applies function to vector faster than looping
+
+vect_f32_to_fp8e5m2 = np.vectorize(
+ float32_to_fp8e5m2, otypes=(np.float32,)
+) # Numpy vectorize: applies function to vector faster than looping