aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/generate/generate_dot_product_states.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/generate/generate_dot_product_states.cc')
-rw-r--r--reference_model/src/generate/generate_dot_product_states.cc48
1 files changed, 40 insertions, 8 deletions
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc
index 9ce32ff..b78be71 100644
--- a/reference_model/src/generate/generate_dot_product_states.cc
+++ b/reference_model/src/generate/generate_dot_product_states.cc
@@ -1,4 +1,4 @@
-// Copyright (c) 2023, ARM Limited.
+// Copyright (c) 2023-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -60,7 +60,7 @@ public:
return pseudo;
}
- uint32_t index()
+ uint32_t nextIndex()
{
return _index;
}
@@ -101,6 +101,11 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS0")
+ return _set_data0.nextIndex();
+ }
private:
uint32_t _p;
@@ -129,6 +134,10 @@ public:
else
return (_B * _B / (_KS + 1)) * v;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -158,6 +167,10 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -186,6 +199,10 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -229,6 +246,11 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS4")
+ return _set_data0.nextIndex();
+ }
private:
uint32_t _p;
@@ -258,6 +280,10 @@ public:
else
return 0.f;
}
+ uint32_t nextIndex()
+ {
+ return _set_data.nextIndex();
+ }
private:
uint32_t _p;
@@ -307,21 +333,27 @@ std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConf
float B = getBoundParameter(cfg.dataType, dpinfo.accType);
if (B > 0.f)
{
+ auto param = cfg.inputPos;
+ if (cfg.opType == Op_FFT2D)
+ {
+ // We only use param of zero for FFT2D tensors
+ param = 0;
+ }
// Create the generator
switch (dpinfo.s)
{
case 0:
- return std::make_unique<GeneratorS0>(cfg.inputPos);
+ return std::make_unique<GeneratorS0>(param);
case 1:
- return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B);
+ return std::make_unique<GeneratorS1>(param, dpinfo.ks, B);
case 2:
- return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks);
+ return std::make_unique<GeneratorS2>(param, dpinfo.ks);
case 3:
- return std::make_unique<GeneratorS3>(cfg.inputPos);
+ return std::make_unique<GeneratorS3>(param);
case 4:
- return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B);
+ return std::make_unique<GeneratorS4>(param, dpinfo.ks, B);
case 5:
- return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B);
+ return std::make_unique<GeneratorS5>(param, dpinfo.ks, B);
default:
WARNING("[Generator][DP] Unsupported dot product test series for generator.");
return nullptr;