aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc35
1 files changed, 29 insertions, 6 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b9e2fbe..8d8dac7 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2023, ARM Limited.
+// Copyright (c) 2020-2024, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -1684,7 +1684,8 @@ int OpFFT2d<Dtype>::eval()
in_real_batch, in_real_height, in_real_width, in_imag_batch, in_imag_height, in_imag_width,
out_real_batch, out_real_height, out_real_width, out_imag_batch, out_imag_height, out_imag_width);
- OutEigenType sum_real, sum_imag, a, sign_val = 1.0;
+ OutEigenType sum_real, sum_imag, sign_val = 1.0;
+ OutEigenType a, a_cos, a_sin, v_ir;
if (attribute->inverse())
{
@@ -1715,11 +1716,33 @@ int OpFFT2d<Dtype>::eval()
{
OutEigenType val_real = in_real_val(n, iy, ix);
OutEigenType val_imag = in_imag_val(n, iy, ix);
- // Use explicit cast to ensure intermmediate calculations are completed using OutEigenType
+ // Perform the periodic calculation in integer maths to keep
+ // the accuracy of the co-efficients similar for FP32 normal
+ // and FP64 precise mode
+ int32_t ay = (static_cast<int64_t>(iy) * static_cast<int64_t>(oy)) % in_real_height;
+ int32_t ax = (static_cast<int64_t>(ix) * static_cast<int64_t>(ox)) % in_real_width;
+
+ // Use explicit cast to ensure intermediate calculations are completed using OutEigenType
a = sign_val * 2 * M_PI *
- ((iy * (OutEigenType)oy) / in_real_height + (ix * (OutEigenType)ox) / in_real_width);
- sum_real += val_real * cos(a) + val_imag * sin(a);
- sum_imag += -val_real * sin(a) + val_imag * cos(a);
+ ((OutEigenType)ay / in_real_height + (OutEigenType)ax / in_real_width);
+ // Calculate weight values
+ a_cos = cos(a);
+ a_sin = sin(a);
+ if (g_func_config.abs_mode)
+ {
+ // Bounded op - Use abs weight values
+ a_cos = std::abs(a_cos);
+ a_sin = std::abs(a_sin);
+ // Bounded op - Use abs real value for imaginary calc
+ v_ir = val_real;
+ }
+ else
+ {
+ // Normal op - Use negative real value for imaginary calc
+ v_ir = -val_real;
+ }
+ sum_real += val_real * a_cos + val_imag * a_sin;
+ sum_imag += v_ir * a_sin + val_imag * a_cos;
}
}
this->out_real->getTensor()(n, oy, ox) = sum_real;