diff options
author | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-01-30 16:10:50 +0000 |
---|---|---|
committer | Jeremy Johnson <jeremy.johnson@arm.com> | 2024-02-08 11:12:55 +0000 |
commit | 6f57e6e665094959aed40c0e388ac81fbd118720 (patch) | |
tree | 82fdfa4b40baf370aa346e3d19fa3f1760294ee9 /reference_model/src/generate/generate_dot_product.cc | |
parent | 47ab1762d1c15a7b4c0c068d7294111c5c5f92a2 (diff) | |
download | reference_model-6f57e6e665094959aed40c0e388ac81fbd118720.tar.gz |
Main Compliance: RFFT2D support
Correct ref model to produce imaginery values of zero as specification
indicates at certain output positions.
Fix up precise and abs modes for RFFT2D in ref model to produce correct
results and bounds using abs weights.
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: I33767e4219a260278f7933f28b1799223a95a3cc
Diffstat (limited to 'reference_model/src/generate/generate_dot_product.cc')
-rw-r--r-- | reference_model/src/generate/generate_dot_product.cc | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/reference_model/src/generate/generate_dot_product.cc b/reference_model/src/generate/generate_dot_product.cc index 7337969..117d49d 100644 --- a/reference_model/src/generate/generate_dot_product.cc +++ b/reference_model/src/generate/generate_dot_product.cc @@ -963,6 +963,63 @@ bool generateFFT2D(const TosaReference::GenerateConfig& cfg, return true; } +//---------------------------------------------------------------------------// +// RFFT2D // +//---------------------------------------------------------------------------// + +template <typename DataType> +bool generateRFFT2DReal(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + DataType* data, + size_t size) +{ + const int64_t T = TosaReference::numElementsFromShape(cfg.shape); + const uint32_t H = cfg.shape[1]; + const uint32_t W = cfg.shape[2]; + + for (int64_t t = 0; t < T; ++t) + { + uint32_t x = t % W; + uint32_t y = (t / W) % H; + uint32_t k = y * W + x; + + data[t] = static_cast<DataType>(generator(k)); + } + return true; +} + +bool generateRFFT2D(const TosaReference::GenerateConfig& cfg, + TosaReference::IDotProductGenerator& generator, + void* data, + size_t size) +{ + if (cfg.shape.size() != 3) + { + WARNING("[Generator][DP][RFFT2D] Tensor shape expected 3 dimensions."); + return false; + } + + switch (cfg.dataType) + { + case DType::DType_FP32: { + float* outData = reinterpret_cast<float*>(data); + switch (cfg.inputPos) + { + case 0: + return generateRFFT2DReal(cfg, generator, outData, size); + default: + WARNING("[Generator][DP][RFFT2D] Invalid input tensor slot position to operator."); + return false; + } + break; + } + default: + WARNING("[Generator][DP][RFFT2D] Only supports FP32."); + return false; + } + + return true; +} } // namespace namespace TosaReference @@ -1003,6 +1060,8 @@ bool generateDotProduct(const GenerateConfig& cfg, void* data, size_t size) return generateConv3D(cfg, *generator, data, size); case tosa::Op_FFT2D: return generateFFT2D(cfg, *generator, data, size); + case tosa::Op_RFFT2D: + return generateRFFT2D(cfg, *generator, data, size); default: WARNING("[Generator][DP] Unsupported operator."); return false; |