From 5d0e9c7f3748e80d6f14a3eeaef858eeb912e1fd Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Fri, 5 Apr 2024 01:19:31 +0000 Subject: [ref model] fix const/pad/clamp attribute serialization This changes to use native type serialization and deserialization for pad_const, clamp min_val/max_val and const data attribute values whereby fp16 values are stored as 2 bytes each, fp8 values are stored in 1 byte each, etc. Signed-off-by: Tai Ly Change-Id: Ia95d320fe8c546ce1d1ccc035d6e9bcaadcc9ca3 --- reference_model/include/dtype.h | 45 ++- reference_model/src/float_utils.h | 533 ---------------------------- reference_model/src/ops/activation_funcs.cc | 96 +++-- reference_model/src/ops/activation_funcs.h | 1 - reference_model/src/ops/data_layout.cc | 67 +++- reference_model/src/ops/ewise_unary.cc | 7 + reference_model/src/ops/type_conversion.cc | 10 +- reference_model/src/subgraph_traverser.cc | 20 +- thirdparty/serialization_lib | 2 +- verif/generator/tosa_test_gen.py | 35 +- 10 files changed, 189 insertions(+), 627 deletions(-) delete mode 100644 reference_model/src/float_utils.h diff --git a/reference_model/include/dtype.h b/reference_model/include/dtype.h index a283f39..3463af9 100644 --- a/reference_model/include/dtype.h +++ b/reference_model/include/dtype.h @@ -89,26 +89,8 @@ inline const char* EnumNameTOSAREFTYPE(TOSA_REF_TYPE e) } // return corresponding TOSA_REF_TYPE for DType -inline TOSA_REF_TYPE ConvertDType(const DType dtype) +inline TOSA_REF_TYPE DType2RefType(const DType dtype) { - assert(DType_MAX == DType_FP8E5M2); // must update whenever DType_MAX changes - - if (g_func_config.precise_mode) - { - // in precise mode, convert all floating DType to TOSA_REF_TYPE_FP64 - switch (dtype) - { - case DType_FP16: - case DType_FP32: - case DType_BF16: - case DType_FP8E4M3: - case DType_FP8E5M2: - return TOSA_REF_TYPE_FP64; - default: - break; - } - } - switch (dtype) { case DType_BOOL: @@ -145,6 +127,31 @@ inline TOSA_REF_TYPE ConvertDType(const DType dtype) return TOSA_REF_TYPE_UNKNOWN; } +// return corresponding TOSA_REF_TYPE for DType +// if precise_mode, convert all floating dtype to FP64 +inline TOSA_REF_TYPE ConvertDType(const DType dtype) +{ + assert(DType_MAX == DType_FP8E5M2); // must update whenever DType_MAX changes + + if (g_func_config.precise_mode) + { + // in precise mode, convert all floating DType to TOSA_REF_TYPE_FP64 + switch (dtype) + { + case DType_FP16: + case DType_FP32: + case DType_BF16: + case DType_FP8E4M3: + case DType_FP8E5M2: + return TOSA_REF_TYPE_FP64; + default: + break; + } + } + + return DType2RefType(dtype); +} + template bool IsSignedInt() { diff --git a/reference_model/src/float_utils.h b/reference_model/src/float_utils.h deleted file mode 100644 index b98c89b..0000000 --- a/reference_model/src/float_utils.h +++ /dev/null @@ -1,533 +0,0 @@ -// Copyright (c) 2024, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef FLOAT_UTILS_H_ -#define FLOAT_UTILS_H_ - -#include -#include -#include -#include -#if defined(__cpp_lib_bit_cast) -#include -#endif // defined(__cpp_lib_bit_cast) - -namespace tosa::reference::internal -{ - -namespace float_support -{ - -struct hidden -{}; - -#if defined(__cpp_lib_bit_cast) -#define BITCAST_CONSTEXPR constexpr inline - -constexpr inline int32_t get_bits(const float& f) -{ - return std::bit_cast(f); -} -constexpr inline float from_bits(const int32_t& i) -{ - return std::bit_cast(i); -} - -#else -#define BITCAST_CONSTEXPR inline - -union ufloat32 -{ - constexpr ufloat32(const float& x) - : f(x) - {} - constexpr ufloat32(const int32_t& x) - : i(x) - {} - - float f; - int32_t i; -}; - -inline int32_t get_bits(const float& f) -{ - return ufloat32(f).i; -} -inline float from_bits(const int32_t& i) -{ - return ufloat32(i).f; -} -#endif - -} // namespace float_support - -template = true> -class float_t -{ - storage_t m_data = 0; - -public: - static constexpr size_t n_exponent_bits = n_exp_bits; - static constexpr size_t n_significand_bits = sizeof(storage_t) * 8 - 1 - n_exp_bits; - static constexpr int64_t exponent_bias = (1 << (n_exp_bits - 1)) - 1; - - /// \brief Construct a floating point type with the given bit - /// representation. - static constexpr float_t from_bits(storage_t bits) - { - return float_t(float_support::hidden(), bits); - } - - /// \brief Construct a float from the given sign, exponent and significand - /// bits. - static constexpr float_t from_bits(bool pm, storage_t e, storage_t s) - { - storage_t bits = pm ? 1 : 0; - - bits <<= n_exp_bits; - bits |= e; - - bits <<= n_significand_bits; - if (with_denorm || e) - bits |= s; - - return float_t(float_support::hidden(), bits); - } - - /// \brief (Hidden) Construct a float type from a given bit pattern - constexpr float_t(const float_support::hidden&, storage_t bits) - : m_data(bits) - {} - - constexpr float_t() - : m_data(0) - {} - constexpr float_t(const float_t& other) - : m_data(other.m_data) - {} - - /// \brief Cast to a different floating point representation. - template - constexpr inline - operator float_t() const - { - using other_float_t = - float_t; - - // Shortcut for types which are fundamentally similar (e.g., bf16 -> - // fp32) - if constexpr (n_exp_bits == other_n_exp_bits && sizeof(other_storage_t) >= sizeof(storage_t) && - has_nan == other_has_nan) - { - return other_float_t::from_bits(static_cast(m_data) - << (sizeof(other_storage_t) - sizeof(storage_t)) * 8); - } - - // Get initial values for the new floating point type - const bool sign_bit = m_data < 0; - int64_t new_exponent_bits = 0; - uint64_t new_significand = 0; - - if (is_nan() || is_infinity()) - { - new_exponent_bits = (1 << other_n_exp_bits) - 1; - - if (is_nan()) - { - if constexpr (other_has_infinity) - { - // Copy across the `not_quiet bit`; set the LSB. Don't - // attempt to copy across any of the rest of the payload. - new_significand = - 0x1 | (((significand() >> (n_significand_bits - 1)) & 1) << other_float_t::n_significand_bits); - } - else - { - new_significand = (1ul << other_float_t::n_significand_bits) - 1; - } - } - else if constexpr (!other_has_infinity) - { - new_significand = (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); - } - } - else if (!is_zero()) - { - const int64_t this_exponent_bits = exponent_bits(); - { - constexpr int64_t exponent_rebias = other_float_t::exponent_bias - exponent_bias; - new_exponent_bits = std::max(this_exponent_bits + exponent_rebias, exponent_rebias + 1); - } - new_significand = this->significand() << (64 - n_significand_bits); - - // Normalise subnormals - if (this_exponent_bits == 0) - { - // Shift the most-significant 1 out of the magnitude to convert - // it to a significand. Modify the exponent accordingly. - uint8_t shift = __builtin_clzl(new_significand) + 1; - new_exponent_bits -= shift; - new_significand <<= shift; - } - - // Align the significand for the output type - uint32_t shift = 64 - other_float_t::n_significand_bits; - const bool other_is_subnormal = new_exponent_bits <= 0; - if (other_is_subnormal) - { - shift += 1 - new_exponent_bits; - new_exponent_bits = 0; - } - - const uint64_t shift_out = shift == 64 ? new_significand : new_significand & ((1ll << shift) - 1); - new_significand = shift == 64 ? 0 : new_significand >> shift; - - // Reinsert the most-significant-one if this is a subnormal in the - // output type. - new_significand |= (other_is_subnormal ? 1ll : 0) << (64 - shift); - - // Apply rounding based on the bits shifted out of the significand - const uint64_t shift_half = 1ll << (shift - 1); - if (shift_out > shift_half || (shift_out == shift_half && (new_significand & 1))) - { - new_significand += 1; - - // Handle the case that the significand overflowed due to - // rounding - constexpr uint64_t max_significand = (1ll << other_float_t::n_significand_bits) - 1; - if (new_significand > max_significand) - { - new_significand = 0; - new_exponent_bits++; - } - } - - // Saturate to infinity if the exponent is larger than can be - // represented in the output type. This can only occur if the size - // of the exponent of the new type is not greater than the exponent - // of the old type. - if constexpr (other_n_exp_bits <= n_exp_bits) - { - constexpr int64_t inf_exp_bits = (1ll << other_n_exp_bits) - 1; - if (new_exponent_bits >= inf_exp_bits) - { - new_exponent_bits = inf_exp_bits; - new_significand = - other_has_infinity ? 0 : (1ul << other_float_t::n_significand_bits) - (other_has_nan ? 2 : 1); - } - } - } - - return other_float_t::from_bits(sign_bit, new_exponent_bits, new_significand); - } - - /// \brief Convert from a 32-bit floating point value - BITCAST_CONSTEXPR - float_t(const float& f) - { - // If this format exactly represents the binary32 format then get - // the bits from the provided float; otherwise get a binary32 - // representation and then convert to this format. - if constexpr (represents_binary32()) - m_data = float_support::get_bits(f); - else - m_data = static_cast>( - static_cast>(f)) - .m_data; - } - - /// \brief Cast to a 32-bit floating point value - BITCAST_CONSTEXPR operator float() const - { - // If this format exactly represents the binary32 format then return - // a float; otherwise get a binary32 representation and then return - // a float. - if constexpr (represents_binary32()) - return float_support::from_bits(m_data); - else - return static_cast(this->operator float_t()); - } - - /// \brief Return whether this type represents the IEEE754 binary32 - /// format - constexpr static inline bool represents_binary32() - { - return std::is_same_v && n_exp_bits == 8 && has_nan && with_denorm && with_infinity; - } - - constexpr auto operator-() const - { - return from_bits(m_data ^ (1ll << (sizeof(storage_t) * 8 - 1))); - } - - constexpr bool is_subnormal() const - { - return exponent_bits() == 0 && significand() != 0; - } - - constexpr bool is_zero() const - { - return exponent_bits() == 0 && significand() == 0; - } - - constexpr bool is_nan() const - { - return has_nan && (exponent_bits() == (1ul << n_exponent_bits) - 1) && - ((with_infinity && significand()) || - (!with_infinity && significand() == (1ul << n_significand_bits) - 1)); - } - - constexpr bool is_infinity() const - { - return with_infinity && ((exponent_bits() == (1ul << n_exponent_bits) - 1) && !significand()); - } - - constexpr inline const storage_t& bits() const - { - return m_data; - } - - /// \brief Get the exponent - constexpr inline int64_t exponent() const - { - return std::max(exponent_bits(), 1ul) - exponent_bias; - } - - /// \brief Get the bits from the exponent - constexpr inline uint64_t exponent_bits() const - { - constexpr uint64_t mask = (1ul << n_exp_bits) - 1; - return (m_data >> n_significand_bits) & mask; - } - - constexpr inline uint64_t significand() const - { - return m_data & ((1ul << n_significand_bits) - 1); - } - - constexpr inline bool operator==(const float_t& other) const - { - return !is_nan() && !other.is_nan() && ((is_zero() && other.is_zero()) || bits() == other.bits()); - } - - constexpr inline float_t& operator+=(const float_t& rhs) - { - this->m_data = static_cast(static_cast(*this) + static_cast(rhs)).bits(); - return *this; - } -}; - -// This should probably be exported so we can use it elsewhere -#undef BITCAST_CONSTEXPR - -namespace float_support -{ - -// Pre-C++23 these can't be computed as constexpr, so have to hardcode them - -template -struct digits10; // floor(log10(2) * (digits - 1) -template -struct max_digits10; // ceil(log10(2) * digits + 1) -template -struct min_exponent10; // floor(log10(2) * min_exponent) -template -struct max_exponent10; // floor(log10(2) * max_exponent) - -template <> -struct digits10<8> -{ - constexpr static inline int value = 2; -}; - -template <> -struct max_digits10<8> -{ - constexpr static inline int value = 4; -}; - -template <> -struct digits10<10> -{ - constexpr static inline int value = 2; -}; - -template <> -struct max_digits10<10> -{ - constexpr static inline int value = 5; -}; - -template <> -struct digits10<24> -{ - constexpr static inline int value = 6; -}; - -template <> -struct max_digits10<24> -{ - constexpr static inline int value = 9; -}; - -template <> -struct min_exponent10<-13> -{ - constexpr static inline int value = -3; -}; - -template <> -struct max_exponent10<16> -{ - constexpr static inline int value = 4; -}; - -template <> -struct min_exponent10<-125> -{ - constexpr static inline int value = -37; -}; - -template <> -struct max_exponent10<128> -{ - constexpr static inline int value = 38; -}; - -template -inline constexpr int digits10_v = digits10::value; -template -inline constexpr int max_digits10_v = max_digits10::value; - -template -inline constexpr int min_exponent10_v = min_exponent10::value; - -template -inline constexpr int max_exponent10_v = max_exponent10::value; - -} // namespace float_support - -} // namespace tosa::reference::internal - -namespace std -{ - -template -struct is_floating_point> - : std::integral_constant -{}; - -template -class numeric_limits> -{ - using this_float_t = tosa::reference::internal::float_t; - -public: - static constexpr bool is_specialized = true; - - static constexpr auto min() noexcept - { - return this_float_t::from_bits(false, 1, 0); - } - - static constexpr auto max() noexcept - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 2, - (1 << this_float_t::n_significand_bits) - 1); - } - - static constexpr auto lowest() noexcept - { - return -max(); - } - - static constexpr int digits = this_float_t::n_significand_bits + 1; - static constexpr int digits10 = tosa::reference::internal::float_support::digits10_v; - static constexpr int max_digits10 = tosa::reference::internal::float_support::max_digits10_v; - - static constexpr bool is_signed = true; - static constexpr bool is_integer = false; - static constexpr bool is_exact = false; - static constexpr int radix = 2; - - static constexpr auto epsilon() noexcept - { - return this_float_t::from_bits(false, this_float_t::exponent_bias - this_float_t::n_significand_bits, 0); - } - - static constexpr auto round_error() noexcept - { - return this_float_t::from_bits(0, this_float_t::exponent_bias - 1, 0); - } - - static constexpr int min_exponent = (1 - this_float_t::exponent_bias) + 1; - static constexpr int min_exponent10 = tosa::reference::internal::float_support::min_exponent10_v; - static constexpr int max_exponent = this_float_t::exponent_bias + 1; - static constexpr int max_exponent10 = tosa::reference::internal::float_support::max_exponent10_v; - - static constexpr bool has_infinity = with_inf; - static constexpr bool has_quiet_NaN = has_nan; - static constexpr bool has_signaling_NaN = true; - static constexpr float_denorm_style has_denorm = with_denorm ? denorm_present : denorm_absent; - static constexpr bool has_denorm_loss = false; - - static constexpr auto infinity() noexcept - { - if constexpr (with_inf) - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 0); - } - else - { - return this_float_t::from_bits(false, 0, 0); - } - } - - static constexpr auto quiet_NaN() noexcept - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, - 1 << (this_float_t::n_significand_bits - 1) | 1); - } - - static constexpr auto signaling_NaN() noexcept - { - return this_float_t::from_bits(false, (1 << this_float_t::n_exponent_bits) - 1, 1); - } - - static constexpr auto denorm_min() noexcept - { - return this_float_t::from_bits(false, 0, 1); - } - - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = false; - static constexpr bool is_modulo = false; - - static constexpr bool traps = false; - static constexpr bool tinyness_before = false; - static constexpr float_round_style round_style = round_to_nearest; -}; - -} // namespace std - -#endif // _FLOAT_UTILS_H_ diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index de7d8be..fc2a9ac 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -31,53 +31,79 @@ int OpClamp::register_fcn() auto tosa_level = g_func_config.tosa_level; LEVEL_CHECK(Rank <= tosa_level.MAX_RANK, "Rank should be smaller than or equal to MAX_RANK"); - switch (Dtype) + ASSERT_MSG(!(static_cast(this))->getOutputs().empty(), + "Must call register_fcn after tensors are linked to nodes"); + + InEigenType min, max; + + // need to use input tensor's serializationDtype to deserialize min/max values + // because Dtype may be FP64 in precise_mode + auto serializationDtype = (static_cast(this))->getInputs()[0]->getSerializationDtype(); + switch (DType2RefType(serializationDtype)) { - case TOSA_REF_TYPE_FP16: - case TOSA_REF_TYPE_BF16: - case TOSA_REF_TYPE_FP32: { + case TOSA_REF_TYPE_FP16: { + std::vector min_float_data, max_float_data; + TosaSerializationHandler::ConvertU8toF16(attribute->min_val(), /* size = */ 1, min_float_data); + TosaSerializationHandler::ConvertU8toF16(attribute->max_val(), /* size = */ 1, max_float_data); + min = (InEigenType)min_float_data[0]; + max = (InEigenType)max_float_data[0]; + } + break; + case TOSA_REF_TYPE_BF16: { std::vector min_float_data, max_float_data; - TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data); - TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data); - InEigenType min = (InEigenType)min_float_data[0]; - InEigenType max = (InEigenType)max_float_data[0]; - ERROR_IF(max < min, "OpClamp: max smaller than min"); - - this->fcn = [min, max](InEigenType a) -> OutEigenType { - return fpTrunc(a <= min ? min : a >= max ? max : a); - }; + TosaSerializationHandler::ConvertU8toBF16(attribute->min_val(), /* size = */ 1, min_float_data); + TosaSerializationHandler::ConvertU8toBF16(attribute->max_val(), /* size = */ 1, max_float_data); + min = (InEigenType)min_float_data[0]; + max = (InEigenType)max_float_data[0]; } break; - case TOSA_REF_TYPE_FP64: { + case TOSA_REF_TYPE_FP32: { std::vector min_float_data, max_float_data; TosaSerializationHandler::ConvertU8toF32(attribute->min_val(), /* size = */ 1, min_float_data); TosaSerializationHandler::ConvertU8toF32(attribute->max_val(), /* size = */ 1, max_float_data); - InEigenType min = (InEigenType)min_float_data[0]; - InEigenType max = (InEigenType)max_float_data[0]; - ERROR_IF(max < min, "OpClamp: max smaller than min"); - - this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); }; + min = (InEigenType)min_float_data[0]; + max = (InEigenType)max_float_data[0]; } break; case TOSA_REF_TYPE_INT8: { - std::vector min_int_data, max_int_data; - TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data); - TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data); - int8_t min = (int8_t)min_int_data[0]; - int8_t max = (int8_t)max_int_data[0]; - - ERROR_IF(max < min, "OpClamp: max smaller than min"); - this->fcn = [min, max](int8_t a) -> int8_t { return a <= min ? min : a >= max ? max : a; }; + std::vector min_int_data, max_int_data; + TosaSerializationHandler::ConvertU8toI8(attribute->min_val(), /* size = */ 1, min_int_data); + TosaSerializationHandler::ConvertU8toI8(attribute->max_val(), /* size = */ 1, max_int_data); + min = (InEigenType)min_int_data[0]; + max = (InEigenType)max_int_data[0]; + } + break; + case TOSA_REF_TYPE_INT16: { + std::vector min_int_data, max_int_data; + TosaSerializationHandler::ConvertU8toI16(attribute->min_val(), /* size = */ 1, min_int_data); + TosaSerializationHandler::ConvertU8toI16(attribute->max_val(), /* size = */ 1, max_int_data); + min = (InEigenType)min_int_data[0]; + max = (InEigenType)max_int_data[0]; } + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + ERROR_IF(max < min, "OpClamp: max smaller than min"); + + // evaluation function is still based on Dtype + switch (Dtype) + { + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: { + // apply fpTrunc after min/max + this->fcn = [min, max](InEigenType a) -> OutEigenType { + return fpTrunc(a <= min ? min : a >= max ? max : a); + }; + } + break; + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT8: case TOSA_REF_TYPE_INT16: { - std::vector min_int_data, max_int_data; - TosaSerializationHandler::ConvertU8toI32(attribute->min_val(), /* size = */ 1, min_int_data); - TosaSerializationHandler::ConvertU8toI32(attribute->max_val(), /* size = */ 1, max_int_data); - int16_t min = (int16_t)min_int_data[0]; - int16_t max = (int16_t)max_int_data[0]; - - ERROR_IF(max < min, "OpClamp: max smaller than min"); - this->fcn = [min, max](int16_t a) -> int16_t { return a <= min ? min : a >= max ? max : a; }; + // simply min/max + this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); }; } break; default: diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h index 1696668..055642a 100644 --- a/reference_model/src/ops/activation_funcs.h +++ b/reference_model/src/ops/activation_funcs.h @@ -32,7 +32,6 @@ public: : UnaryNode(sgt_, Op_CLAMP, id_) { INIT_ATTRIBUTE(Clamp); - register_fcn(); } virtual ~OpClamp(); static constexpr int32_t QMin = GetQMin::value; diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index e264284..6664ec3 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -171,11 +171,31 @@ int OpPad::eval() { InEigenType pad_value = 0; - switch (Dtype) - { - case TOSA_REF_TYPE_BOOL: - case TOSA_REF_TYPE_INT8: - case TOSA_REF_TYPE_INT16: + // need to use input tensor's serializationDtype to deserialize pad_const + // because Dtype may be FP64 in precise_mode + switch (DType2RefType(inputs[0]->getSerializationDtype())) + { + case TOSA_REF_TYPE_BOOL: { + std::vector bool_data; + TosaSerializationHandler::ConvertU8toBool(attribute->pad_const(), + /* size = */ 1, bool_data); + pad_value = (InEigenType)bool_data[0]; + break; + } + case TOSA_REF_TYPE_INT8: { + std::vector int8_data; + TosaSerializationHandler::ConvertU8toI8(attribute->pad_const(), + /* size = */ 1, int8_data); + pad_value = (InEigenType)int8_data[0]; + break; + } + case TOSA_REF_TYPE_INT16: { + std::vector int16_data; + TosaSerializationHandler::ConvertU8toI16(attribute->pad_const(), + /* size = */ 1, int16_data); + pad_value = (InEigenType)int16_data[0]; + break; + } case TOSA_REF_TYPE_INT32: { std::vector int32_data; TosaSerializationHandler::ConvertU8toI32(attribute->pad_const(), @@ -183,15 +203,38 @@ int OpPad::eval() pad_value = (InEigenType)int32_data[0]; break; } - case TOSA_REF_TYPE_FP16: - case TOSA_REF_TYPE_BF16: - case TOSA_REF_TYPE_FP32: - case TOSA_REF_TYPE_FP64: - case TOSA_REF_TYPE_FP8E4M3: + case TOSA_REF_TYPE_FP16: { + std::vector f16_data; + TosaSerializationHandler::ConvertU8toF16(attribute->pad_const(), + /* size = */ 1, f16_data); + pad_value = (InEigenType)f16_data[0]; + break; + } + case TOSA_REF_TYPE_BF16: { + std::vector f32_data; + TosaSerializationHandler::ConvertU8toBF16(attribute->pad_const(), + /* size = */ 1, f32_data); + pad_value = (InEigenType)f32_data[0]; + break; + } + case TOSA_REF_TYPE_FP32: { + std::vector f32_data; + TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(), + /* size = */ 1, f32_data); + pad_value = (InEigenType)f32_data[0]; + break; + } + case TOSA_REF_TYPE_FP8E4M3: { + std::vector f32_data; + TosaSerializationHandler::ConvertU8toFP8E4M3(attribute->pad_const(), + /* size = */ 1, f32_data); + pad_value = (InEigenType)f32_data[0]; + break; + } case TOSA_REF_TYPE_FP8E5M2: { std::vector float_data; - TosaSerializationHandler::ConvertU8toF32(attribute->pad_const(), - /* size = */ 1, float_data); + TosaSerializationHandler::ConvertU8toFP8E5M2(attribute->pad_const(), + /* size = */ 1, float_data); pad_value = (InEigenType)float_data[0]; break; } diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index dd9ea5a..310a174 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -66,6 +66,13 @@ int UnaryNode::checkTensorAttributes() template int UnaryNode::eval() { + // call register_fcn() here to ensure inputs/outputs have been connected + // to the node by the time register_fcn() is called for Clamp Operator + if (register_fcn()) + { + return 1; + } + this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn); return GraphNode::eval(); diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 7bca697..40e6c64 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -25,11 +25,11 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -using fp16 = tosa::reference::internal::float_t; -using bf16 = tosa::reference::internal::float_t; -using fp32 = tosa::reference::internal::float_t; -using fp8e4m3 = tosa::reference::internal::float_t; -using fp8e5m2 = tosa::reference::internal::float_t; +using fp16 = tosa::float_t; +using bf16 = tosa::float_t; +using fp32 = tosa::float_t; +using fp8e4m3 = tosa::float_t; +using fp8e5m2 = tosa::float_t; template OpRescale::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc index 52b1806..6aa0a45 100644 --- a/reference_model/src/subgraph_traverser.cc +++ b/reference_model/src/subgraph_traverser.cc @@ -580,7 +580,7 @@ int SubgraphTraverser::allocateTensor(std::string name) break; case DType_BF16: { std::vector fp32_data; - TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); + TosaSerializationHandler::ConvertU8toBF16(ts->GetData(), tensor->getElementCount(), fp32_data); // Ensure valid bfloat16 stored in each float for (auto f : fp32_data) ASSERT_MSG(checkValidBFloat(f), "Float value %f not valid bfloat16", f); @@ -595,11 +595,23 @@ int SubgraphTraverser::allocateTensor(std::string name) } } break; - case DType_FP8E4M3: + case DType_FP8E4M3: { + std::vector fp32_data; + TosaSerializationHandler::ConvertU8toFP8E4M3(ts->GetData(), tensor->getElementCount(), fp32_data); + if (tensor->getDtype() == TOSA_REF_TYPE_FP64) + { + std::vector f64_data(fp32_data.begin(), fp32_data.end()); + tensor->setTensorValueDouble(f64_data.size(), f64_data.data()); + } + else + { + tensor->setTensorValueFloat(fp32_data.size(), fp32_data.data()); + } + } + break; case DType_FP8E5M2: { std::vector fp32_data; - TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data); - // Ensure valid fp8 stored in each float + TosaSerializationHandler::ConvertU8toFP8E5M2(ts->GetData(), tensor->getElementCount(), fp32_data); if (tensor->getDtype() == TOSA_REF_TYPE_FP64) { std::vector f64_data(fp32_data.begin(), fp32_data.end()); diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index 8f9e284..57d7818 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit 8f9e2842ce7d25645233ad4f6fa406be982346ae +Subproject commit 57d781883142db8a45fe98ac1a1dfacc49cba78a diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index c5ac0f9..38ab3f4 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -3,7 +3,6 @@ import json import logging import os -import struct from copy import deepcopy from datetime import datetime from pathlib import Path @@ -1390,20 +1389,14 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - if a.dtype in (DType.BF16, DType.FP16, DType.FP32): - if a.dtype == DType.FP16: - # Non-tensor fp16 ops take fp16 values as fp32 in reference_model - min_val = min_val.astype(np.float32) - max_val = max_val.astype(np.float32) - min_val_as_bytes = struct.pack("