aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-04 14:17:26 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2023-10-05 14:13:36 +0100
commit59b307d5b1090680d1918745ee54c8466df4861d (patch)
treed11d4e0a6f9dbfc55a632803ac86db5656499fed
parentb20b0c9cb4c85bb9a3c901d5acaf421d84656850 (diff)
downloadreference_model-59b307d5b1090680d1918745ee54c8466df4861d.tar.gz
Expand TOSA MI generator support for MATMUL
Fixed PrimitiveGenerator starting point and added test sets 1-5. Fixed verify_test reduce_product missing data_type. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Iaf080ce5c1adb5819f70d1a285d04baa36016092
-rw-r--r--reference_model/src/generate/generate_dot_product_states.cc231
-rw-r--r--reference_model/src/generate/generate_utils.cc2
-rw-r--r--reference_model/test/generate_tests.cpp195
-rw-r--r--reference_model/test/verify_tests.cpp1
4 files changed, 398 insertions, 31 deletions
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc
index cd9ffba..d3eeb6d 100644
--- a/reference_model/src/generate/generate_dot_product_states.cc
+++ b/reference_model/src/generate/generate_dot_product_states.cc
@@ -15,6 +15,7 @@
#include "generate_dot_product.h"
#include "generate_utils.h"
+#include <cmath>
#include <cstdint>
namespace
@@ -23,6 +24,7 @@ namespace
// Input index global variables
inline constexpr uint32_t P0 = 0;
inline constexpr uint32_t P1 = 1;
+inline constexpr uint32_t P2 = 2;
// Unused helper function
template <typename... Args>
@@ -48,10 +50,12 @@ public:
[[nodiscard]] float operator()()
{
- _r = _r * _m + 1;
float sign = (_r >> 31) == 0 ? +1 : -1;
float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF);
+
+ // Move index and calculate r value for the next index
++_index;
+ _r = _r * _m + 1;
return pseudo;
}
@@ -69,35 +73,210 @@ private:
};
//----------------------------------------------------------------------------//
-// State generators
+// State generators - equivalent to tosa_mi_data() in the TOSA specification
+//
+// Each call to the generator returns the next generated value with an
+// auto incrementing index
//----------------------------------------------------------------------------//
-// S0 generator
+// Test set 0 generator
+// The aim of this generator is to check that sum of products with zero gives zero result.
class GeneratorS0 : public TosaReference::IDotProductGenerator
{
public:
GeneratorS0(uint32_t p)
: _p(p)
- , _s0(0) // set_data(2*S)
- , _s1(1) // set_data(2*S+1)
+ , _set_data0(2 * 0)
+ , _set_data1(2 * 0 + 1)
{}
float operator()(uint32_t k) override
{
unused(k);
- const float s0 = _s0();
- const float s1 = _s1();
+ const float s0 = _set_data0();
+ const float s1 = _set_data1();
if (_p == P0)
return s0 < 0.f ? 0.f : s1;
- else
+ else if (_p == P1)
return s0 < 0.f ? s1 : 0.f;
+ else
+ return 0.f;
+ }
+
+private:
+ uint32_t _p;
+ PrimitiveGenerator _set_data0;
+ PrimitiveGenerator _set_data1;
+};
+
+// Test set 1 generator
+// The aim of this test set is to check values with large exponents.
+class GeneratorS1 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS1(uint32_t p, uint32_t KS, float B)
+ : _p(p)
+ , _KS(KS)
+ , _B(B)
+ , _set_data(3 * 1 + p)
+ {}
+ float operator()(uint32_t k) override
+ {
+ unused(k);
+ const float s = _set_data();
+ float v = 0.75f + 0.25f * s;
+ if (_p != P2)
+ return (_B / std::sqrt(_KS + 1)) * v;
+ else
+ return (_B * _B / (_KS + 1)) * v;
+ }
+
+private:
+ uint32_t _p;
+ uint32_t _KS;
+ float _B;
+ PrimitiveGenerator _set_data;
+};
+
+// Test set 2 generator
+// The aim of this test set is to check rounding error when accumulating small values
+// onto a large value. In this case the small values are of similar magnitude. If the
+// implementation changes the order of the sum, then the test data must also be reordered
+// so that the largest values occur first in the sum.
+class GeneratorS2 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS2(uint32_t p, uint32_t KS)
+ : _p(p)
+ , _KS(KS)
+ , _set_data(2 * 2 + p)
+ {}
+ float operator()(uint32_t k) override
+ {
+ const float s = _set_data();
+ if (_p != P2)
+ return k == 0 ? 1.f : s / std::sqrt(_KS);
+ else
+ return 0.f;
}
private:
uint32_t _p;
- PrimitiveGenerator _s0;
- PrimitiveGenerator _s1;
+ uint32_t _KS;
+ PrimitiveGenerator _set_data;
};
+// Test set 3 generator
+// The aim of this test set is to check rounding error when accumulating small values
+// onto a large value. In this case the small values are of varying magnitude. If the
+// implementation changes the order of the sum, then the test data must also be reordered
+// so that the largest values occur first in the sum.
+class GeneratorS3 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS3(uint32_t p)
+ : _p(p)
+ , _set_data(2 * 3 + p)
+ {}
+ float operator()(uint32_t k) override
+ {
+ const float s0 = _set_data();
+ const float s1 = _set_data();
+ if (_p != P2)
+ return k == 0 ? 16.f : std::exp(2 * s0) * s1;
+ else
+ return 0.f;
+ }
+
+private:
+ uint32_t _p;
+ PrimitiveGenerator _set_data;
+};
+
+// Test set 4 generator
+// The aim of this test set is to check a mixture of zero and non-zero products.
+class GeneratorS4 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS4(uint32_t p, uint32_t KS, float B)
+ : _p(p)
+ , _KS(KS)
+ , _B(B)
+ , _set_data0(2 * 4 + 0)
+ , _set_data1(2 * 4 + 1)
+ {}
+ float operator()(uint32_t k) override
+ {
+ const float s0 = _set_data0();
+ const float s1 = _set_data1();
+ if (_p == P0)
+ return (k == _KS / 2) ? +0.5f : s0 < 0 ? 0.f : (_B / std::sqrt(_KS)) * s1;
+ else if (_p == P1)
+ return (k == _KS / 2) ? -0.5f : s0 < 0 ? (_B / std::sqrt(_KS)) * s1 : 0.f;
+ else
+ return 0.f;
+ }
+
+private:
+ uint32_t _p;
+ uint32_t _KS;
+ float _B;
+ PrimitiveGenerator _set_data0;
+ PrimitiveGenerator _set_data1;
+};
+
+// Test set 5 generator
+// The aim of this test set is to check signed inputs of large range.
+class GeneratorS5 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS5(uint32_t p, uint32_t KS, float B)
+ : _p(p)
+ , _KS(KS)
+ , _B(B)
+ , _set_data(3 * 5 + p)
+ {}
+ float operator()(uint32_t k) override
+ {
+ unused(k);
+ const float s = _set_data();
+ if (_p != P2)
+ return (_B / std::sqrt(_KS + 1)) * s;
+ else
+ return (_B * _B / (_KS + 1)) * s;
+ }
+
+private:
+ uint32_t _p;
+ uint32_t _KS;
+ float _B;
+ PrimitiveGenerator _set_data;
+};
+
+float getBoundParameter(const DType& dataType, const DType& accType)
+{
+ // Work out the bounds parameter value B for the given data and accumulator types
+ // Returns value > 0.f on success
+ float B = 0.f;
+ if (dataType == DType::DType_FP16)
+ {
+ if (accType == DType::DType_FP16)
+ B = 255.875f; // (1<<8) - (1/8);
+ else if (accType == DType::DType_FP32)
+ B = 65504.f; // (1<<16) - (1<<5);
+ }
+ else if (dataType == DType::DType_BF16)
+ {
+ if (accType == DType::DType_FP32)
+ B = 18374686479671623680.f; // (1<<64) - (1<<56)
+ }
+ else if (dataType == DType::DType_FP32)
+ {
+ if (accType == DType::DType_FP32)
+ B = 18446742974197923840.f; // (1<<64) - (1<<40)
+ }
+ return B;
+}
+
} // namespace
namespace TosaReference
@@ -105,15 +284,35 @@ namespace TosaReference
std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg)
{
+ // Generators can only support 3 inputs
+ if (cfg.inputPos > 2)
+ return nullptr;
+
const DotProductInfo& dpinfo = cfg.dotProductInfo;
- switch (dpinfo.s)
+
+ float B = getBoundParameter(cfg.dataType, dpinfo.accType);
+ if (B > 0.f)
{
- case 0:
- return std::make_unique<GeneratorS0>(cfg.inputPos);
- default:
- return nullptr;
+ // Create the generator
+ switch (dpinfo.s)
+ {
+ case 0:
+ return std::make_unique<GeneratorS0>(cfg.inputPos);
+ case 1:
+ return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B);
+ case 2:
+ return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks);
+ case 3:
+ return std::make_unique<GeneratorS3>(cfg.inputPos);
+ case 4:
+ return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B);
+ case 5:
+ return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B);
+ default:
+ return nullptr;
+ }
}
return nullptr;
}
-} // namespace TosaReference \ No newline at end of file
+} // namespace TosaReference
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index c52f051..c32d0fb 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -110,6 +110,8 @@ std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char*
int64_t numElementsFromShape(const std::vector<int32_t>& shape)
{
+ // Rank 0 shapes have no entries and so this will return 1
+ // Other ranked shapes will return the product of their dimensions
return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int64_t>());
}
diff --git a/reference_model/test/generate_tests.cpp b/reference_model/test/generate_tests.cpp
index 88dc979..bce7092 100644
--- a/reference_model/test/generate_tests.cpp
+++ b/reference_model/test/generate_tests.cpp
@@ -16,9 +16,46 @@
#include <doctest.h>
#include <array>
+#include <iostream>
#include <string>
#include <vector>
+namespace
+{
+template <typename T>
+void debug_vec_print(const std::vector<T>& vec)
+{
+ std::cout << "vector: ";
+ for (auto v = vec.begin(); v != vec.end(); ++v)
+ {
+ T f = *v;
+ std::cout << std::dec << f << " [" << std::hex << *(uint32_t*)&f << "] ";
+ }
+ std::cout << std::dec << '\n';
+}
+
+void update_json_template(std::string& str, const std::string& set)
+{
+ std::string find = "_SET_";
+ auto pos = str.find(find);
+ while (pos != std::string::npos)
+ {
+ str.replace(pos, find.length(), set);
+ pos = str.find(find);
+ }
+}
+
+template <typename T>
+void check_output(const std::vector<T>& results, const std::vector<uint32_t>& expected)
+{
+ for (size_t idx = 0; idx < expected.size(); ++idx)
+ {
+ REQUIRE_MESSAGE(expected[idx] == *(uint32_t*)&results[idx], "index: ", idx);
+ }
+}
+
+} // namespace
+
TEST_SUITE_BEGIN("generate");
TEST_CASE("negative - api")
@@ -34,7 +71,7 @@ TEST_CASE("negative - api")
"op" : "MATMUL",
"dot_product_info": {
"s": 0,
- "ks": 10,
+ "ks": 8,
"acc_type": "FP32"
}
}
@@ -81,35 +118,163 @@ TEST_CASE("negative - api")
TEST_CASE("positive - dot product")
{
- std::string json_cfg = R"({
+ std::string template_json_cfg = R"({
"tensors" : {
"in1" : {
"generator": "DOT_PRODUCT",
"data_type": "FP32",
"input_type": "VARIABLE",
- "shape" : [ 4, 8, 8 ],
+ "shape" : [ 4, 8, 2 ],
"input_pos": 0,
"op" : "MATMUL",
"dot_product_info": {
- "s": 0,
- "ks": 10,
+ "s": _SET_,
+ "ks": 2,
+ "acc_type": "FP32"
+ }
+ },
+ "in2" : {
+ "generator": "DOT_PRODUCT",
+ "data_type": "FP32",
+ "input_type": "VARIABLE",
+ "shape" : [ 4, 2, 5 ],
+ "input_pos": 1,
+ "op" : "MATMUL",
+ "dot_product_info": {
+ "s": _SET_,
+ "ks": 2,
"acc_type": "FP32"
}
}
+
}
})";
- const std::string tosaName = "in1";
- const size_t tosaElements = 4 * 8 * 8;
- const size_t tosaSize = tosaElements * 4;
+ const std::string tosaNameP0 = "in1";
+ const size_t tosaElementsP0 = 4 * 8 * 2;
+ const std::string tosaNameP1 = "in2";
+ const size_t tosaElementsP1 = 4 * 2 * 5;
- SUBCASE("matmul")
+ SUBCASE("matmul, set 0, param 0")
{
- std::vector<float> buffer(tosaElements);
- REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaName.c_str(), (void*)buffer.data(), tosaSize));
- REQUIRE(buffer[0] == (float)-0.950864);
- REQUIRE(buffer[1] == 0.f);
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "0");
+
+ std::vector<uint32_t> expected = { 0xbf665aa4, 0xbf736bd3, 0x0 };
+ std::vector<float> buffer(tosaElementsP0);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4));
+ check_output<float>(buffer, expected);
}
-}
+ SUBCASE("matmul, set 0, param 1")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "0");
-TEST_SUITE_END(); // generate \ No newline at end of file
+ std::vector<uint32_t> expected = { 0x0, 0x0, 0x3f34f2dd };
+ std::vector<float> buffer(tosaElementsP1);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4));
+ check_output<float>(buffer, expected);
+ }
+ SUBCASE("matmul, set 1, param 0")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "1");
+
+ std::vector<uint32_t> expected = { 0x5e97f1b0, 0x5ea6a18e, 0x5eb811af };
+ std::vector<float> buffer(tosaElementsP0);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4));
+ check_output<float>(buffer, expected);
+ }
+ SUBCASE("matmul, set 1, param 1")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "1");
+
+ std::vector<uint32_t> expected = { 0x5f128bb1, 0x5ef54579, 0x5ebd65b8 };
+ std::vector<float> buffer(tosaElementsP1);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4));
+ check_output<float>(buffer, expected);
+ }
+ SUBCASE("matmul, set 2, param 0")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "2");
+
+ std::vector<uint32_t> expected = { 0x3f800000, 0x3e66ed53, 0x3f800000 };
+ std::vector<float> buffer(tosaElementsP0);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4));
+ check_output<float>(buffer, expected);
+ }
+ SUBCASE("matmul, set 2, param 1")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "2");
+
+ std::vector<uint32_t> expected = { 0x3f800000, 0x3f800000, 0x3f800000 };
+ std::vector<float> buffer(tosaElementsP1);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4));
+ check_output<float>(buffer, expected);
+ }
+ SUBCASE("matmul, set 3, param 0")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "3");
+
+ // NOTE: Python test script produced 0xbf256686 - so off by 1
+ std::vector<uint32_t> expected = { 0x41800000, 0xbf256685, 0x41800000 };
+ std::vector<float> buffer(tosaElementsP0);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4));
+ check_output<float>(buffer, expected);
+ }
+ SUBCASE("matmul, set 3, param 1")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "3");
+
+ std::vector<uint32_t> expected = { 0x41800000, 0x41800000, 0x41800000 };
+ std::vector<float> buffer(tosaElementsP1);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4));
+ check_output<float>(buffer, expected);
+ }
+ SUBCASE("matmul, set 4, param 0")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "4");
+
+ std::vector<uint32_t> expected = { 0x0, 0x3f000000, 0x5f14e80c };
+ std::vector<float> buffer(tosaElementsP0);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4));
+ check_output<float>(buffer, expected);
+ }
+ SUBCASE("matmul, set 4, param 1")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "4");
+
+ std::vector<uint32_t> expected = { 0x5d5d0db2, 0xdf2c82a8, 0x0 };
+ std::vector<float> buffer(tosaElementsP1);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4));
+ check_output<float>(buffer, expected);
+ }
+ SUBCASE("matmul, set 5, param 0")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "5");
+
+ std::vector<uint32_t> expected = { 0x5df6c4b3, 0x5e6b4088, 0x5ed0fe71 };
+ std::vector<float> buffer(tosaElementsP0);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP0.c_str(), (void*)buffer.data(), tosaElementsP0 * 4));
+ check_output<float>(buffer, expected);
+ }
+ SUBCASE("matmul, set 5, param 1")
+ {
+ std::string json_cfg = template_json_cfg;
+ update_json_template(json_cfg, "5");
+
+ std::vector<uint32_t> expected = { 0xde086d85, 0x5e630878, 0x5eba5c7b };
+ std::vector<float> buffer(tosaElementsP1);
+ REQUIRE(tgd_generate_data(json_cfg.c_str(), tosaNameP1.c_str(), (void*)buffer.data(), tosaElementsP1 * 4));
+ check_output<float>(buffer, expected);
+ }
+}
+TEST_SUITE_END(); // generate
diff --git a/reference_model/test/verify_tests.cpp b/reference_model/test/verify_tests.cpp
index b75ddec..f36efbf 100644
--- a/reference_model/test/verify_tests.cpp
+++ b/reference_model/test/verify_tests.cpp
@@ -264,6 +264,7 @@ TEST_CASE("positive - reduce product")
"tensors" : {
"out1" : {
"mode": "REDUCE_PRODUCT",
+ "data_type": "FP32",
"reduce_product_info": {
"m": 23,
"n": 8