aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2023-10-04 14:17:26 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2023-10-05 14:13:36 +0100
commit59b307d5b1090680d1918745ee54c8466df4861d (patch)
treed11d4e0a6f9dbfc55a632803ac86db5656499fed /reference_model/src/generate
parentb20b0c9cb4c85bb9a3c901d5acaf421d84656850 (diff)
downloadreference_model-59b307d5b1090680d1918745ee54c8466df4861d.tar.gz
Expand TOSA MI generator support for MATMUL
Fixed PrimitiveGenerator starting point and added test sets 1-5. Fixed verify_test reduce_product missing data_type. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Iaf080ce5c1adb5819f70d1a285d04baa36016092
Diffstat (limited to 'reference_model/src/generate')
-rw-r--r--reference_model/src/generate/generate_dot_product_states.cc231
-rw-r--r--reference_model/src/generate/generate_utils.cc2
2 files changed, 217 insertions, 16 deletions
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc
index cd9ffba..d3eeb6d 100644
--- a/reference_model/src/generate/generate_dot_product_states.cc
+++ b/reference_model/src/generate/generate_dot_product_states.cc
@@ -15,6 +15,7 @@
#include "generate_dot_product.h"
#include "generate_utils.h"
+#include <cmath>
#include <cstdint>
namespace
@@ -23,6 +24,7 @@ namespace
// Input index global variables
inline constexpr uint32_t P0 = 0;
inline constexpr uint32_t P1 = 1;
+inline constexpr uint32_t P2 = 2;
// Unused helper function
template <typename... Args>
@@ -48,10 +50,12 @@ public:
[[nodiscard]] float operator()()
{
- _r = _r * _m + 1;
float sign = (_r >> 31) == 0 ? +1 : -1;
float pseudo = sign * (float)(_r & 0x7FFFFFFF) / (float)(0x7FFFFFFF);
+
+ // Move index and calculate r value for the next index
++_index;
+ _r = _r * _m + 1;
return pseudo;
}
@@ -69,35 +73,210 @@ private:
};
//----------------------------------------------------------------------------//
-// State generators
+// State generators - equivalent to tosa_mi_data() in the TOSA specification
+//
+// Each call to the generator returns the next generated value with an
+// auto incrementing index
//----------------------------------------------------------------------------//
-// S0 generator
+// Test set 0 generator
+// The aim of this generator is to check that sum of products with zero gives zero result.
class GeneratorS0 : public TosaReference::IDotProductGenerator
{
public:
GeneratorS0(uint32_t p)
: _p(p)
- , _s0(0) // set_data(2*S)
- , _s1(1) // set_data(2*S+1)
+ , _set_data0(2 * 0)
+ , _set_data1(2 * 0 + 1)
{}
float operator()(uint32_t k) override
{
unused(k);
- const float s0 = _s0();
- const float s1 = _s1();
+ const float s0 = _set_data0();
+ const float s1 = _set_data1();
if (_p == P0)
return s0 < 0.f ? 0.f : s1;
- else
+ else if (_p == P1)
return s0 < 0.f ? s1 : 0.f;
+ else
+ return 0.f;
+ }
+
+private:
+ uint32_t _p;
+ PrimitiveGenerator _set_data0;
+ PrimitiveGenerator _set_data1;
+};
+
+// Test set 1 generator
+// The aim of this test set is to check values with large exponents.
+class GeneratorS1 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS1(uint32_t p, uint32_t KS, float B)
+ : _p(p)
+ , _KS(KS)
+ , _B(B)
+ , _set_data(3 * 1 + p)
+ {}
+ float operator()(uint32_t k) override
+ {
+ unused(k);
+ const float s = _set_data();
+ float v = 0.75f + 0.25f * s;
+ if (_p != P2)
+ return (_B / std::sqrt(_KS + 1)) * v;
+ else
+ return (_B * _B / (_KS + 1)) * v;
+ }
+
+private:
+ uint32_t _p;
+ uint32_t _KS;
+ float _B;
+ PrimitiveGenerator _set_data;
+};
+
+// Test set 2 generator
+// The aim of this test set is to check rounding error when accumulating small values
+// onto a large value. In this case the small values are of similar magnitude. If the
+// implementation changes the order of the sum, then the test data must also be reordered
+// so that the largest values occur first in the sum.
+class GeneratorS2 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS2(uint32_t p, uint32_t KS)
+ : _p(p)
+ , _KS(KS)
+ , _set_data(2 * 2 + p)
+ {}
+ float operator()(uint32_t k) override
+ {
+ const float s = _set_data();
+ if (_p != P2)
+ return k == 0 ? 1.f : s / std::sqrt(_KS);
+ else
+ return 0.f;
}
private:
uint32_t _p;
- PrimitiveGenerator _s0;
- PrimitiveGenerator _s1;
+ uint32_t _KS;
+ PrimitiveGenerator _set_data;
};
+// Test set 3 generator
+// The aim of this test set is to check rounding error when accumulating small values
+// onto a large value. In this case the small values are of varying magnitude. If the
+// implementation changes the order of the sum, then the test data must also be reordered
+// so that the largest values occur first in the sum.
+class GeneratorS3 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS3(uint32_t p)
+ : _p(p)
+ , _set_data(2 * 3 + p)
+ {}
+ float operator()(uint32_t k) override
+ {
+ const float s0 = _set_data();
+ const float s1 = _set_data();
+ if (_p != P2)
+ return k == 0 ? 16.f : std::exp(2 * s0) * s1;
+ else
+ return 0.f;
+ }
+
+private:
+ uint32_t _p;
+ PrimitiveGenerator _set_data;
+};
+
+// Test set 4 generator
+// The aim of this test set is to check a mixture of zero and non-zero products.
+class GeneratorS4 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS4(uint32_t p, uint32_t KS, float B)
+ : _p(p)
+ , _KS(KS)
+ , _B(B)
+ , _set_data0(2 * 4 + 0)
+ , _set_data1(2 * 4 + 1)
+ {}
+ float operator()(uint32_t k) override
+ {
+ const float s0 = _set_data0();
+ const float s1 = _set_data1();
+ if (_p == P0)
+ return (k == _KS / 2) ? +0.5f : s0 < 0 ? 0.f : (_B / std::sqrt(_KS)) * s1;
+ else if (_p == P1)
+ return (k == _KS / 2) ? -0.5f : s0 < 0 ? (_B / std::sqrt(_KS)) * s1 : 0.f;
+ else
+ return 0.f;
+ }
+
+private:
+ uint32_t _p;
+ uint32_t _KS;
+ float _B;
+ PrimitiveGenerator _set_data0;
+ PrimitiveGenerator _set_data1;
+};
+
+// Test set 5 generator
+// The aim of this test set is to check signed inputs of large range.
+class GeneratorS5 : public TosaReference::IDotProductGenerator
+{
+public:
+ GeneratorS5(uint32_t p, uint32_t KS, float B)
+ : _p(p)
+ , _KS(KS)
+ , _B(B)
+ , _set_data(3 * 5 + p)
+ {}
+ float operator()(uint32_t k) override
+ {
+ unused(k);
+ const float s = _set_data();
+ if (_p != P2)
+ return (_B / std::sqrt(_KS + 1)) * s;
+ else
+ return (_B * _B / (_KS + 1)) * s;
+ }
+
+private:
+ uint32_t _p;
+ uint32_t _KS;
+ float _B;
+ PrimitiveGenerator _set_data;
+};
+
+float getBoundParameter(const DType& dataType, const DType& accType)
+{
+ // Work out the bounds parameter value B for the given data and accumulator types
+ // Returns value > 0.f on success
+ float B = 0.f;
+ if (dataType == DType::DType_FP16)
+ {
+ if (accType == DType::DType_FP16)
+ B = 255.875f; // (1<<8) - (1/8);
+ else if (accType == DType::DType_FP32)
+ B = 65504.f; // (1<<16) - (1<<5);
+ }
+ else if (dataType == DType::DType_BF16)
+ {
+ if (accType == DType::DType_FP32)
+ B = 18374686479671623680.f; // (1<<64) - (1<<56)
+ }
+ else if (dataType == DType::DType_FP32)
+ {
+ if (accType == DType::DType_FP32)
+ B = 18446742974197923840.f; // (1<<64) - (1<<40)
+ }
+ return B;
+}
+
} // namespace
namespace TosaReference
@@ -105,15 +284,35 @@ namespace TosaReference
std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConfig& cfg)
{
+ // Generators can only support 3 inputs
+ if (cfg.inputPos > 2)
+ return nullptr;
+
const DotProductInfo& dpinfo = cfg.dotProductInfo;
- switch (dpinfo.s)
+
+ float B = getBoundParameter(cfg.dataType, dpinfo.accType);
+ if (B > 0.f)
{
- case 0:
- return std::make_unique<GeneratorS0>(cfg.inputPos);
- default:
- return nullptr;
+ // Create the generator
+ switch (dpinfo.s)
+ {
+ case 0:
+ return std::make_unique<GeneratorS0>(cfg.inputPos);
+ case 1:
+ return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B);
+ case 2:
+ return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks);
+ case 3:
+ return std::make_unique<GeneratorS3>(cfg.inputPos);
+ case 4:
+ return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B);
+ case 5:
+ return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B);
+ default:
+ return nullptr;
+ }
}
return nullptr;
}
-} // namespace TosaReference \ No newline at end of file
+} // namespace TosaReference
diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc
index c52f051..c32d0fb 100644
--- a/reference_model/src/generate/generate_utils.cc
+++ b/reference_model/src/generate/generate_utils.cc
@@ -110,6 +110,8 @@ std::optional<GenerateConfig> parseGenerateConfig(const char* json, const char*
int64_t numElementsFromShape(const std::vector<int32_t>& shape)
{
+ // Rank 0 shapes have no entries and so this will return 1
+ // Other ranked shapes will return the product of their dimensions
return std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<int64_t>());
}