diff options
Diffstat (limited to 'reference_model/src/generate/generate_dot_product_states.cc')
-rw-r--r-- | reference_model/src/generate/generate_dot_product_states.cc | 48 |
1 files changed, 40 insertions, 8 deletions
diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc index 9ce32ff..b78be71 100644 --- a/reference_model/src/generate/generate_dot_product_states.cc +++ b/reference_model/src/generate/generate_dot_product_states.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -60,7 +60,7 @@ public: return pseudo; } - uint32_t index() + uint32_t nextIndex() { return _index; } @@ -101,6 +101,11 @@ public: else return 0.f; } + uint32_t nextIndex() + { + ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS0") + return _set_data0.nextIndex(); + } private: uint32_t _p; @@ -129,6 +134,10 @@ public: else return (_B * _B / (_KS + 1)) * v; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -158,6 +167,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -186,6 +199,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -229,6 +246,11 @@ public: else return 0.f; } + uint32_t nextIndex() + { + ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS4") + return _set_data0.nextIndex(); + } private: uint32_t _p; @@ -258,6 +280,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -307,21 +333,27 @@ std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConf float B = getBoundParameter(cfg.dataType, dpinfo.accType); if (B > 0.f) { + auto param = cfg.inputPos; + if (cfg.opType == Op_FFT2D) + { + // We only use param of zero for FFT2D tensors + param = 0; + } // Create the generator switch (dpinfo.s) { case 0: - return std::make_unique<GeneratorS0>(cfg.inputPos); + return std::make_unique<GeneratorS0>(param); case 1: - return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B); + return std::make_unique<GeneratorS1>(param, dpinfo.ks, B); case 2: - return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks); + return std::make_unique<GeneratorS2>(param, dpinfo.ks); case 3: - return std::make_unique<GeneratorS3>(cfg.inputPos); + return std::make_unique<GeneratorS3>(param); case 4: - return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B); + return std::make_unique<GeneratorS4>(param, dpinfo.ks, B); case 5: - return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B); + return std::make_unique<GeneratorS5>(param, dpinfo.ks, B); default: WARNING("[Generator][DP] Unsupported dot product test series for generator."); return nullptr; |