diff options
Diffstat (limited to 'reference_model/src/generate')
4 files changed, 131 insertions, 9 deletions
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc index fe829e3..046007e 100644 --- a/reference_model/src/generate/generate_dot_product.cc +++ b/reference_model/src/generate/generate_dot_product.cc @@ -736,6 +736,92 @@ bool generateTransposeConv2D(const TosaReference::GenerateConfig& cfg, return false; } } +//---------------------------------------------------------------------------// +// FFT2D // +//---------------------------------------------------------------------------// + +template <typename DataType> +bool generateFFT2DReal(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + DataType* data, + size_t size) +{ + const int64_t T = TosaReference::numElementsFromShape(cfg.shape); + const uint32_t H = cfg.shape[1]; + const uint32_t W = cfg.shape[2]; + + for (int64_t t = 0; t < T; ++t) + { + uint32_t x = t % W; + uint32_t y = (t / W) % H; + uint32_t k = y * W + x; + + data[t] = static_cast<DataType>(generator(k)); + } + return true; +} + +template <typename DataType> +bool generateFFT2DImag(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + DataType* data, + size_t size) +{ + const int64_t T = TosaReference::numElementsFromShape(cfg.shape); + const uint32_t H = cfg.shape[1]; + const uint32_t W = cfg.shape[2]; + + // The index expression of ((1*N+n)*H+y)*W+x in the spec equates to + // using the values after those used for the Real tensor, but we need + // to iterate through all those values to get to the Imaginary data + for (int64_t n = 0; n < 2; ++n) + { + for (int64_t t = 0; t < T; ++t) + { + uint32_t x = t % W; + uint32_t y = (t / W) % H; + uint32_t k = y * W + x; + + data[t] = static_cast<DataType>(generator(k)); + } + } + return true; +} + +bool generateFFT2D(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + void* data, + size_t size) +{ + if (cfg.shape.size() != 3) + { + WARNING("[Generator][DP][FFT2D] Tensor shape expected 3 dimensions."); + return false; + } + + switch (cfg.dataType) + { + case DType::DType_FP32: { + float* outData = reinterpret_cast<float*>(data); + switch (cfg.inputPos) + { + case 0: + return generateFFT2DReal(cfg, generator, outData, size); + case 1: + return generateFFT2DImag(cfg, generator, outData, size); + default: + WARNING("[Generator][DP][FFT2D] Invalid input tensor slot position to operator."); + return false; + } + break; + } + default: + WARNING("[Generator][DP][FFT2D] Only supports FP32."); + return false; + } + + return true; +} } // namespace namespace TosaReference @@ -772,6 +858,8 @@ bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size) return generateDepthwiseConv2D(cfg, *generator, data, size); case tosa::Op_TRANSPOSE_CONV2D: return generateTransposeConv2D(cfg, *generator, data, size); + case tosa::Op_FFT2D: + return generateFFT2D(cfg, *generator, data, size); default: WARNING("[Generator][DP] Unsupported operator."); return false; diff --git a/reference_model/src/generate/generate_dot_product.h b/reference_model/src/generate/generate_dot_product.h index cd9d4ba..bf1b1ff 100644 --- a/reference_model/src/generate/generate_dot_product.h +++ b/reference_model/src/generate/generate_dot_product.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ class IDotProductGenerator public: virtual float operator()(uint32_t k) = 0; virtual ~IDotProductGenerator() = default; + virtual uint32_t nextIndex() = 0; }; /// \brief Dot-product stage generator selector diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc index 9ce32ff..b78be71 100644 --- a/reference_model/src/generate/generate_dot_product_states.cc +++ b/reference_model/src/generate/generate_dot_product_states.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -60,7 +60,7 @@ public: return pseudo; } - uint32_t index() + uint32_t nextIndex() { return _index; } @@ -101,6 +101,11 @@ public: else return 0.f; } + uint32_t nextIndex() + { + ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS0") + return _set_data0.nextIndex(); + } private: uint32_t _p; @@ -129,6 +134,10 @@ public: else return (_B * _B / (_KS + 1)) * v; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -158,6 +167,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -186,6 +199,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -229,6 +246,11 @@ public: else return 0.f; } + uint32_t nextIndex() + { + ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS4") + return _set_data0.nextIndex(); + } private: uint32_t _p; @@ -258,6 +280,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -307,21 +333,27 @@ std::unique_ptr<IDotProductGenerator> pickDotProductGenerator(const GenerateConf float B = getBoundParameter(cfg.dataType, dpinfo.accType); if (B > 0.f) { + auto param = cfg.inputPos; + if (cfg.opType == Op_FFT2D) + { + // We only use param of zero for FFT2D tensors + param = 0; + } // Create the generator switch (dpinfo.s) { case 0: - return std::make_unique<GeneratorS0>(cfg.inputPos); + return std::make_unique<GeneratorS0>(param); case 1: - return std::make_unique<GeneratorS1>(cfg.inputPos, dpinfo.ks, B); + return std::make_unique<GeneratorS1>(param, dpinfo.ks, B); case 2: - return std::make_unique<GeneratorS2>(cfg.inputPos, dpinfo.ks); + return std::make_unique<GeneratorS2>(param, dpinfo.ks); case 3: - return std::make_unique<GeneratorS3>(cfg.inputPos); + return std::make_unique<GeneratorS3>(param); case 4: - return std::make_unique<GeneratorS4>(cfg.inputPos, dpinfo.ks, B); + return std::make_unique<GeneratorS4>(param, dpinfo.ks, B); case 5: - return std::make_unique<GeneratorS5>(cfg.inputPos, dpinfo.ks, B); + return std::make_unique<GeneratorS5>(param, dpinfo.ks, B); default: WARNING("[Generator][DP] Unsupported dot product test series for generator."); return nullptr; diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index a8b472a..2e40b04 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -54,6 +54,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op, { Op::Op_ERF, "ERF" }, { Op::Op_EXP, "EXP" }, { Op::Op_FLOOR, "FLOOR" }, + { Op::Op_FFT2D, "FFT2D" }, { Op::Op_FULLY_CONNECTED, "FULLY_CONNECTED" }, { Op::Op_GATHER, "GATHER" }, { Op::Op_GREATER, "GREATER" }, |