From 59b307d5b1090680d1918745ee54c8466df4861d Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Wed, 4 Oct 2023 14:17:26 +0100 Subject: Expand TOSA MI generator support for MATMUL Fixed PrimitiveGenerator starting point and added test sets 1-5. Fixed verify_test reduce_product missing data_type. Signed-off-by: Jeremy Johnson Change-Id: Iaf080ce5c1adb5819f70d1a285d04baa36016092 --- .../src/generate/generate_dot_product_states.cc | 231 +++++++++++++++++++-- reference_model/src/generate/generate_utils.cc | 2 + 2 files changed, 217 insertions(+), 16 deletions(-) (limited to 'reference_model/src/generate') diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc index cd9ffba..d3eeb6d 100644 --- a/reference_model/src/generate/generate_dot_product_states.cc +++ b/reference_model/src/generate/generate_dot_product_states.cc @@ -15,6 +15,7 @@ #include "generate_dot_product.h" #include "generate_utils.h" +#include #include namespace @@ -23,6 +24,7 @@ namespace // Input index global variables inline constexpr uint32_t P0 = 0; inline constexpr uint32_t P1 = 1; +inline constexpr uint32_t P2 = 2; // Unused helper function template @@ -48,10 +50,12 @@ public: [[nodiscard]] float operator()() { - _r = _r * _m + 1; float sign = (_r >> 31) == 0 ? +1 : -1; float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF); + + // Move index and calculate r value for the next index ++_index; + _r = _r * _m + 1; return pseudo; } @@ -69,35 +73,210 @@ private: }; //----------------------------------------------------------------------------// -// State generators +// State generators - equivalent to tosa_mi_data() in the TOSA specification +// +// Each call to the generator returns the next generated value with an +// auto incrementing index //----------------------------------------------------------------------------// -// S0 generator +// Test set 0 generator +// The aim of this generator is to check that sum of products with zero gives zero result. class GeneratorS0 : public TosaReference::IDotProductGenerator { public: GeneratorS0(uint32_t p) : _p(p) - , _s0(0) // set_data(2*S) - , _s1(1) // set_data(2*S+1) + , _set_data0(2 * 0) + , _set_data1(2 * 0 + 1) {} float operator()(uint32_t k) override { unused(k); - const float s0 = _s0(); - const float s1 = _s1(); + const float s0 = _set_data0(); + const float s1 = _set_data1(); if (_p == P0) return s0 < 0.f ? 0.f : s1; - else + else if (_p == P1) return s0 < 0.f ? s1 : 0.f; + else + return 0.f; + } + +private: + uint32_t _p; + PrimitiveGenerator _set_data0; + PrimitiveGenerator _set_data1; +}; + +// Test set 1 generator +// The aim of this test set is to check values with large exponents. +class GeneratorS1 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS1(uint32_t p, uint32_t KS, float B) + : _p(p) + , _KS(KS) + , _B(B) + , _set_data(3 * 1 + p) + {} + float operator()(uint32_t k) override + { + unused(k); + const float s = _set_data(); + float v = 0.75f + 0.25f * s; + if (_p != P2) + return (_B / std::sqrt(_KS + 1)) * v; + else + return (_B * _B / (_KS + 1)) * v; + } + +private: + uint32_t _p; + uint32_t _KS; + float _B; + PrimitiveGenerator _set_data; +}; + +// Test set 2 generator +// The aim of this test set is to check rounding error when accumulating small values +// onto a large value. In this case the small values are of similar magnitude. If the +// implementation changes the order of the sum, then the test data must also be reordered +// so that the largest values occur first in the sum. +class GeneratorS2 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS2(uint32_t p, uint32_t KS) + : _p(p) + , _KS(KS) + , _set_data(2 * 2 + p) + {} + float operator()(uint32_t k) override + { + const float s = _set_data(); + if (_p != P2) + return k == 0 ? 1.f : s / std::sqrt(_KS); + else + return 0.f; } private: uint32_t _p; - PrimitiveGenerator _s0; - PrimitiveGenerator _s1; + uint32_t _KS; + PrimitiveGenerator _set_data; }; +// Test set 3 generator +// The aim of this test set is to check rounding error when accumulating small values +// onto a large value. In this case the small values are of varying magnitude. If the +// implementation changes the order of the sum, then the test data must also be reordered +// so that the largest values occur first in the sum. +class GeneratorS3 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS3(uint32_t p) + : _p(p) + , _set_data(2 * 3 + p) + {} + float operator()(uint32_t k) override + { + const float s0 = _set_data(); + const float s1 = _set_data(); + if (_p != P2) + return k == 0 ? 16.f : std::exp(2 * s0) * s1; + else + return 0.f; + } + +private: + uint32_t _p; + PrimitiveGenerator _set_data; +}; + +// Test set 4 generator +// The aim of this test set is to check a mixture of zero and non-zero products. +class GeneratorS4 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS4(uint32_t p, uint32_t KS, float B) + : _p(p) + , _KS(KS) + , _B(B) + , _set_data0(2 * 4 + 0) + , _set_data1(2 * 4 + 1) + {} + float operator()(uint32_t k) override + { + const float s0 = _set_data0(); + const float s1 = _set_data1(); + if (_p == P0) + return (k == _KS / 2) ? +0.5f : s0 < 0 ? 0.f : (_B / std::sqrt(_KS)) * s1; + else if (_p == P1) + return (k == _KS / 2) ? -0.5f : s0 < 0 ? (_B / std::sqrt(_KS)) * s1 : 0.f; + else + return 0.f; + } + +private: + uint32_t _p; + uint32_t _KS; + float _B; + PrimitiveGenerator _set_data0; + PrimitiveGenerator _set_data1; +}; + +// Test set 5 generator +// The aim of this test set is to check signed inputs of large range. +class GeneratorS5 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS5(uint32_t p, uint32_t KS, float B) + : _p(p) + , _KS(KS) + , _B(B) + , _set_data(3 * 5 + p) + {} + float operator()(uint32_t k) override + { + unused(k); + const float s = _set_data(); + if (_p != P2) + return (_B / std::sqrt(_KS + 1)) * s; + else + return (_B * _B / (_KS + 1)) * s; + } + +private: + uint32_t _p; + uint32_t _KS; + float _B; + PrimitiveGenerator _set_data; +}; + +float getBoundParameter(const DType& dataType, const DType& accType) +{ + // Work out the bounds parameter value B for the given data and accumulator types + // Returns value > 0.f on success + float B = 0.f; + if (dataType == DType::DType_FP16) + { + if (accType == DType::DType_FP16) + B = 255.875f; // (1<<8) - (1/8); + else if (accType == DType::DType_FP32) + B = 65504.f; // (1<<16) - (1<<5); + } + else if (dataType == DType::DType_BF16) + { + if (accType == DType::DType_FP32) + B = 18374686479671623680.f; // (1<<64) - (1<<56) + } + else if (dataType == DType::DType_FP32) + { + if (accType == DType::DType_FP32) + B = 18446742974197923840.f; // (1<<64) - (1<<40) + } + return B; +} + } // namespace namespace TosaReference @@ -105,15 +284,35 @@ namespace TosaReference std::unique_ptr pickDotProductGenerator(const GenerateConfig& cfg) { + // Generators can only support 3 inputs + if (cfg.inputPos > 2) + return nullptr; + const DotProductInfo& dpinfo = cfg.dotProductInfo; - switch (dpinfo.s) + + float B = getBoundParameter(cfg.dataType, dpinfo.accType); + if (B > 0.f) { - case 0: - return std::make_unique(cfg.inputPos); - default: - return nullptr; + // Create the generator + switch (dpinfo.s) + { + case 0: + return std::make_unique(cfg.inputPos); + case 1: + return std::make_unique(cfg.inputPos, dpinfo.ks, B); + case 2: + return std::make_unique(cfg.inputPos, dpinfo.ks); + case 3: + return std::make_unique(cfg.inputPos); + case 4: + return std::make_unique(cfg.inputPos, dpinfo.ks, B); + case 5: + return std::make_unique(cfg.inputPos, dpinfo.ks, B); + default: + return nullptr; + } } return nullptr; } -} // namespace TosaReference \ No newline at end of file +} // namespace TosaReference diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index c52f051..c32d0fb 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -110,6 +110,8 @@ std::optional parseGenerateConfig(const char* json, const char* int64_t numElementsFromShape(const std::vector& shape) { + // Rank 0 shapes have no entries and so this will return 1 + // Other ranked shapes will return the product of their dimensions return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies()); } -- cgit v1.2.1