diff options
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 35 |
1 files changed, 29 insertions, 6 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index b9e2fbe..8d8dac7 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2023, ARM Limited. +// Copyright (c) 2020-2024, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -1684,7 +1684,8 @@ int OpFFT2d<Dtype>::eval() in_real_batch, in_real_height, in_real_width, in_imag_batch, in_imag_height, in_imag_width, out_real_batch, out_real_height, out_real_width, out_imag_batch, out_imag_height, out_imag_width); - OutEigenType sum_real, sum_imag, a, sign_val = 1.0; + OutEigenType sum_real, sum_imag, sign_val = 1.0; + OutEigenType a, a_cos, a_sin, v_ir; if (attribute->inverse()) { @@ -1715,11 +1716,33 @@ int OpFFT2d<Dtype>::eval() { OutEigenType val_real = in_real_val(n, iy, ix); OutEigenType val_imag = in_imag_val(n, iy, ix); - // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType + // Perform the periodic calculation in integer maths to keep + // the accuracy of the co-efficients similar for FP32 normal + // and FP64 precise mode + int32_t ay = (static_cast<int64_t>(iy) * static_cast<int64_t>(oy)) % in_real_height; + int32_t ax = (static_cast<int64_t>(ix) * static_cast<int64_t>(ox)) % in_real_width; + + // Use explicit cast to ensure intermediate calculations are completed using OutEigenType a = sign_val * 2 * M_PI * - ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width); - sum_real += val_real * cos(a) + val_imag * sin(a); - sum_imag += -val_real * sin(a) + val_imag * cos(a); + ((OutEigenType)ay / in_real_height + (OutEigenType)ax / in_real_width); + // Calculate weight values + a_cos = cos(a); + a_sin = sin(a); + if (g_func_config.abs_mode) + { + // Bounded op - Use abs weight values + a_cos = std::abs(a_cos); + a_sin = std::abs(a_sin); + // Bounded op - Use abs real value for imaginary calc + v_ir = val_real; + } + else + { + // Normal op - Use negative real value for imaginary calc + v_ir = -val_real; + } + sum_real += val_real * a_cos + val_imag * a_sin; + sum_imag += v_ir * a_sin + val_imag * a_cos; } } this->out_real->getTensor()(n, oy, ox) = sum_real; |