aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2024-04-05 01:19:31 +0000
committerTai Ly <tai.ly@arm.com>2024-04-15 14:28:29 +0000
commit5d0e9c7f3748e80d6f14a3eeaef858eeb912e1fd (patch)
treed9dddba756207cee68b948d434502801be93d6c4
parent6dc755bf141726a7582ad1a844f97cb3f50c9b21 (diff)
downloadreference_model-5d0e9c7f3748e80d6f14a3eeaef858eeb912e1fd.tar.gz
[ref model] fix const/pad/clamp attribute serialization
This changes to use native type serialization and deserialization for pad_const, clamp min_val/max_val and const data attribute values whereby fp16 values are stored as 2 bytes each, fp8 values are stored in 1 byte each, etc. Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: Ia95d320fe8c546ce1d1ccc035d6e9bcaadcc9ca3
-rw-r--r--reference_model/include/dtype.h45
-rw-r--r--reference_model/src/float_utils.h533
-rw-r--r--reference_model/src/ops/activation_funcs.cc96
-rw-r--r--reference_model/src/ops/activation_funcs.h1
-rw-r--r--reference_model/src/ops/data_layout.cc67
-rw-r--r--reference_model/src/ops/ewise_unary.cc7
-rw-r--r--reference_model/src/ops/type_conversion.cc10
-rw-r--r--reference_model/src/subgraph_traverser.cc20
m---------thirdparty/serialization_lib0
-rw-r--r--verif/generator/tosa_test_gen.py35
10 files changed, 188 insertions, 626 deletions
diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h
index a283f39..3463af9 100644
--- a/reference_model/include/dtype.h
+++ b/reference_model/include/dtype.h
@@ -89,26 +89,8 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e)
}
// return corresponding TOSA_REF_TYPE for DType
-inline TOSA_REF_TYPE ConvertDType(const DType dtype)
+inline TOSA_REF_TYPE DType2RefType(const DType dtype)
{
- assert(DType_MAX == DType_FP8E5M2); // must update whenever DType_MAX changes
-
- if (g_func_config.precise_mode)
- {
- // in precise mode, convert all floating DType to TOSA_REF_TYPE_FP64
- switch (dtype)
- {
- case DType_FP16:
- case DType_FP32:
- case DType_BF16:
- case DType_FP8E4M3:
- case DType_FP8E5M2:
- return TOSA_REF_TYPE_FP64;
- default:
- break;
- }
- }
-
switch (dtype)
{
case DType_BOOL:
@@ -145,6 +127,31 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype)
return TOSA_REF_TYPE_UNKNOWN;
}
+// return corresponding TOSA_REF_TYPE for DType
+// if precise_mode, convert all floating dtype to FP64
+inline TOSA_REF_TYPE ConvertDType(const DType dtype)
+{
+ assert(DType_MAX == DType_FP8E5M2); // must update whenever DType_MAX changes
+
+ if (g_func_config.precise_mode)
+ {
+ // in precise mode, convert all floating DType to TOSA_REF_TYPE_FP64
+ switch (dtype)
+ {
+ case DType_FP16:
+ case DType_FP32:
+ case DType_BF16:
+ case DType_FP8E4M3:
+ case DType_FP8E5M2:
+ return TOSA_REF_TYPE_FP64;
+ default:
+ break;
+ }
+ }
+
+ return DType2RefType(dtype);
+}
+
template <TOSA_REF_TYPE Dtype>
bool IsSignedInt()
{
diff --git a/reference_model/src/float_utils.h b/reference_model/src/float_utils.h
deleted file mode 100644
index b98c89b..0000000
--- a/reference_model/src/float_utils.h
+++ /dev/null
@@ -1,533 +0,0 @@
-// Copyright (c) 2024, ARM Limited.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#ifndef FLOAT_UTILS_H_
-#define FLOAT_UTILS_H_
-
-#include <algorithm>
-#include <cstdint>
-#include <limits>
-#include <type_traits>
-#if defined(__cpp_lib_bit_cast)
-#include <bit>
-#endif // defined(__cpp_lib_bit_cast)
-
-namespace tosa::reference::internal
-{
-
-namespace float_support
-{
-
-struct hidden
-{};
-
-#if defined(__cpp_lib_bit_cast)
-#define BITCAST_CONSTEXPR constexpr inline
-
-constexpr inline int32_t get_bits(const float& f)
-{
- return std::bit_cast<int32_t>(f);
-}
-constexpr inline float from_bits(const int32_t& i)
-{
- return std::bit_cast<float>(i);
-}
-
-#else
-#define BITCAST_CONSTEXPR inline
-
-union ufloat32
-{
- constexpr ufloat32(const float& x)
- : f(x)
- {}
- constexpr ufloat32(const int32_t& x)
- : i(x)
- {}
-
- float f;
- int32_t i;
-};
-
-inline int32_t get_bits(const float& f)
-{
- return ufloat32(f).i;
-}
-inline float from_bits(const int32_t& i)
-{
- return ufloat32(i).f;
-}
-#endif
-
-} // namespace float_support
-
-template <typename storage_t,
- size_t n_exp_bits,
- bool has_nan,
- bool with_denorm,
- bool with_infinity,
- std::enable_if_t<(n_exp_bits + 1 < sizeof(storage_t) * 8), bool> = true>
-class float_t
-{
- storage_t m_data = 0;
-
-public:
- static constexpr size_t n_exponent_bits = n_exp_bits;
- static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits;
- static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1;
-
- /// \brief Construct a floating point type with the given bit
- /// representation.
- static constexpr float_t from_bits(storage_t bits)
- {
- return float_t(float_support::hidden(), bits);
- }
-
- /// \brief Construct a float from the given sign, exponent and significand
- /// bits.
- static constexpr float_t from_bits(bool pm, storage_t e, storage_t s)
- {
- storage_t bits = pm ? 1 : 0;
-
- bits <<= n_exp_bits;
- bits |= e;
-
- bits <<= n_significand_bits;
- if (with_denorm || e)
- bits |= s;
-
- return float_t(float_support::hidden(), bits);
- }
-
- /// \brief (Hidden) Construct a float type from a given bit pattern
- constexpr float_t(const float_support::hidden&, storage_t bits)
- : m_data(bits)
- {}
-
- constexpr float_t()
- : m_data(0)
- {}
- constexpr float_t(const float_t& other)
- : m_data(other.m_data)
- {}
-
- /// \brief Cast to a different floating point representation.
- template <typename other_storage_t,
- size_t other_n_exp_bits,
- bool other_has_nan,
- bool other_has_denorm,
- bool other_has_infinity>
- constexpr inline
- operator float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>() const
- {
- using other_float_t =
- float_t<other_storage_t, other_n_exp_bits, other_has_nan, other_has_denorm, other_has_infinity>;
-
- // Shortcut for types which are fundamentally similar (e.g., bf16 ->
- // fp32)
- if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) &&
- has_nan == other_has_nan)
- {
- return other_float_t::from_bits(static_cast<other_storage_t>(m_data)
- << (sizeof(other_storage_t) - sizeof(storage_t)) * 8);
- }
-
- // Get initial values for the new floating point type
- const bool sign_bit = m_data < 0;
- int64_t new_exponent_bits = 0;
- uint64_t new_significand = 0;
-
- if (is_nan() || is_infinity())
- {
- new_exponent_bits = (1 << other_n_exp_bits) - 1;
-
- if (is_nan())
- {
- if constexpr (other_has_infinity)
- {
- // Copy across the `not_quiet bit`; set the LSB. Don't
- // attempt to copy across any of the rest of the payload.
- new_significand =
- 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits);
- }
- else
- {
- new_significand = (1ul << other_float_t::n_significand_bits) - 1;
- }
- }
- else if constexpr (!other_has_infinity)
- {
- new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1);
- }
- }
- else if (!is_zero())
- {
- const int64_t this_exponent_bits = exponent_bits();
- {
- constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias;
- new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1);
- }
- new_significand = this->significand() << (64 - n_significand_bits);
-
- // Normalise subnormals
- if (this_exponent_bits == 0)
- {
- // Shift the most-significant 1 out of the magnitude to convert
- // it to a significand. Modify the exponent accordingly.
- uint8_t shift = __builtin_clzl(new_significand) + 1;
- new_exponent_bits -= shift;
- new_significand <<= shift;
- }
-
- // Align the significand for the output type
- uint32_t shift = 64 - other_float_t::n_significand_bits;
- const bool other_is_subnormal = new_exponent_bits <= 0;
- if (other_is_subnormal)
- {
- shift += 1 - new_exponent_bits;
- new_exponent_bits = 0;
- }
-
- const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1);
- new_significand = shift == 64 ? 0 : new_significand >> shift;
-
- // Reinsert the most-significant-one if this is a subnormal in the
- // output type.
- new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift);
-
- // Apply rounding based on the bits shifted out of the significand
- const uint64_t shift_half = 1ll << (shift - 1);
- if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1)))
- {
- new_significand += 1;
-
- // Handle the case that the significand overflowed due to
- // rounding
- constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1;
- if (new_significand > max_significand)
- {
- new_significand = 0;
- new_exponent_bits++;
- }
- }
-
- // Saturate to infinity if the exponent is larger than can be
- // represented in the output type. This can only occur if the size
- // of the exponent of the new type is not greater than the exponent
- // of the old type.
- if constexpr (other_n_exp_bits <= n_exp_bits)
- {
- constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1;
- if (new_exponent_bits >= inf_exp_bits)
- {
- new_exponent_bits = inf_exp_bits;
- new_significand =
- other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1);
- }
- }
- }
-
- return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand);
- }
-
- /// \brief Convert from a 32-bit floating point value
- BITCAST_CONSTEXPR
- float_t(const float& f)
- {
- // If this format exactly represents the binary32 format then get
- // the bits from the provided float; otherwise get a binary32
- // representation and then convert to this format.
- if constexpr (represents_binary32())
- m_data = float_support::get_bits(f);
- else
- m_data = static_cast<float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_infinity>>(
- static_cast<float_t<int32_t, 8, true, true, true>>(f))
- .m_data;
- }
-
- /// \brief Cast to a 32-bit floating point value
- BITCAST_CONSTEXPR operator float() const
- {
- // If this format exactly represents the binary32 format then return
- // a float; otherwise get a binary32 representation and then return
- // a float.
- if constexpr (represents_binary32())
- return float_support::from_bits(m_data);
- else
- return static_cast<float>(this->operator float_t<int32_t, 8, true, true, true>());
- }
-
- /// \brief Return whether this type represents the IEEE754 binary32
- /// format
- constexpr static inline bool represents_binary32()
- {
- return std::is_same_v<storage_t, int32_t> && n_exp_bits == 8 && has_nan && with_denorm && with_infinity;
- }
-
- constexpr auto operator-() const
- {
- return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1)));
- }
-
- constexpr bool is_subnormal() const
- {
- return exponent_bits() == 0 && significand() != 0;
- }
-
- constexpr bool is_zero() const
- {
- return exponent_bits() == 0 && significand() == 0;
- }
-
- constexpr bool is_nan() const
- {
- return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) &&
- ((with_infinity && significand()) ||
- (!with_infinity && significand() == (1ul << n_significand_bits) - 1));
- }
-
- constexpr bool is_infinity() const
- {
- return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand());
- }
-
- constexpr inline const storage_t& bits() const
- {
- return m_data;
- }
-
- /// \brief Get the exponent
- constexpr inline int64_t exponent() const
- {
- return std::max<int64_t>(exponent_bits(), 1ul) - exponent_bias;
- }
-
- /// \brief Get the bits from the exponent
- constexpr inline uint64_t exponent_bits() const
- {
- constexpr uint64_t mask = (1ul << n_exp_bits) - 1;
- return (m_data >> n_significand_bits) & mask;
- }
-
- constexpr inline uint64_t significand() const
- {
- return m_data & ((1ul << n_significand_bits) - 1);
- }
-
- constexpr inline bool operator==(const float_t& other) const
- {
- return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits());
- }
-
- constexpr inline float_t& operator+=(const float_t& rhs)
- {
- this->m_data = static_cast<float_t>(static_cast<float>(*this) + static_cast<float>(rhs)).bits();
- return *this;
- }
-};
-
-// This should probably be exported so we can use it elsewhere
-#undef BITCAST_CONSTEXPR
-
-namespace float_support
-{
-
-// Pre-C++23 these can't be computed as constexpr, so have to hardcode them
-
-template <int>
-struct digits10; // floor(log10(2) * (digits - 1)
-template <int>
-struct max_digits10; // ceil(log10(2) * digits + 1)
-template <int>
-struct min_exponent10; // floor(log10(2) * min_exponent)
-template <int>
-struct max_exponent10; // floor(log10(2) * max_exponent)
-
-template <>
-struct digits10<8>
-{
- constexpr static inline int value = 2;
-};
-
-template <>
-struct max_digits10<8>
-{
- constexpr static inline int value = 4;
-};
-
-template <>
-struct digits10<10>
-{
- constexpr static inline int value = 2;
-};
-
-template <>
-struct max_digits10<10>
-{
- constexpr static inline int value = 5;
-};
-
-template <>
-struct digits10<24>
-{
- constexpr static inline int value = 6;
-};
-
-template <>
-struct max_digits10<24>
-{
- constexpr static inline int value = 9;
-};
-
-template <>
-struct min_exponent10<-13>
-{
- constexpr static inline int value = -3;
-};
-
-template <>
-struct max_exponent10<16>
-{
- constexpr static inline int value = 4;
-};
-
-template <>
-struct min_exponent10<-125>
-{
- constexpr static inline int value = -37;
-};
-
-template <>
-struct max_exponent10<128>
-{
- constexpr static inline int value = 38;
-};
-
-template <int d>
-inline constexpr int digits10_v = digits10<d>::value;
-template <int d>
-inline constexpr int max_digits10_v = max_digits10<d>::value;
-
-template <int e>
-inline constexpr int min_exponent10_v = min_exponent10<e>::value;
-
-template <int e>
-inline constexpr int max_exponent10_v = max_exponent10<e>::value;
-
-} // namespace float_support
-
-} // namespace tosa::reference::internal
-
-namespace std
-{
-
-template <typename storage_t, size_t n_exp_bits, bool has_nan, bool has_denorm, bool has_inf>
-struct is_floating_point<tosa::reference::internal::float_t<storage_t, n_exp_bits, has_nan, has_denorm, has_inf>>
- : std::integral_constant<bool, true>
-{};
-
-template <typename storage_t, size_t n_exp_bits, bool has_nan, bool with_denorm, bool with_inf>
-class numeric_limits<tosa::reference::internal::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>>
-{
- using this_float_t = tosa::reference::internal::float_t<storage_t, n_exp_bits, has_nan, with_denorm, with_inf>;
-
-public:
- static constexpr bool is_specialized = true;
-
- static constexpr auto min() noexcept
- {
- return this_float_t::from_bits(false, 1, 0);
- }
-
- static constexpr auto max() noexcept
- {
- return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2,
- (1 << this_float_t::n_significand_bits) - 1);
- }
-
- static constexpr auto lowest() noexcept
- {
- return -max();
- }
-
- static constexpr int digits = this_float_t::n_significand_bits + 1;
- static constexpr int digits10 = tosa::reference::internal::float_support::digits10_v<digits>;
- static constexpr int max_digits10 = tosa::reference::internal::float_support::max_digits10_v<digits>;
-
- static constexpr bool is_signed = true;
- static constexpr bool is_integer = false;
- static constexpr bool is_exact = false;
- static constexpr int radix = 2;
-
- static constexpr auto epsilon() noexcept
- {
- return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0);
- }
-
- static constexpr auto round_error() noexcept
- {
- return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0);
- }
-
- static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1;
- static constexpr int min_exponent10 = tosa::reference::internal::float_support::min_exponent10_v<min_exponent>;
- static constexpr int max_exponent = this_float_t::exponent_bias + 1;
- static constexpr int max_exponent10 = tosa::reference::internal::float_support::max_exponent10_v<max_exponent>;
-
- static constexpr bool has_infinity = with_inf;
- static constexpr bool has_quiet_NaN = has_nan;
- static constexpr bool has_signaling_NaN = true;
- static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent;
- static constexpr bool has_denorm_loss = false;
-
- static constexpr auto infinity() noexcept
- {
- if constexpr (with_inf)
- {
- return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0);
- }
- else
- {
- return this_float_t::from_bits(false, 0, 0);
- }
- }
-
- static constexpr auto quiet_NaN() noexcept
- {
- return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1,
- 1 << (this_float_t::n_significand_bits - 1) | 1);
- }
-
- static constexpr auto signaling_NaN() noexcept
- {
- return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1);
- }
-
- static constexpr auto denorm_min() noexcept
- {
- return this_float_t::from_bits(false, 0, 1);
- }
-
- static constexpr bool is_iec559 = false;
- static constexpr bool is_bounded = false;
- static constexpr bool is_modulo = false;
-
- static constexpr bool traps = false;
- static constexpr bool tinyness_before = false;
- static constexpr float_round_style round_style = round_to_nearest;
-};
-
-} // namespace std
-
-#endif // _FLOAT_UTILS_H_
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index de7d8be..fc2a9ac 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -31,53 +31,79 @@ int OpClamp<Rank, Dtype>::register_fcn()
auto tosa_level = g_func_config.tosa_level;
LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK");
- switch (Dtype)
+ ASSERT_MSG(!(static_cast<GraphNode*>(this))->getOutputs().empty(),
+ "Must call register_fcn after tensors are linked to nodes");
+
+ InEigenType min, max;
+
+ // need to use input tensor's serializationDtype to deserialize min/max values
+ // because Dtype may be FP64 in precise_mode
+ auto serializationDtype = (static_cast<GraphNode*>(this))->getInputs()[0]->getSerializationDtype();
+ switch (DType2RefType(serializationDtype))
{
- case TOSA_REF_TYPE_FP16:
- case TOSA_REF_TYPE_BF16:
- case TOSA_REF_TYPE_FP32: {
+ case TOSA_REF_TYPE_FP16: {
+ std::vector<half_float::half> min_float_data, max_float_data;
+ TosaSerializationHandler::ConvertU8toF16(attribute->min_val(), /* size = */ 1, min_float_data);
+ TosaSerializationHandler::ConvertU8toF16(attribute->max_val(), /* size = */ 1, max_float_data);
+ min = (InEigenType)min_float_data[0];
+ max = (InEigenType)max_float_data[0];
+ }
+ break;
+ case TOSA_REF_TYPE_BF16: {
std::vector<float> min_float_data, max_float_data;
- TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data);
- TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data);
- InEigenType min = (InEigenType)min_float_data[0];
- InEigenType max = (InEigenType)max_float_data[0];
- ERROR_IF(max < min, "OpClamp: max smaller than min");
-
- this->fcn = [min, max](InEigenType a) -> OutEigenType {
- return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a);
- };
+ TosaSerializationHandler::ConvertU8toBF16(attribute->min_val(), /* size = */ 1, min_float_data);
+ TosaSerializationHandler::ConvertU8toBF16(attribute->max_val(), /* size = */ 1, max_float_data);
+ min = (InEigenType)min_float_data[0];
+ max = (InEigenType)max_float_data[0];
}
break;
- case TOSA_REF_TYPE_FP64: {
+ case TOSA_REF_TYPE_FP32: {
std::vector<float> min_float_data, max_float_data;
TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data);
TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data);
- InEigenType min = (InEigenType)min_float_data[0];
- InEigenType max = (InEigenType)max_float_data[0];
- ERROR_IF(max < min, "OpClamp: max smaller than min");
-
- this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); };
+ min = (InEigenType)min_float_data[0];
+ max = (InEigenType)max_float_data[0];
}
break;
case TOSA_REF_TYPE_INT8: {
- std::vector<int32_t> min_int_data, max_int_data;
- TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data);
- TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data);
- int8_t min = (int8_t)min_int_data[0];
- int8_t max = (int8_t)max_int_data[0];
-
- ERROR_IF(max < min, "OpClamp: max smaller than min");
- this->fcn = [min, max](int8_t a) -> int8_t { return a <= min ? min : a >= max ? max : a; };
+ std::vector<int8_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI8(attribute->min_val(), /* size = */ 1, min_int_data);
+ TosaSerializationHandler::ConvertU8toI8(attribute->max_val(), /* size = */ 1, max_int_data);
+ min = (InEigenType)min_int_data[0];
+ max = (InEigenType)max_int_data[0];
+ }
+ break;
+ case TOSA_REF_TYPE_INT16: {
+ std::vector<int16_t> min_int_data, max_int_data;
+ TosaSerializationHandler::ConvertU8toI16(attribute->min_val(), /* size = */ 1, min_int_data);
+ TosaSerializationHandler::ConvertU8toI16(attribute->max_val(), /* size = */ 1, max_int_data);
+ min = (InEigenType)min_int_data[0];
+ max = (InEigenType)max_int_data[0];
}
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ ERROR_IF(max < min, "OpClamp: max smaller than min");
+
+ // evaluation function is still based on Dtype
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32: {
+ // apply fpTrunc<Dtype> after min/max
+ this->fcn = [min, max](InEigenType a) -> OutEigenType {
+ return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a);
+ };
+ }
+ break;
+ case TOSA_REF_TYPE_FP64:
+ case TOSA_REF_TYPE_INT8:
case TOSA_REF_TYPE_INT16: {
- std::vector<int32_t> min_int_data, max_int_data;
- TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data);
- TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data);
- int16_t min = (int16_t)min_int_data[0];
- int16_t max = (int16_t)max_int_data[0];
-
- ERROR_IF(max < min, "OpClamp: max smaller than min");
- this->fcn = [min, max](int16_t a) -> int16_t { return a <= min ? min : a >= max ? max : a; };
+ // simply min/max
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); };
}
break;
default:
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h
index 1696668..055642a 100644
--- a/reference_model/src/ops/activation_funcs.h
+++ b/reference_model/src/ops/activation_funcs.h
@@ -32,7 +32,6 @@ public:
: UnaryNode<Rank, Dtype>(sgt_, Op_CLAMP, id_)
{
INIT_ATTRIBUTE(Clamp);
- register_fcn();
}
virtual ~OpClamp();
static constexpr int32_t QMin = GetQMin<Dtype>::value;
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index e264284..6664ec3 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -171,11 +171,31 @@ int OpPad<Rank, Dtype>::eval()
{
InEigenType pad_value = 0;
- switch (Dtype)
- {
- case TOSA_REF_TYPE_BOOL:
- case TOSA_REF_TYPE_INT8:
- case TOSA_REF_TYPE_INT16:
+ // need to use input tensor's serializationDtype to deserialize pad_const
+ // because Dtype may be FP64 in precise_mode
+ switch (DType2RefType(inputs[0]->getSerializationDtype()))
+ {
+ case TOSA_REF_TYPE_BOOL: {
+ std::vector<bool> bool_data;
+ TosaSerializationHandler::ConvertU8toBool(attribute->pad_const(),
+ /* size = */ 1, bool_data);
+ pad_value = (InEigenType)bool_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_INT8: {
+ std::vector<int8_t> int8_data;
+ TosaSerializationHandler::ConvertU8toI8(attribute->pad_const(),
+ /* size = */ 1, int8_data);
+ pad_value = (InEigenType)int8_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_INT16: {
+ std::vector<int16_t> int16_data;
+ TosaSerializationHandler::ConvertU8toI16(attribute->pad_const(),
+ /* size = */ 1, int16_data);
+ pad_value = (InEigenType)int16_data[0];
+ break;
+ }
case TOSA_REF_TYPE_INT32: {
std::vector<int32_t> int32_data;
TosaSerializationHandler::ConvertU8toI32(attribute->pad_const(),
@@ -183,15 +203,38 @@ int OpPad<Rank, Dtype>::eval()
pad_value = (InEigenType)int32_data[0];
break;
}
- case TOSA_REF_TYPE_FP16:
- case TOSA_REF_TYPE_BF16:
- case TOSA_REF_TYPE_FP32:
- case TOSA_REF_TYPE_FP64:
- case TOSA_REF_TYPE_FP8E4M3:
+ case TOSA_REF_TYPE_FP16: {
+ std::vector<half_float::half> f16_data;
+ TosaSerializationHandler::ConvertU8toF16(attribute->pad_const(),
+ /* size = */ 1, f16_data);
+ pad_value = (InEigenType)f16_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_BF16: {
+ std::vector<float> f32_data;
+ TosaSerializationHandler::ConvertU8toBF16(attribute->pad_const(),
+ /* size = */ 1, f32_data);
+ pad_value = (InEigenType)f32_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_FP32: {
+ std::vector<float> f32_data;
+ TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(),
+ /* size = */ 1, f32_data);
+ pad_value = (InEigenType)f32_data[0];
+ break;
+ }
+ case TOSA_REF_TYPE_FP8E4M3: {
+ std::vector<float> f32_data;
+ TosaSerializationHandler::ConvertU8toFP8E4M3(attribute->pad_const(),
+ /* size = */ 1, f32_data);
+ pad_value = (InEigenType)f32_data[0];
+ break;
+ }
case TOSA_REF_TYPE_FP8E5M2: {
std::vector<float> float_data;
- TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(),
- /* size = */ 1, float_data);
+ TosaSerializationHandler::ConvertU8toFP8E5M2(attribute->pad_const(),
+ /* size = */ 1, float_data);
pad_value = (InEigenType)float_data[0];
break;
}
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index dd9ea5a..310a174 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -66,6 +66,13 @@ int UnaryNode<Rank, Dtype>::checkTensorAttributes()
template <int Rank, TOSA_REF_TYPE Dtype>
int UnaryNode<Rank, Dtype>::eval()
{
+ // call register_fcn() here to ensure inputs/outputs have been connected
+ // to the node by the time register_fcn() is called for Clamp Operator
+ if (register_fcn())
+ {
+ return 1;
+ }
+
this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
return GraphNode::eval();
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 7bca697..40e6c64 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -25,11 +25,11 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-using fp16 = tosa::reference::internal::float_t<int16_t, 5, true, true, true>;
-using bf16 = tosa::reference::internal::float_t<int16_t, 8, true, true, true>;
-using fp32 = tosa::reference::internal::float_t<int32_t, 8, true, true, true>;
-using fp8e4m3 = tosa::reference::internal::float_t<int8_t, 4, true, true, false>;
-using fp8e5m2 = tosa::reference::internal::float_t<int8_t, 5, true, true, true>;
+using fp16 = tosa::float_t<int16_t, 5, true, true, true>;
+using bf16 = tosa::float_t<int16_t, 8, true, true, true>;
+using fp32 = tosa::float_t<int32_t, 8, true, true, true>;
+using fp8e4m3 = tosa::float_t<int8_t, 4, true, true, false>;
+using fp8e5m2 = tosa::float_t<int8_t, 5, true, true, true>;
template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index 52b1806..6aa0a45 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -580,7 +580,7 @@ int SubgraphTraverser::allocateTensor(std::string name)
break;
case DType_BF16: {
std::vector<float> fp32_data;
- TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
+ TosaSerializationHandler::ConvertU8toBF16(ts->GetData(), tensor->getElementCount(), fp32_data);
// Ensure valid bfloat16 stored in each float
for (auto f : fp32_data)
ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f);
@@ -595,11 +595,23 @@ int SubgraphTraverser::allocateTensor(std::string name)
}
}
break;
- case DType_FP8E4M3:
+ case DType_FP8E4M3: {
+ std::vector<float> fp32_data;
+ TosaSerializationHandler::ConvertU8toFP8E4M3(ts->GetData(), tensor->getElementCount(), fp32_data);
+ if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
+ {
+ std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
+ tensor->setTensorValueDouble(f64_data.size(), f64_data.data());
+ }
+ else
+ {
+ tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data());
+ }
+ }
+ break;
case DType_FP8E5M2: {
std::vector<float> fp32_data;
- TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
- // Ensure valid fp8 stored in each float
+ TosaSerializationHandler::ConvertU8toFP8E5M2(ts->GetData(), tensor->getElementCount(), fp32_data);
if (tensor->getDtype() == TOSA_REF_TYPE_FP64)
{
std::vector<double> f64_data(fp32_data.begin(), fp32_data.end());
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 8f9e2842ce7d25645233ad4f6fa406be982346a
+Subproject 57d781883142db8a45fe98ac1a1dfacc49cba78
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index c5ac0f9..38ab3f4 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -3,7 +3,6 @@
import json
import logging
import os
-import struct
from copy import deepcopy
from datetime import datetime
from pathlib import Path
@@ -1390,20 +1389,14 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- if a.dtype in (DType.BF16, DType.FP16, DType.FP32):
- if a.dtype == DType.FP16:
- # Non-tensor fp16 ops take fp16 values as fp32 in reference_model
- min_val = min_val.astype(np.float32)
- max_val = max_val.astype(np.float32)
- min_val_as_bytes = struct.pack("<f", min_val)
- max_val_as_bytes = struct.pack("<f", max_val)
- elif a.dtype in (DType.INT8, DType.INT16):
- min_val_as_bytes = struct.pack("<i", min_val)
- max_val_as_bytes = struct.pack("<i", max_val)
- else:
- # to avoid internal error for incorrect input types
- min_val_as_bytes = struct.pack("<i", 0)
- max_val_as_bytes = struct.pack("<i", 0)
+ min_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(a.dtype, [min_val])
+ max_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(a.dtype, [max_val])
+
+ # align to 8 bytes
+ while (len(min_val_as_bytes) % 8) != 0:
+ min_val_as_bytes.append(0)
+ while (len(max_val_as_bytes) % 8) != 0:
+ max_val_as_bytes.append(0)
attr.ClampAttribute(self.ser.builder, min_val_as_bytes, max_val_as_bytes)
@@ -1550,9 +1543,17 @@ class TosaTestGen:
# get pad_const_val_as_bytes from either pad_const_float or pad_const_int
if gtu.dtypeIsFloat(a.dtype):
- pad_const_val_as_bytes = struct.pack("<f", pad_const_float)
+ pad_const_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(
+ a.dtype, [pad_const_float]
+ )
else:
- pad_const_val_as_bytes = struct.pack("<i", pad_const_int)
+ pad_const_val_as_bytes = ts.TosaSerializer.convertDataToUint8Vec(
+ a.dtype, [pad_const_int]
+ )
+
+ # align to 8 bytes
+ while (len(pad_const_val_as_bytes) % 8) != 0:
+ pad_const_val_as_bytes.append(0)
attr = ts.TosaSerializerAttribute()
attr.PadAttribute(self.ser.builder, pad_const_val_as_bytes)