aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_pseudo_random.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/generate/generate_pseudo_random.cc')
-rw-r--r--reference_model/src/generate/generate_pseudo_random.cc132
1 files changed, 131 insertions, 1 deletions
diff --git a/reference_model/src/generate/generate_pseudo_random.cc b/reference_model/src/generate/generate_pseudo_random.cc
index b62c38f..865483b 100644
--- a/reference_model/src/generate/generate_pseudo_random.cc
+++ b/reference_model/src/generate/generate_pseudo_random.cc
@@ -15,6 +15,7 @@
#include "generate_utils.h"
#include "half.hpp"
+#include <algorithm>
#include <array>
#include <iterator>
#include <limits>
@@ -27,7 +28,7 @@
namespace
{
-// Random generator
+// Random FP generator
template <typename FP>
class PseudoRandomGeneratorFloat
{
@@ -127,6 +128,123 @@ bool generateFP(const TosaReference::GenerateConfig& cfg, DataType* data, size_t
return true;
}
+// Random INT generator
+template <typename INT>
+class PseudoRandomGeneratorInteger
+{
+public:
+ PseudoRandomGeneratorInteger(uint64_t seed)
+ : _gen(seed)
+ {
+ constexpr auto min = std::numeric_limits<INT>::min();
+ constexpr auto max = std::numeric_limits<INT>::max();
+
+ setDistribution(min, max);
+ }
+
+ PseudoRandomGeneratorInteger(uint64_t seed, INT min, INT max)
+ : _gen(seed)
+ {
+ setDistribution(min, max);
+ }
+
+ INT getRandomInteger()
+ {
+ return _unidis(_gen);
+ }
+
+ INT getRandomInteger(INT min, INT max)
+ {
+ typename std::uniform_int_distribution<INT>::param_type range(min, max);
+ return _unidis(_gen, range);
+ }
+
+private:
+ void setDistribution(INT min, INT max)
+ {
+ _unidis = std::uniform_int_distribution<INT>(min, max);
+ }
+
+ std::mt19937 _gen;
+ std::uniform_int_distribution<INT> _unidis;
+};
+
+template <typename DataType>
+bool shuffleINTbyRow(const TosaReference::GenerateConfig& cfg, DataType* data, size_t size)
+{
+ const TosaReference::PseudoRandomInfo& prinfo = cfg.pseudoRandomInfo;
+ PseudoRandomGeneratorInteger<DataType>* generator;
+
+ if (cfg.shape.size() != 2)
+ {
+ WARNING("[Generator][PR][INT] Shuffle only supports 2 dimensional tensors.");
+ return false;
+ }
+ if (prinfo.range.size() != 2)
+ {
+ WARNING("[Generator][PR][INT] Cannot create un-ranged shuffle data.");
+ return false;
+ }
+
+ const int32_t min = std::stoi(prinfo.range[0]);
+ const int32_t max = std::stoi(prinfo.range[1]);
+ generator = new PseudoRandomGeneratorInteger<DataType>(prinfo.rngSeed, min, max);
+
+ // Work out inclusive range
+ const auto range = std::abs(max - min) + 1;
+ const auto N = cfg.shape[0]; // Number of rows
+ const auto W = cfg.shape[1]; // Width of rows
+ if (W > range)
+ {
+ WARNING("[Generator][PR][INT] Cannot fill data size %d with given shuffle range %d.", W, range);
+ return false;
+ }
+
+ std::vector<DataType> numbers(range);
+ for (int n = 0; n < N; ++n)
+ {
+ // Fill in the numbers in range
+ std::iota(numbers.begin(), numbers.end(), min);
+
+ // Perform random shuffling
+ for (auto num = numbers.begin(); num < numbers.end(); ++num)
+ {
+ std::swap(*num, numbers[generator->getRandomInteger()]);
+ }
+ // Copy amount of data required
+ for (auto w = 0; w < W; ++w)
+ {
+ data[(n * W) + w] = numbers[w];
+ }
+ }
+ return true;
+}
+
+template <typename DataType>
+bool generateINT(const TosaReference::GenerateConfig& cfg, DataType* data, size_t size)
+{
+ const TosaReference::PseudoRandomInfo& prinfo = cfg.pseudoRandomInfo;
+ PseudoRandomGeneratorInteger<DataType>* generator;
+
+ const auto T = TosaReference::numElementsFromShape(cfg.shape);
+
+ if (prinfo.range.size() == 2)
+ {
+ const int32_t min = std::stoi(prinfo.range[0]);
+ const int32_t max = std::stoi(prinfo.range[1]);
+ generator = new PseudoRandomGeneratorInteger<DataType>(prinfo.rngSeed, min, max);
+ }
+ else
+ {
+ generator = new PseudoRandomGeneratorInteger<DataType>(prinfo.rngSeed);
+ }
+
+ for (auto t = 0; t < T; ++t)
+ {
+ data[t] = generator->getRandomInteger();
+ }
+ return true;
+}
} // namespace
namespace TosaReference
@@ -155,6 +273,18 @@ bool generatePseudoRandom(const GenerateConfig& cfg, void* data, size_t size)
half_float::half* outData = reinterpret_cast<half_float::half*>(data);
return generateFP(cfg, outData, size);
}
+ case DType::DType_INT32: {
+ int32_t* outData = reinterpret_cast<int32_t*>(data);
+ if (cfg.opType == Op::Op_SCATTER && cfg.inputPos == 1)
+ {
+ // Indices for SCATTER must not repeat - perform data shuffle
+ return shuffleINTbyRow(cfg, outData, size);
+ }
+ else
+ {
+ return generateINT(cfg, outData, size);
+ }
+ }
default:
WARNING("[Generator][PR] Unsupported type.");
return false;