diff options
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/generate/generate_entry.cc | 5 | ||||
-rw-r--r-- | reference_model/src/generate/generate_full_range.cc | 59 | ||||
-rw-r--r-- | reference_model/src/generate/generate_full_range.h | 34 | ||||
-rw-r--r-- | reference_model/src/generate/generate_utils.cc | 21 | ||||
-rw-r--r-- | reference_model/src/generate/generate_utils.h | 15 | ||||
-rw-r--r-- | reference_model/src/verify/verify_abs_error.cc | 13 | ||||
-rw-r--r-- | reference_model/src/verify/verify_utils.cc | 25 |
7 files changed, 154 insertions, 18 deletions
diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc index 91b2fc7..6f797b3 100644 --- a/reference_model/src/generate/generate_entry.cc +++ b/reference_model/src/generate/generate_entry.cc @@ -16,6 +16,7 @@ #include "generate_dot_product.h" #include "generate_fixed_data.h" +#include "generate_full_range.h" #include "generate_pseudo_random.h" #include "generate_utils.h" @@ -41,6 +42,10 @@ bool generate(const GenerateConfig& cfg, void* data, size_t size) return generateFixedData(cfg, data, size); break; } + case GeneratorType::FullRange: { + return generateFullRange(cfg, data, size); + break; + } default: { WARNING("[Generator] Unsupported generation mode."); break; diff --git a/reference_model/src/generate/generate_full_range.cc b/reference_model/src/generate/generate_full_range.cc new file mode 100644 index 0000000..d2a89da --- /dev/null +++ b/reference_model/src/generate/generate_full_range.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "generate_full_range.h" +#include "half.hpp" + +namespace +{ + +template <typename DataType> +bool generate(const TosaReference::GenerateConfig& cfg, DataType* data, size_t size) +{ + const TosaReference::FullRangeInfo& frinfo = cfg.fullRangeInfo; + DataType value = frinfo.startVal; + + const auto T = TosaReference::numElementsFromShape(cfg.shape); + for (auto t = 0; t < T; ++t) + { + data[t] = value; + value++; + } + return true; +} +} // namespace + +namespace TosaReference +{ +bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size) +{ + // Check we support the operator + if (cfg.opType == Op::Op_UNKNOWN) + { + WARNING("[Generator][PR] Unknown operator."); + return false; + } + + switch (cfg.dataType) + { + case DType::DType_FP16: { + uint16_t* outData = reinterpret_cast<uint16_t*>(data); + return generate(cfg, outData, size); + } + default: + WARNING("[Generator][PR] Unsupported type."); + return false; + } +} +} // namespace TosaReference
\ No newline at end of file diff --git a/reference_model/src/generate/generate_full_range.h b/reference_model/src/generate/generate_full_range.h new file mode 100644 index 0000000..df24160 --- /dev/null +++ b/reference_model/src/generate/generate_full_range.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GENERATE_FULL_RANGE_H_ +#define GENERATE_FULL_RANGE_H_ + +#include "generate_utils.h" + +namespace TosaReference +{ + +/// \brief Perform full range generation +/// +/// \param cfg Generator related meta-data +/// \param data Buffer to generate the data to +/// \param size Size of the buffer +/// +/// \return True on successful generation +bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size); + +}; // namespace TosaReference + +#endif // GENERATE_FULL_RANGE_H_
\ No newline at end of file diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index d0d0194..f31b443 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -105,9 +105,9 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType, { GeneratorType::Unknown, "UNKNOWN" }, { GeneratorType::PseudoRandom, "PSEUDO_RANDOM" }, { GeneratorType::DotProduct, "DOT_PRODUCT" }, - { GeneratorType::OpFullRange, "OP_FULL_RANGE" }, - { GeneratorType::OpBoundary, "OP_BOUNDARY" }, - { GeneratorType::OpSpecial, "OP_SPECIAL" }, + { GeneratorType::FullRange, "FULL_RANGE" }, + { GeneratorType::Boundary, "BOUNDARY" }, + { GeneratorType::Special, "SPECIAL" }, { GeneratorType::FixedData, "FIXED_DATA" }, }) @@ -151,6 +151,14 @@ void from_json(const nlohmann::json& j, FixedDataInfo& fixedDataInfo) j.at("data").get_to(fixedDataInfo.data); } +void from_json(const nlohmann::json& j, FullRangeInfo& fullRangeInfo) +{ + if (j.contains("start_val")) + { + j.at("start_val").get_to(fullRangeInfo.startVal); + } +} + void from_json(const nlohmann::json& j, GenerateConfig& cfg) { j.at("data_type").get_to(cfg.dataType); @@ -186,6 +194,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg) { j.at("fixed_data_info").get_to(cfg.fixedDataInfo); } + + //Set up defaults for fullRangeInfo + cfg.fullRangeInfo.startVal = 0; + if (j.contains("full_range_info")) + { + j.at("full_range_info").get_to(cfg.fullRangeInfo); + } } std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName) diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h index 697b404..8ce9b0e 100644 --- a/reference_model/src/generate/generate_utils.h +++ b/reference_model/src/generate/generate_utils.h @@ -31,9 +31,9 @@ enum class GeneratorType Unknown, PseudoRandom, DotProduct, - OpFullRange, - OpBoundary, - OpSpecial, + FullRange, + Boundary, + Special, FixedData, }; @@ -74,6 +74,14 @@ struct FixedDataInfo std::vector<int32_t> data; }; +/// \brief Op specific generator meta-data +struct FullRangeInfo +{ + FullRangeInfo() = default; + + uint16_t startVal; +}; + /// \brief Generator configuration struct GenerateConfig { @@ -86,6 +94,7 @@ struct GenerateConfig DotProductInfo dotProductInfo; PseudoRandomInfo pseudoRandomInfo; FixedDataInfo fixedDataInfo; + FullRangeInfo fullRangeInfo; }; /// \brief Parse the generator config when given in JSON form diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc index 125045e..64f86a3 100644 --- a/reference_model/src/verify/verify_abs_error.cc +++ b/reference_model/src/verify/verify_abs_error.cc @@ -30,12 +30,17 @@ double calcErrorBound(double referenceValue, double boundsValue, const void* cfg { const auto cfg = reinterpret_cast<const AbsErrorVerifyInfo*>(cfgPtr); - double valBound = std::abs(referenceValue) * boundsValue; - if (cfg->lowerBound > 0) + double errorBound = 0.0; + if (std::isfinite(referenceValue) && std::abs(referenceValue) != 0.0) { - valBound = std::max(cfg->lowerBound, valBound); + double valBound = std::abs(referenceValue) * boundsValue; + if (cfg->lowerBound > 0) + { + valBound = std::max(cfg->lowerBound, valBound); + } + errorBound = exp2(-AccPrecision<OutType>::normal_frac / cfg->normalDivisor) * valBound; } - return exp2(-AccPrecision<OutType>::normal_frac / cfg->normalDivisor) * valBound; + return errorBound; } } // namespace diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc index 50a98e5..d4657b3 100644 --- a/reference_model/src/verify/verify_utils.cc +++ b/reference_model/src/verify/verify_utils.cc @@ -356,21 +356,23 @@ bool validateData(const double* referenceData, TOSA_REF_REQUIRE(calcErrorBound != nullptr, "Missing error bound function validation"); std::string warning, worstWarning; - double difference, worstDifference = 0.0; - size_t worstPosition; - bool compliant = true; + double worstDifference = 0.0; + // Set to invalid index + size_t worstIndex = T; + bool compliant = true; for (size_t i = 0; i < T; ++i) { - double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i]; - double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr); - bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning); + double difference = 0.0; + double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i]; + double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr); + bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning); if (!valid) { compliant = false; if (std::isnan(difference) || std::abs(difference) > std::abs(worstDifference)) { - worstPosition = i; + worstIndex = i; worstDifference = difference; worstWarning.assign(warning); if (std::isnan(difference)) @@ -379,11 +381,18 @@ bool validateData(const double* referenceData, break; } } + else if (std::abs(difference) == 0.0) + { + auto pos = indexToPosition(i, shape); + WARNING("[Verifier][%s] Invalid error bound, no difference found. Location: %s", modeStr.c_str(), + positionToString(pos).c_str()); + return false; + } } } if (!compliant) { - auto pos = indexToPosition(worstPosition, shape); + auto pos = indexToPosition(worstIndex, shape); WARNING("[Verifier][%s] Largest deviance at location %s: %s", modeStr.c_str(), positionToString(pos).c_str(), worstWarning.c_str()); } |