aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src')
-rw-r--r--reference_model/src/generate/generate_entry.cc5
-rw-r--r--reference_model/src/generate/generate_full_range.cc59
-rw-r--r--reference_model/src/generate/generate_full_range.h34
-rw-r--r--reference_model/src/generate/generate_utils.cc21
-rw-r--r--reference_model/src/generate/generate_utils.h15
-rw-r--r--reference_model/src/verify/verify_abs_error.cc13
-rw-r--r--reference_model/src/verify/verify_utils.cc25
7 files changed, 154 insertions, 18 deletions
diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc
index 91b2fc7..6f797b3 100644
--- a/reference_model/src/generate/generate_entry.cc
+++ b/reference_model/src/generate/generate_entry.cc
@@ -16,6 +16,7 @@
#include "generate_dot_product.h"
#include "generate_fixed_data.h"
+#include "generate_full_range.h"
#include "generate_pseudo_random.h"
#include "generate_utils.h"
@@ -41,6 +42,10 @@ bool generate(const GenerateConfig& cfg, void* data, size_t size)
return generateFixedData(cfg, data, size);
break;
}
+ case GeneratorType::FullRange: {
+ return generateFullRange(cfg, data, size);
+ break;
+ }
default: {
WARNING("[Generator] Unsupported generation mode.");
break;
diff --git a/reference_model/src/generate/generate_full_range.cc b/reference_model/src/generate/generate_full_range.cc
new file mode 100644
index 0000000..d2a89da
--- /dev/null
+++ b/reference_model/src/generate/generate_full_range.cc
@@ -0,0 +1,59 @@
+// Copyright (c) 2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "generate_full_range.h"
+#include "half.hpp"
+
+namespace
+{
+
+template <typename DataType>
+bool generate(const TosaReference::GenerateConfig& cfg, DataType* data, size_t size)
+{
+ const TosaReference::FullRangeInfo& frinfo = cfg.fullRangeInfo;
+ DataType value = frinfo.startVal;
+
+ const auto T = TosaReference::numElementsFromShape(cfg.shape);
+ for (auto t = 0; t < T; ++t)
+ {
+ data[t] = value;
+ value++;
+ }
+ return true;
+}
+} // namespace
+
+namespace TosaReference
+{
+bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size)
+{
+ // Check we support the operator
+ if (cfg.opType == Op::Op_UNKNOWN)
+ {
+ WARNING("[Generator][PR] Unknown operator.");
+ return false;
+ }
+
+ switch (cfg.dataType)
+ {
+ case DType::DType_FP16: {
+ uint16_t* outData = reinterpret_cast<uint16_t*>(data);
+ return generate(cfg, outData, size);
+ }
+ default:
+ WARNING("[Generator][PR] Unsupported type.");
+ return false;
+ }
+}
+} // namespace TosaReference \ No newline at end of file
diff --git a/reference_model/src/generate/generate_full_range.h b/reference_model/src/generate/generate_full_range.h
new file mode 100644
index 0000000..df24160
--- /dev/null
+++ b/reference_model/src/generate/generate_full_range.h
@@ -0,0 +1,34 @@
+// Copyright (c) 2024, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GENERATE_FULL_RANGE_H_
+#define GENERATE_FULL_RANGE_H_
+
+#include "generate_utils.h"
+
+namespace TosaReference
+{
+
+/// \brief Perform full range generation
+///
+/// \param cfg Generator related meta-data
+/// \param data Buffer to generate the data to
+/// \param size Size of the buffer
+///
+/// \return True on successful generation
+bool generateFullRange(const GenerateConfig& cfg, void* data, size_t size);
+
+}; // namespace TosaReference
+
+#endif // GENERATE_FULL_RANGE_H_ \ No newline at end of file
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index d0d0194..f31b443 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -105,9 +105,9 @@ NLOHMANN_JSON_SERIALIZE_ENUM(GeneratorType,
{ GeneratorType::Unknown, "UNKNOWN" },
{ GeneratorType::PseudoRandom, "PSEUDO_RANDOM" },
{ GeneratorType::DotProduct, "DOT_PRODUCT" },
- { GeneratorType::OpFullRange, "OP_FULL_RANGE" },
- { GeneratorType::OpBoundary, "OP_BOUNDARY" },
- { GeneratorType::OpSpecial, "OP_SPECIAL" },
+ { GeneratorType::FullRange, "FULL_RANGE" },
+ { GeneratorType::Boundary, "BOUNDARY" },
+ { GeneratorType::Special, "SPECIAL" },
{ GeneratorType::FixedData, "FIXED_DATA" },
})
@@ -151,6 +151,14 @@ void from_json(const nlohmann::json& j, FixedDataInfo& fixedDataInfo)
j.at("data").get_to(fixedDataInfo.data);
}
+void from_json(const nlohmann::json& j, FullRangeInfo& fullRangeInfo)
+{
+ if (j.contains("start_val"))
+ {
+ j.at("start_val").get_to(fullRangeInfo.startVal);
+ }
+}
+
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("data_type").get_to(cfg.dataType);
@@ -186,6 +194,13 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("fixed_data_info").get_to(cfg.fixedDataInfo);
}
+
+ //Set up defaults for fullRangeInfo
+ cfg.fullRangeInfo.startVal = 0;
+ if (j.contains("full_range_info"))
+ {
+ j.at("full_range_info").get_to(cfg.fullRangeInfo);
+ }
}
std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName)
diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h
index 697b404..8ce9b0e 100644
--- a/reference_model/src/generate/generate_utils.h
+++ b/reference_model/src/generate/generate_utils.h
@@ -31,9 +31,9 @@ enum class GeneratorType
Unknown,
PseudoRandom,
DotProduct,
- OpFullRange,
- OpBoundary,
- OpSpecial,
+ FullRange,
+ Boundary,
+ Special,
FixedData,
};
@@ -74,6 +74,14 @@ struct FixedDataInfo
std::vector<int32_t> data;
};
+/// \brief Op specific generator meta-data
+struct FullRangeInfo
+{
+ FullRangeInfo() = default;
+
+ uint16_t startVal;
+};
+
/// \brief Generator configuration
struct GenerateConfig
{
@@ -86,6 +94,7 @@ struct GenerateConfig
DotProductInfo dotProductInfo;
PseudoRandomInfo pseudoRandomInfo;
FixedDataInfo fixedDataInfo;
+ FullRangeInfo fullRangeInfo;
};
/// \brief Parse the generator config when given in JSON form
diff --git a/reference_model/src/verify/verify_abs_error.cc b/reference_model/src/verify/verify_abs_error.cc
index 125045e..64f86a3 100644
--- a/reference_model/src/verify/verify_abs_error.cc
+++ b/reference_model/src/verify/verify_abs_error.cc
@@ -30,12 +30,17 @@ double calcErrorBound(double referenceValue, double boundsValue, const void* cfg
{
const auto cfg = reinterpret_cast<const AbsErrorVerifyInfo*>(cfgPtr);
- double valBound = std::abs(referenceValue) * boundsValue;
- if (cfg->lowerBound > 0)
+ double errorBound = 0.0;
+ if (std::isfinite(referenceValue) && std::abs(referenceValue) != 0.0)
{
- valBound = std::max(cfg->lowerBound, valBound);
+ double valBound = std::abs(referenceValue) * boundsValue;
+ if (cfg->lowerBound > 0)
+ {
+ valBound = std::max(cfg->lowerBound, valBound);
+ }
+ errorBound = exp2(-AccPrecision<OutType>::normal_frac / cfg->normalDivisor) * valBound;
}
- return exp2(-AccPrecision<OutType>::normal_frac / cfg->normalDivisor) * valBound;
+ return errorBound;
}
} // namespace
diff --git a/reference_model/src/verify/verify_utils.cc b/reference_model/src/verify/verify_utils.cc
index 50a98e5..d4657b3 100644
--- a/reference_model/src/verify/verify_utils.cc
+++ b/reference_model/src/verify/verify_utils.cc
@@ -356,21 +356,23 @@ bool validateData(const double* referenceData,
TOSA_REF_REQUIRE(calcErrorBound != nullptr, "Missing error bound function validation");
std::string warning, worstWarning;
- double difference, worstDifference = 0.0;
- size_t worstPosition;
- bool compliant = true;
+ double worstDifference = 0.0;
+ // Set to invalid index
+ size_t worstIndex = T;
+ bool compliant = true;
for (size_t i = 0; i < T; ++i)
{
- double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i];
- double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr);
- bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning);
+ double difference = 0.0;
+ double boundVal = (boundsData == nullptr) ? 0.0 : boundsData[i];
+ double errBound = calcErrorBound(referenceData[i], boundVal, cfgPtr);
+ bool valid = tosaCheckFloatBound(implementationData[i], referenceData[i], errBound, difference, warning);
if (!valid)
{
compliant = false;
if (std::isnan(difference) || std::abs(difference) > std::abs(worstDifference))
{
- worstPosition = i;
+ worstIndex = i;
worstDifference = difference;
worstWarning.assign(warning);
if (std::isnan(difference))
@@ -379,11 +381,18 @@ bool validateData(const double* referenceData,
break;
}
}
+ else if (std::abs(difference) == 0.0)
+ {
+ auto pos = indexToPosition(i, shape);
+ WARNING("[Verifier][%s] Invalid error bound, no difference found. Location: %s", modeStr.c_str(),
+ positionToString(pos).c_str());
+ return false;
+ }
}
}
if (!compliant)
{
- auto pos = indexToPosition(worstPosition, shape);
+ auto pos = indexToPosition(worstIndex, shape);
WARNING("[Verifier][%s] Largest deviance at location %s: %s", modeStr.c_str(), positionToString(pos).c_str(),
worstWarning.c_str());
}