diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-10-26 13:53:14 +0100 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-11-02 23:22:09 +0000 |
commit | a4d907e8686791dd84ed987d0d79325c4d908b73 (patch) | |
tree | 9748ef39183b7548a9ff50d457920eace3a6fdec /reference_model/src/generate/generate_pseudo_random.cc | |
parent | d1a08ce27ef8d0f6cf77e1b864610aade06edc5c (diff) | |
download | reference_model-a4d907e8686791dd84ed987d0d79325c4d908b73.tar.gz |
Main compliance testing support for MUL
Update verify ULP mode to allow fractions (e.g. 0.5).
Update pseudo generator to accept ranges.
Fix up pseudo random distribution based on ranges.
Change-Id: I9168c5f7d37722678c0f1f9e906953c8cec367b1
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Diffstat (limited to 'reference_model/src/generate/generate_pseudo_random.cc')
-rw-r--r-- | reference_model/src/generate/generate_pseudo_random.cc | 62 |
1 files changed, 51 insertions, 11 deletions
diff --git a/reference_model/src/generate/generate_pseudo_random.cc b/reference_model/src/generate/generate_pseudo_random.cc index 858a4b2..f234796 100644 --- a/reference_model/src/generate/generate_pseudo_random.cc +++ b/reference_model/src/generate/generate_pseudo_random.cc @@ -40,40 +40,76 @@ public: constexpr auto min = std::numeric_limits<FP>::lowest() / 2; constexpr auto max = std::numeric_limits<FP>::max() / 2; static_assert(max <= std::numeric_limits<FP>::max() + min); - _unidis = std::uniform_real_distribution<FP>(min, max); - // Piecewise Constant distribution - const std::array<double, 7> intervals{ min, min + 1000, -1000.0, 0.0, 1000.0, max - 1000, max }; - const std::array<double, 7> weights{ 1.0, 0.1, 1.0, 2.0, 1.0, 0.1, 1.0 }; - _pwcdis = std::piecewise_constant_distribution<FP>(intervals.begin(), intervals.end(), weights.begin()); + setDistribution(min, max); } - FP getRandomUniformFloat() + PseudoRandomGeneratorFloat(uint64_t seed, FP min, FP max) + : _gen(seed) { - return _unidis(_gen); + setDistribution(min, max); } - FP getRandomPWCFloat() + FP getRandomFloat() { - return _pwcdis(_gen); + if (_useUniform) + return _unidis(_gen); + else + return _pwcdis(_gen); } private: + void setDistribution(FP min, FP max) + { + _unidis = std::uniform_real_distribution<FP>(min, max); + + // Piecewise Constant distribution for larger ranges + double range = std::abs(max - min); + double mid; + if (max == -min) + mid = 0.f; + else + mid = (range / 2) + min; + double segment = std::min<double>(1000.0, range / 5); + + const std::array<double, 7> intervals{ + min, min + segment, mid - segment, mid, mid + segment, max - segment, max + }; + const std::array<double, 7> weights{ 1.0, 0.1, 1.0, 2.0, 1.0, 0.1, 1.0 }; + _pwcdis = std::piecewise_constant_distribution<FP>(intervals.begin(), intervals.end(), weights.begin()); + + // Uniform distribution works well on smaller ranges + _useUniform = (range < 2000.0); + } + std::mt19937 _gen; std::uniform_real_distribution<FP> _unidis; std::piecewise_constant_distribution<FP> _pwcdis; + bool _useUniform; }; bool generateFP32(const TosaReference::GenerateConfig& cfg, void* data, size_t size) { const TosaReference::PseudoRandomInfo& prinfo = cfg.pseudoRandomInfo; - PseudoRandomGeneratorFloat<float> generator(prinfo.rngSeed); + + PseudoRandomGeneratorFloat<float>* generator; + + if (prinfo.range.size() == 2) + { + const float min = std::stof(prinfo.range[0]); + const float max = std::stof(prinfo.range[1]); + generator = new PseudoRandomGeneratorFloat<float>(prinfo.rngSeed, min, max); + } + else + { + generator = new PseudoRandomGeneratorFloat<float>(prinfo.rngSeed); + } float* a = reinterpret_cast<float*>(data); const auto T = TosaReference::numElementsFromShape(cfg.shape); for (auto t = 0; t < T; ++t) { - a[t] = generator.getRandomPWCFloat(); + a[t] = generator->getRandomFloat(); } return true; } @@ -90,6 +126,10 @@ bool generatePseudoRandom(const GenerateConfig& cfg, void* data, size_t size) WARNING("[Generator][PR] Unknown operator."); return false; } + if (cfg.pseudoRandomInfo.range.size() != 0 || cfg.pseudoRandomInfo.range.size() != 2) + { + WARNING("[Generator][PR] Invalid range."); + } switch (cfg.dataType) { |