diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-10-04 14:17:26 +0100 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2023-10-05 14:13:36 +0100 |
commit | 59b307d5b1090680d1918745ee54c8466df4861d (patch) | |
tree | d11d4e0a6f9dbfc55a632803ac86db5656499fed | |
parent | b20b0c9cb4c85bb9a3c901d5acaf421d84656850 (diff) | |
download | reference_model-59b307d5b1090680d1918745ee54c8466df4861d.tar.gz |
Expand TOSA MI generator support for MATMUL
Fixed PrimitiveGenerator starting point and added test sets 1-5.
Fixed verify_test reduce_product missing data_type.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Iaf080ce5c1adb5819f70d1a285d04baa36016092
-rw-r--r-- | reference_model/src/generate/generate_dot_product_states.cc | 231 | ||||
-rw-r--r-- | reference_model/src/generate/generate_utils.cc | 2 | ||||
-rw-r--r-- | reference_model/test/generate_tests.cpp | 195 | ||||
-rw-r--r-- | reference_model/test/verify_tests.cpp | 1 |
4 files changed, 398 insertions, 31 deletions
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc index cd9ffba..d3eeb6d 100644 --- a/reference_model/src/generate/generate_dot_product_states.cc +++ b/reference_model/src/generate/generate_dot_product_states.cc @@ -15,6 +15,7 @@ #include "generate_dot_product.h" #include "generate_utils.h" +#include <cmath> #include <cstdint> namespace @@ -23,6 +24,7 @@ namespace // Input index global variables inline constexpr uint32_t P0 = 0; inline constexpr uint32_t P1 = 1; +inline constexpr uint32_t P2 = 2; // Unused helper function template <typename... Args> @@ -48,10 +50,12 @@ public: [[nodiscard]] float operator()() { - _r = _r * _m + 1; float sign = (_r >> 31) == 0 ? +1 : -1; float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF); + + // Move index and calculate r value for the next index ++_index; + _r = _r * _m + 1; return pseudo; } @@ -69,35 +73,210 @@ private: }; //----------------------------------------------------------------------------// -// State generators +// State generators - equivalent to tosa_mi_data() in the TOSA specification +// +// Each call to the generator returns the next generated value with an +// auto incrementing index //----------------------------------------------------------------------------// -// S0 generator +// Test set 0 generator +// The aim of this generator is to check that sum of products with zero gives zero result. class GeneratorS0 : public TosaReference::IDotProductGenerator { public: GeneratorS0(uint32_t p) : _p(p) - , _s0(0) // set_data(2*S) - , _s1(1) // set_data(2*S+1) + , _set_data0(2 * 0) + , _set_data1(2 * 0 + 1) {} float operator()(uint32_t k) override { unused(k); - const float s0 = _s0(); - const float s1 = _s1(); + const float s0 = _set_data0(); + const float s1 = _set_data1(); if (_p == P0) return s0 < 0.f ? 0.f : s1; - else + else if (_p == P1) return s0 < 0.f ? s1 : 0.f; + else + return 0.f; + } + +private: + uint32_t _p; + PrimitiveGenerator _set_data0; + PrimitiveGenerator _set_data1; +}; + +// Test set 1 generator +// The aim of this test set is to check values with large exponents. +class GeneratorS1 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS1(uint32_t p, uint32_t KS, float B) + : _p(p) + , _KS(KS) + , _B(B) + , _set_data(3 * 1 + p) + {} + float operator()(uint32_t k) override + { + unused(k); + const float s = _set_data(); + float v = 0.75f + 0.25f * s; + if (_p != P2) + return (_B / std::sqrt(_KS + 1)) * v; + else + return (_B * _B / (_KS + 1)) * v; + } + +private: + uint32_t _p; + uint32_t _KS; + float _B; + PrimitiveGenerator _set_data; +}; + +// Test set 2 generator +// The aim of this test set is to check rounding error when accumulating small values +// onto a large value. In this case the small values are of similar magnitude. If the +// implementation changes the order of the sum, then the test data must also be reordered +// so that the largest values occur first in the sum. +class GeneratorS2 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS2(uint32_t p, uint32_t KS) + : _p(p) + , _KS(KS) + , _set_data(2 * 2 + p) + {} + float operator()(uint32_t k) override + { + const float s = _set_data(); + if (_p != P2) + return k == 0 ? 1.f : s / std::sqrt(_KS); + else + return 0.f; } private: uint32_t _p; - PrimitiveGenerator _s0; - PrimitiveGenerator _s1; + uint32_t _KS; + PrimitiveGenerator _set_data; }; +// Test set 3 generator +// The aim of this test set is to check rounding error when accumulating small values +// onto a large value. In this case the small values are of varying magnitude. If the +// implementation changes the order of the sum, then the test data must also be reordered +// so that the largest values occur first in the sum. +class GeneratorS3 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS3(uint32_t p) + : _p(p) + , _set_data(2 * 3 + p) + {} + float operator()(uint32_t k) override + { + const float s0 = _set_data(); + const float s1 = _set_data(); + if (_p != P2) + return k == 0 ? 16.f : std::exp(2 * s0) * s1; + else + return 0.f; + } + +private: + uint32_t _p; + PrimitiveGenerator _set_data; +}; + +// Test set 4 generator +// The aim of this test set is to check a mixture of zero and non-zero products. +class GeneratorS4 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS4(uint32_t p, uint32_t KS, float B) + : _p(p) + , _KS(KS) + , _B(B) + , _set_data0(2 * 4 + 0) + , _set_data1(2 * 4 + 1) + {} + float operator()(uint32_t k) override + { + const float s0 = _set_data0(); + const float s1 = _set_data1(); + if (_p == P0) + return (k == _KS / 2) ? +0.5f : s0 < 0 ? 0.f : (_B / std::sqrt(_KS)) * s1; + else if (_p == P1) + return (k == _KS / 2) ? -0.5f : s0 < 0 ? (_B / std::sqrt(_KS)) * s1 : 0.f; + else + return 0.f; + } + +private: + uint32_t _p; + uint32_t _KS; + float _B; + PrimitiveGenerator _set_data0; + PrimitiveGenerator _set_data1; +}; + +// Test set 5 generator +// The aim of this test set is to check signed inputs of large range. +class GeneratorS5 : public TosaReference::IDotProductGenerator +{ +public: + GeneratorS5(uint32_t p, uint32_t KS, float B) + : _p(p) + , _KS(KS) + , _B(B) + , _set_data(3 * 5 + p) + {} + float operator()(uint32_t k) override + { + unused(k); + const float s = _set_data(); + if (_p != P2) + return (_B / std::sqrt(_KS + 1)) * s; + else + return (_B * _B / (_KS + 1)) * s; + } + +private: + uint32_t _p; + uint32_t _KS; + float _B; + PrimitiveGenerator _set_data; +}; + +float getBoundParameter(const DType& dataType, const DType& accType) +{ + // Work out the bounds parameter value B for the given data and accumulator types + // Returns value > 0.f on success + float B = 0.f; + if (dataType == DType::DType_FP16) + { + if (accType == DType::DType_FP16) + B = 255.875f; // (1<<8) - (1/8); + else if (accType == DType::DType_FP32) + B = 65504.f; // (1<<16) - (1<<5); + } + else if (dataType == DType::DType_BF16) + { + if (accType == DType::DType_FP32) + B = 18374686479671623680.f; // (1<<64) - (1<<56) + } + else if (dataType == DType::DType_FP32) + { + if (accType == DType::DType_FP32) + B = 18446742974197923840.f; // (1<<64) - (1<<40) + } + return B; +} + } // namespace namespace TosaReference @@ -105,15 +284,35 @@ namespace TosaReference std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg) { + // Generators can only support 3 inputs + if (cfg.inputPos > 2) + return nullptr; + const DotProductInfo& dpinfo = cfg.dotProductInfo; - switch (dpinfo.s) + + float B = getBoundParameter(cfg.dataType, dpinfo.accType); + if (B > 0.f) { - case 0: - return std::make_unique<GeneratorS0>(cfg.inputPos); - default: - return nullptr; + // Create the generator + switch (dpinfo.s) + { + case 0: + return std::make_unique<GeneratorS0>(cfg.inputPos); + case 1: + return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B); + case 2: + return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks); + case 3: + return std::make_unique<GeneratorS3>(cfg.inputPos); + case 4: + return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B); + case 5: + return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B); + default: + return nullptr; + } } return nullptr; } -} // namespace TosaReference
\ No newline at end of file +} // namespace TosaReference diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index c52f051..c32d0fb 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -110,6 +110,8 @@ std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char* int64_t numElementsFromShape(const std::vector<int32_t>& shape) { + // Rank 0 shapes have no entries and so this will return 1 + // Other ranked shapes will return the product of their dimensions return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int64_t>()); } diff --git a/reference_model/test/generate_tests.cpp b/reference_model/test/generate_tests.cpp index 88dc979..bce7092 100644 --- a/reference_model/test/generate_tests.cpp +++ b/reference_model/test/generate_tests.cpp @@ -16,9 +16,46 @@ #include <doctest.h> #include <array> +#include <iostream> #include <string> #include <vector> +namespace +{ +template <typename T> +void debug_vec_print(const std::vector<T>& vec) +{ + std::cout << "vector: "; + for (auto v = vec.begin(); v != vec.end(); ++v) + { + T f = *v; + std::cout << std::dec << f << " [" << std::hex << *(uint32_t*)&f << "] "; + } + std::cout << std::dec << '\n'; +} + +void update_json_template(std::string& str, const std::string& set) +{ + std::string find = "_SET_"; + auto pos = str.find(find); + while (pos != std::string::npos) + { + str.replace(pos, find.length(), set); + pos = str.find(find); + } +} + +template <typename T> +void check_output(const std::vector<T>& results, const std::vector<uint32_t>& expected) +{ + for (size_t idx = 0; idx < expected.size(); ++idx) + { + REQUIRE_MESSAGE(expected[idx] == *(uint32_t*)&results[idx], "index: ", idx); + } +} + +} // namespace + TEST_SUITE_BEGIN("generate"); TEST_CASE("negative - api") @@ -34,7 +71,7 @@ TEST_CASE("negative - api") "op" : "MATMUL", "dot_product_info": { "s": 0, - "ks": 10, + "ks": 8, "acc_type": "FP32" } } @@ -81,35 +118,163 @@ TEST_CASE("negative - api") TEST_CASE("positive - dot product") { - std::string json_cfg = R"({ + std::string template_json_cfg = R"({ "tensors" : { "in1" : { "generator": "DOT_PRODUCT", "data_type": "FP32", "input_type": "VARIABLE", - "shape" : [ 4, 8, 8 ], + "shape" : [ 4, 8, 2 ], "input_pos": 0, "op" : "MATMUL", "dot_product_info": { - "s": 0, - "ks": 10, + "s": _SET_, + "ks": 2, + "acc_type": "FP32" + } + }, + "in2" : { + "generator": "DOT_PRODUCT", + "data_type": "FP32", + "input_type": "VARIABLE", + "shape" : [ 4, 2, 5 ], + "input_pos": 1, + "op" : "MATMUL", + "dot_product_info": { + "s": _SET_, + "ks": 2, "acc_type": "FP32" } } + } })"; - const std::string tosaName = "in1"; - const size_t tosaElements = 4 * 8 * 8; - const size_t tosaSize = tosaElements * 4; + const std::string tosaNameP0 = "in1"; + const size_t tosaElementsP0 = 4 * 8 * 2; + const std::string tosaNameP1 = "in2"; + const size_t tosaElementsP1 = 4 * 2 * 5; - SUBCASE("matmul") + SUBCASE("matmul, set 0, param 0") { - std::vector<float> buffer(tosaElements); - REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize)); - REQUIRE(buffer[0] == (float)-0.950864); - REQUIRE(buffer[1] == 0.f); + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "0"); + + std::vector<uint32_t> expected = { 0xbf665aa4, 0xbf736bd3, 0x0 }; + std::vector<float> buffer(tosaElementsP0); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4)); + check_output<float>(buffer, expected); } -} + SUBCASE("matmul, set 0, param 1") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "0"); -TEST_SUITE_END(); // generate
\ No newline at end of file + std::vector<uint32_t> expected = { 0x0, 0x0, 0x3f34f2dd }; + std::vector<float> buffer(tosaElementsP1); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4)); + check_output<float>(buffer, expected); + } + SUBCASE("matmul, set 1, param 0") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "1"); + + std::vector<uint32_t> expected = { 0x5e97f1b0, 0x5ea6a18e, 0x5eb811af }; + std::vector<float> buffer(tosaElementsP0); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4)); + check_output<float>(buffer, expected); + } + SUBCASE("matmul, set 1, param 1") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "1"); + + std::vector<uint32_t> expected = { 0x5f128bb1, 0x5ef54579, 0x5ebd65b8 }; + std::vector<float> buffer(tosaElementsP1); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4)); + check_output<float>(buffer, expected); + } + SUBCASE("matmul, set 2, param 0") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "2"); + + std::vector<uint32_t> expected = { 0x3f800000, 0x3e66ed53, 0x3f800000 }; + std::vector<float> buffer(tosaElementsP0); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4)); + check_output<float>(buffer, expected); + } + SUBCASE("matmul, set 2, param 1") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "2"); + + std::vector<uint32_t> expected = { 0x3f800000, 0x3f800000, 0x3f800000 }; + std::vector<float> buffer(tosaElementsP1); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4)); + check_output<float>(buffer, expected); + } + SUBCASE("matmul, set 3, param 0") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "3"); + + // NOTE: Python test script produced 0xbf256686 - so off by 1 + std::vector<uint32_t> expected = { 0x41800000, 0xbf256685, 0x41800000 }; + std::vector<float> buffer(tosaElementsP0); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4)); + check_output<float>(buffer, expected); + } + SUBCASE("matmul, set 3, param 1") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "3"); + + std::vector<uint32_t> expected = { 0x41800000, 0x41800000, 0x41800000 }; + std::vector<float> buffer(tosaElementsP1); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4)); + check_output<float>(buffer, expected); + } + SUBCASE("matmul, set 4, param 0") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "4"); + + std::vector<uint32_t> expected = { 0x0, 0x3f000000, 0x5f14e80c }; + std::vector<float> buffer(tosaElementsP0); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4)); + check_output<float>(buffer, expected); + } + SUBCASE("matmul, set 4, param 1") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "4"); + + std::vector<uint32_t> expected = { 0x5d5d0db2, 0xdf2c82a8, 0x0 }; + std::vector<float> buffer(tosaElementsP1); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4)); + check_output<float>(buffer, expected); + } + SUBCASE("matmul, set 5, param 0") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "5"); + + std::vector<uint32_t> expected = { 0x5df6c4b3, 0x5e6b4088, 0x5ed0fe71 }; + std::vector<float> buffer(tosaElementsP0); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4)); + check_output<float>(buffer, expected); + } + SUBCASE("matmul, set 5, param 1") + { + std::string json_cfg = template_json_cfg; + update_json_template(json_cfg, "5"); + + std::vector<uint32_t> expected = { 0xde086d85, 0x5e630878, 0x5eba5c7b }; + std::vector<float> buffer(tosaElementsP1); + REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4)); + check_output<float>(buffer, expected); + } +} +TEST_SUITE_END(); // generate diff --git a/reference_model/test/verify_tests.cpp b/reference_model/test/verify_tests.cpp index b75ddec..f36efbf 100644 --- a/reference_model/test/verify_tests.cpp +++ b/reference_model/test/verify_tests.cpp @@ -264,6 +264,7 @@ TEST_CASE("positive - reduce product") "tensors" : { "out1" : { "mode": "REDUCE_PRODUCT", + "data_type": "FP32", "reduce_product_info": { "m": 23, "n": 8 |