aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src')
-rw-r--r--reference_model/src/generate/generate_dot_product.cc5
-rw-r--r--reference_model/src/generate/generate_dot_product.h2
-rw-r--r--reference_model/src/generate/generate_entry.cc5
-rw-r--r--reference_model/src/generate/generate_pseudo_random.cc103
-rw-r--r--reference_model/src/generate/generate_pseudo_random.h34
-rw-r--r--reference_model/src/generate/generate_utils.cc11
-rw-r--r--reference_model/src/generate/generate_utils.h10
-rw-r--r--reference_model/src/verify/verify_exact.cc19
8 files changed, 182 insertions, 7 deletions
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc
index 1d2325f..cbfac4b 100644
--- a/reference_model/src/generate/generate_dot_product.cc
+++ b/reference_model/src/generate/generate_dot_product.cc
@@ -56,6 +56,11 @@ bool generateMatMul(const TosaReference::GenerateConfig& cfg,
void* data,
size_t size)
{
+ if (cfg.dataType != DType::DType_FP32)
+ {
+ WARNING("[Generator][DP][MatMul] Only supports FP32.");
+ return false;
+ }
if (cfg.shape.size() != 3)
{
WARNING("[Generator][DP][MatMul] Tensor shape expected 3 dimensions.");
diff --git a/reference_model/src/generate/generate_dot_product.h b/reference_model/src/generate/generate_dot_product.h
index 236f577..cd9d4ba 100644
--- a/reference_model/src/generate/generate_dot_product.h
+++ b/reference_model/src/generate/generate_dot_product.h
@@ -37,7 +37,7 @@ std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConf
///
/// \param cfg Generator related meta-data
/// \param data Buffer to generate the data to
-/// \param size Size of the buffet
+/// \param size Size of the buffer
///
/// \return True on successful generation
bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size);
diff --git a/reference_model/src/generate/generate_entry.cc b/reference_model/src/generate/generate_entry.cc
index e7a0044..741cd79 100644
--- a/reference_model/src/generate/generate_entry.cc
+++ b/reference_model/src/generate/generate_entry.cc
@@ -15,6 +15,7 @@
#include "generate.h"
#include "generate_dot_product.h"
+#include "generate_pseudo_random.h"
#include "generate_utils.h"
#include "func_debug.h"
@@ -31,6 +32,10 @@ bool generate(const GenerateConfig& cfg, void* data, size_t size)
return generateDotProduct(cfg, data, size);
break;
}
+ case GeneratorType::PseudoRandom: {
+ return generatePseudoRandom(cfg, data, size);
+ break;
+ }
default: {
WARNING("[Generator] Unsupported generation mode.");
break;
diff --git a/reference_model/src/generate/generate_pseudo_random.cc b/reference_model/src/generate/generate_pseudo_random.cc
new file mode 100644
index 0000000..858a4b2
--- /dev/null
+++ b/reference_model/src/generate/generate_pseudo_random.cc
@@ -0,0 +1,103 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include "generate.h"
+#include "generate_utils.h"
+
+#include <array>
+#include <iterator>
+#include <limits>
+#include <numeric>
+#include <random>
+#include <string>
+#include <type_traits>
+#include <vector>
+
+namespace
+{
+
+// Random generator
+template <typename FP>
+class PseudoRandomGeneratorFloat
+{
+public:
+ PseudoRandomGeneratorFloat(uint64_t seed)
+ : _gen(seed)
+ {
+ // Uniform real distribution generates real values in the range [a, b]
+ // and requires that b - a <= std::numeric_limits<FP>::max() so here
+ // we choose some arbitrary values that satisfy that condition.
+ constexpr auto min = std::numeric_limits<FP>::lowest() / 2;
+ constexpr auto max = std::numeric_limits<FP>::max() / 2;
+ static_assert(max <= std::numeric_limits<FP>::max() + min);
+ _unidis = std::uniform_real_distribution<FP>(min, max);
+
+ // Piecewise Constant distribution
+ const std::array<double, 7> intervals{ min, min + 1000, -1000.0, 0.0, 1000.0, max - 1000, max };
+ const std::array<double, 7> weights{ 1.0, 0.1, 1.0, 2.0, 1.0, 0.1, 1.0 };
+ _pwcdis = std::piecewise_constant_distribution<FP>(intervals.begin(), intervals.end(), weights.begin());
+ }
+
+ FP getRandomUniformFloat()
+ {
+ return _unidis(_gen);
+ }
+
+ FP getRandomPWCFloat()
+ {
+ return _pwcdis(_gen);
+ }
+
+private:
+ std::mt19937 _gen;
+ std::uniform_real_distribution<FP> _unidis;
+ std::piecewise_constant_distribution<FP> _pwcdis;
+};
+
+bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t size)
+{
+ const TosaReference::PseudoRandomInfo& prinfo = cfg.pseudoRandomInfo;
+ PseudoRandomGeneratorFloat<float> generator(prinfo.rngSeed);
+
+ float* a = reinterpret_cast<float*>(data);
+ const auto T = TosaReference::numElementsFromShape(cfg.shape);
+ for (auto t = 0; t < T; ++t)
+ {
+ a[t] = generator.getRandomPWCFloat();
+ }
+ return true;
+}
+
+} // namespace
+
+namespace TosaReference
+{
+bool generatePseudoRandom(const GenerateConfig& cfg, void* data, size_t size)
+{
+ // Check we support the operator
+ if (cfg.opType == Op::Op_UNKNOWN)
+ {
+ WARNING("[Generator][PR] Unknown operator.");
+ return false;
+ }
+
+ switch (cfg.dataType)
+ {
+ case DType::DType_FP32:
+ return generateFP32(cfg, data, size);
+ default:
+ WARNING("[Generator][PR] Unsupported type.");
+ return false;
+ }
+}
+} // namespace TosaReference
diff --git a/reference_model/src/generate/generate_pseudo_random.h b/reference_model/src/generate/generate_pseudo_random.h
new file mode 100644
index 0000000..6796d20
--- /dev/null
+++ b/reference_model/src/generate/generate_pseudo_random.h
@@ -0,0 +1,34 @@
+// Copyright (c) 2023, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GENERATE_PSEUDO_RANDOM_H_
+#define GENERATE_PSEUDO_RANDOM_H_
+
+#include "generate_utils.h"
+
+namespace TosaReference
+{
+
+/// \brief Perform pseudo random based generation
+///
+/// \param cfg Generator related meta-data
+/// \param data Buffer to generate the data to
+/// \param size Size of the buffer
+///
+/// \return True on successful generation
+bool generatePseudoRandom(const GenerateConfig& cfg, void* data, size_t size);
+
+}; // namespace TosaReference
+
+#endif // GENERATE_PSEUDO_RANDOM_H_
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index da16632..bcbf9d7 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -39,6 +39,8 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op,
{
{ Op::Op_UNKNOWN, "UNKNOWN" },
{ Op::Op_MATMUL, "MATMUL" },
+ { Op::Op_MAX_POOL2D, "MAX_POOL2D" },
+ { Op::Op_PAD, "PAD" },
})
} // namespace tosa
@@ -78,6 +80,11 @@ void from_json(const nlohmann::json& j, DotProductInfo& dotProductInfo)
}
}
+void from_json(const nlohmann::json& j, PseudoRandomInfo& pseudoRandomInfo)
+{
+ j.at("rng_seed").get_to(pseudoRandomInfo.rngSeed);
+}
+
void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("data_type").get_to(cfg.dataType);
@@ -90,6 +97,10 @@ void from_json(const nlohmann::json& j, GenerateConfig& cfg)
{
j.at("dot_product_info").get_to(cfg.dotProductInfo);
}
+ if (j.contains("pseudo_random_info"))
+ {
+ j.at("pseudo_random_info").get_to(cfg.pseudoRandomInfo);
+ }
}
std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* tensorName)
diff --git a/reference_model/src/generate/generate_utils.h b/reference_model/src/generate/generate_utils.h
index e8e67bb..0239e98 100644
--- a/reference_model/src/generate/generate_utils.h
+++ b/reference_model/src/generate/generate_utils.h
@@ -55,6 +55,15 @@ struct DotProductInfo
std::array<int32_t, 2> kernel;
};
+/// \brief Pseudo random generator meta-data
+struct PseudoRandomInfo
+{
+ PseudoRandomInfo() = default;
+
+ int64_t rngSeed;
+ // TODO: Add range support
+};
+
/// \brief Generator configuration
struct GenerateConfig
{
@@ -65,6 +74,7 @@ struct GenerateConfig
int32_t inputPos;
tosa::Op opType;
DotProductInfo dotProductInfo;
+ PseudoRandomInfo pseudoRandomInfo;
};
/// \brief Parse the generator config when given in JSON form
diff --git a/reference_model/src/verify/verify_exact.cc b/reference_model/src/verify/verify_exact.cc
index 4d6c72f..36b4ec9 100644
--- a/reference_model/src/verify/verify_exact.cc
+++ b/reference_model/src/verify/verify_exact.cc
@@ -16,6 +16,14 @@
#include "verifiers.h"
#include <cmath>
+namespace
+{
+bool exact_fp32(const double& referenceValue, const float& implementationValue)
+{
+ return std::isnan(referenceValue) ? std::isnan(implementationValue) : (referenceValue == implementationValue);
+}
+} // namespace
+
namespace TosaReference
{
@@ -33,15 +41,14 @@ bool verifyExact(const CTensor* referenceTensor, const CTensor* implementationTe
switch (implementationTensor->data_type)
{
case tosa_datatype_fp32_t: {
- const auto* refData = reinterpret_cast<const float*>(referenceTensor->data);
+ TOSA_REF_REQUIRE(referenceTensor->data_type == tosa_datatype_fp64_t, "[E] Reference tensor is not fp64");
+ const auto* refData = reinterpret_cast<const double*>(referenceTensor->data);
TOSA_REF_REQUIRE(refData != nullptr, "[E] Missing data for reference");
const auto* impData = reinterpret_cast<const float*>(implementationTensor->data);
TOSA_REF_REQUIRE(impData != nullptr, "[E] Missing data for implementation");
- return std::equal(refData, std::next(refData, elementCount), impData, std::next(impData, elementCount),
- [](const auto& referenceValue, const auto& implementationValue) {
- return std::isnan(referenceValue) ? std::isnan(implementationValue)
- : (referenceValue == implementationValue);
- });
+ auto result = std::equal(refData, std::next(refData, elementCount), impData,
+ std::next(impData, elementCount), exact_fp32);
+ return result;
}
default:
WARNING("[Verifier][E] Data-type not supported.");