From c8330811352f753e36f2ee7be4c7d0e6002f21e7 Mon Sep 17 00:00:00 2001 From: Jeremy Johnson Date: Thu, 18 Jan 2024 16:57:28 +0000 Subject: Main Compliance: FFT2D support Improve access to DOT_PRODUCT generator index and location for debugging. Enable multiple result files for compliance and improve output. Fix up precise and abs modes for FFT2D in ref model to produce correct results and bounds using abs weights. Signed-off-by: Jeremy Johnson Change-Id: Ide0c9f9f80397e5f1e07ca30a1036d6014b5784d --- .../src/generate/generate_dot_product.cc | 88 ++++++++++++++++++++++ .../src/generate/generate_dot_product.h | 3 +- .../src/generate/generate_dot_product_states.cc | 48 ++++++++++-- reference_model/src/generate/generate_utils.cc | 1 + 4 files changed, 131 insertions(+), 9 deletions(-) (limited to 'reference_model/src/generate') diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc index fe829e3..046007e 100644 --- a/reference_model/src/generate/generate_dot_product.cc +++ b/reference_model/src/generate/generate_dot_product.cc @@ -736,6 +736,92 @@ bool generateTransposeConv2D(const TosaReference::GenerateConfig& cfg, return false; } } +//---------------------------------------------------------------------------// +// FFT2D // +//---------------------------------------------------------------------------// + +template +bool generateFFT2DReal(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + DataType* data, + size_t size) +{ + const int64_t T = TosaReference::numElementsFromShape(cfg.shape); + const uint32_t H = cfg.shape[1]; + const uint32_t W = cfg.shape[2]; + + for (int64_t t = 0; t < T; ++t) + { + uint32_t x = t % W; + uint32_t y = (t / W) % H; + uint32_t k = y * W + x; + + data[t] = static_cast(generator(k)); + } + return true; +} + +template +bool generateFFT2DImag(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + DataType* data, + size_t size) +{ + const int64_t T = TosaReference::numElementsFromShape(cfg.shape); + const uint32_t H = cfg.shape[1]; + const uint32_t W = cfg.shape[2]; + + // The index expression of ((1*N+n)*H+y)*W+x in the spec equates to + // using the values after those used for the Real tensor, but we need + // to iterate through all those values to get to the Imaginary data + for (int64_t n = 0; n < 2; ++n) + { + for (int64_t t = 0; t < T; ++t) + { + uint32_t x = t % W; + uint32_t y = (t / W) % H; + uint32_t k = y * W + x; + + data[t] = static_cast(generator(k)); + } + } + return true; +} + +bool generateFFT2D(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + void* data, + size_t size) +{ + if (cfg.shape.size() != 3) + { + WARNING("[Generator][DP][FFT2D] Tensor shape expected 3 dimensions."); + return false; + } + + switch (cfg.dataType) + { + case DType::DType_FP32: { + float* outData = reinterpret_cast(data); + switch (cfg.inputPos) + { + case 0: + return generateFFT2DReal(cfg, generator, outData, size); + case 1: + return generateFFT2DImag(cfg, generator, outData, size); + default: + WARNING("[Generator][DP][FFT2D] Invalid input tensor slot position to operator."); + return false; + } + break; + } + default: + WARNING("[Generator][DP][FFT2D] Only supports FP32."); + return false; + } + + return true; +} } // namespace namespace TosaReference @@ -772,6 +858,8 @@ bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size) return generateDepthwiseConv2D(cfg, *generator, data, size); case tosa::Op_TRANSPOSE_CONV2D: return generateTransposeConv2D(cfg, *generator, data, size); + case tosa::Op_FFT2D: + return generateFFT2D(cfg, *generator, data, size); default: WARNING("[Generator][DP] Unsupported operator."); return false; diff --git a/reference_model/src/generate/generate_dot_product.h b/reference_model/src/generate/generate_dot_product.h index cd9d4ba..bf1b1ff 100644 --- a/reference_model/src/generate/generate_dot_product.h +++ b/reference_model/src/generate/generate_dot_product.h @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -28,6 +28,7 @@ class IDotProductGenerator public: virtual float operator()(uint32_t k) = 0; virtual ~IDotProductGenerator() = default; + virtual uint32_t nextIndex() = 0; }; /// \brief Dot-product stage generator selector diff --git a/reference_model/src/generate/generate_dot_product_states.cc b/reference_model/src/generate/generate_dot_product_states.cc index 9ce32ff..b78be71 100644 --- a/reference_model/src/generate/generate_dot_product_states.cc +++ b/reference_model/src/generate/generate_dot_product_states.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2023, ARM Limited. +// Copyright (c) 2023-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -60,7 +60,7 @@ public: return pseudo; } - uint32_t index() + uint32_t nextIndex() { return _index; } @@ -101,6 +101,11 @@ public: else return 0.f; } + uint32_t nextIndex() + { + ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS0") + return _set_data0.nextIndex(); + } private: uint32_t _p; @@ -129,6 +134,10 @@ public: else return (_B * _B / (_KS + 1)) * v; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -158,6 +167,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -186,6 +199,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -229,6 +246,11 @@ public: else return 0.f; } + uint32_t nextIndex() + { + ASSERT_MSG(_set_data0.nextIndex() == _set_data1.nextIndex(), "Internal index inconsistency in GeneratorS4") + return _set_data0.nextIndex(); + } private: uint32_t _p; @@ -258,6 +280,10 @@ public: else return 0.f; } + uint32_t nextIndex() + { + return _set_data.nextIndex(); + } private: uint32_t _p; @@ -307,21 +333,27 @@ std::unique_ptr pickDotProductGenerator(const GenerateConf float B = getBoundParameter(cfg.dataType, dpinfo.accType); if (B > 0.f) { + auto param = cfg.inputPos; + if (cfg.opType == Op_FFT2D) + { + // We only use param of zero for FFT2D tensors + param = 0; + } // Create the generator switch (dpinfo.s) { case 0: - return std::make_unique(cfg.inputPos); + return std::make_unique(param); case 1: - return std::make_unique(cfg.inputPos, dpinfo.ks, B); + return std::make_unique(param, dpinfo.ks, B); case 2: - return std::make_unique(cfg.inputPos, dpinfo.ks); + return std::make_unique(param, dpinfo.ks); case 3: - return std::make_unique(cfg.inputPos); + return std::make_unique(param); case 4: - return std::make_unique(cfg.inputPos, dpinfo.ks, B); + return std::make_unique(param, dpinfo.ks, B); case 5: - return std::make_unique(cfg.inputPos, dpinfo.ks, B); + return std::make_unique(param, dpinfo.ks, B); default: WARNING("[Generator][DP] Unsupported dot product test series for generator."); return nullptr; diff --git a/reference_model/src/generate/generate_utils.cc b/reference_model/src/generate/generate_utils.cc index a8b472a..2e40b04 100644 --- a/reference_model/src/generate/generate_utils.cc +++ b/reference_model/src/generate/generate_utils.cc @@ -54,6 +54,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Op, { Op::Op_ERF, "ERF" }, { Op::Op_EXP, "EXP" }, { Op::Op_FLOOR, "FLOOR" }, + { Op::Op_FFT2D, "FFT2D" }, { Op::Op_FULLY_CONNECTED, "FULLY_CONNECTED" }, { Op::Op_GATHER, "GATHER" }, { Op::Op_GREATER, "GREATER" }, -- cgit v1.2.1