aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/DFT.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/DFT.cpp')
-rw-r--r--tests/validation/reference/DFT.cpp34
1 files changed, 24 insertions, 10 deletions
diff --git a/tests/validation/reference/DFT.cpp b/tests/validation/reference/DFT.cpp
index 7221312641..2b03c270ac 100644
--- a/tests/validation/reference/DFT.cpp
+++ b/tests/validation/reference/DFT.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2020 ARM Limited.
+ * Copyright (c) 2019-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "Permute.h"
#include "Reverse.h"
#include "SliceOperations.h"
+#include "support/ToolchainSupport.h"
#include <cmath>
@@ -267,8 +268,9 @@ SimpleTensor<T> complex_mul_and_reduce(const SimpleTensor<T> &input, const Simpl
output_shape.set(2, Co);
SimpleTensor<T> dst(output_shape, input.data_type(), input.num_channels());
- // MemSet dst memory to zero
- std::memset(dst.data(), 0, dst.size());
+ // dst memory to zero
+ const auto total_element_count = dst.num_channels() * dst.num_elements();
+ std::fill_n(dst.data(), total_element_count, 0);
for(uint32_t b = 0; b < N; ++b)
{
@@ -318,7 +320,7 @@ SimpleTensor<T> ridft_1d(const SimpleTensor<T> &src, bool is_odd)
{
auto dst = rdft_1d_core(src, FFTDirection::Inverse, is_odd);
- const T scaling_factor = dst.shape()[0];
+ const T scaling_factor = T(dst.shape()[0]);
scale(dst, scaling_factor);
return dst;
@@ -330,7 +332,7 @@ SimpleTensor<T> dft_1d(const SimpleTensor<T> &src, FFTDirection direction)
auto dst = dft_1d_core(src, direction);
if(direction == FFTDirection::Inverse)
{
- const T scaling_factor = dst.shape()[0];
+ const T scaling_factor = T(dst.shape()[0]);
scale(dst, scaling_factor);
}
return dst;
@@ -359,7 +361,7 @@ SimpleTensor<T> ridft_2d(const SimpleTensor<T> &src, bool is_odd)
auto transposed_2 = permute(first_pass, PermutationVector(1U, 0U));
auto dst = rdft_1d_core(transposed_2, direction, is_odd);
- const T scaling_factor = dst.shape()[0] * dst.shape()[1];
+ const T scaling_factor = T(dst.shape()[0] * dst.shape()[1]);
scale(dst, scaling_factor);
return dst;
}
@@ -383,7 +385,7 @@ SimpleTensor<T> dft_2d(const SimpleTensor<T> &src, FFTDirection direction)
auto transposed_2 = permute(first_pass, PermutationVector(1U, 0U));
auto dst = dft_1d_core(transposed_2, direction);
- const T scaling_factor = dst.shape()[0] * dst.shape()[1];
+ const T scaling_factor = T(dst.shape()[0] * dst.shape()[1]);
scale(dst, scaling_factor);
return dst;
@@ -398,10 +400,10 @@ SimpleTensor<T> conv2d_dft(const SimpleTensor<T> &src, const SimpleTensor<T> &w,
auto padded_src = pad_layer(src, padding_in);
// Flip weights
- std::vector<uint32_t> axis_v = { 0, 1 };
- SimpleTensor<uint32_t> axis{ TensorShape(2U), DataType::U32 };
+ std::vector<uint32_t> axis_v = { 0, 1 };
+ SimpleTensor<int32_t> axis{ TensorShape(2U), DataType::S32 };
std::copy(axis_v.begin(), axis_v.begin() + axis.shape().x(), axis.data());
- auto flipped_w = reverse(w, axis);
+ auto flipped_w = reverse(w, axis, /* use_inverted_axis */ false);
// Pad weights to have the same size as input
const PaddingList paddings_w = { { 0, src.shape()[0] - 1 }, { 0, src.shape()[1] - 1 } };
@@ -425,6 +427,7 @@ SimpleTensor<T> conv2d_dft(const SimpleTensor<T> &src, const SimpleTensor<T> &w,
return slice(conv_res, Coordinates(start_left, start_top), Coordinates(end_right, end_botton));
}
+// FP32
template SimpleTensor<float> rdft_1d(const SimpleTensor<float> &src);
template SimpleTensor<float> ridft_1d(const SimpleTensor<float> &src, bool is_odd);
template SimpleTensor<float> dft_1d(const SimpleTensor<float> &src, FFTDirection direction);
@@ -434,6 +437,17 @@ template SimpleTensor<float> ridft_2d(const SimpleTensor<float> &src, bool is_od
template SimpleTensor<float> dft_2d(const SimpleTensor<float> &src, FFTDirection direction);
template SimpleTensor<float> conv2d_dft(const SimpleTensor<float> &src, const SimpleTensor<float> &w, const PadStrideInfo &conv_info);
+
+// FP16
+template SimpleTensor<half> rdft_1d(const SimpleTensor<half> &src);
+template SimpleTensor<half> ridft_1d(const SimpleTensor<half> &src, bool is_odd);
+template SimpleTensor<half> dft_1d(const SimpleTensor<half> &src, FFTDirection direction);
+
+template SimpleTensor<half> rdft_2d(const SimpleTensor<half> &src);
+template SimpleTensor<half> ridft_2d(const SimpleTensor<half> &src, bool is_odd);
+template SimpleTensor<half> dft_2d(const SimpleTensor<half> &src, FFTDirection direction);
+
+template SimpleTensor<half> conv2d_dft(const SimpleTensor<half> &src, const SimpleTensor<half> &w, const PadStrideInfo &conv_info);
} // namespace reference
} // namespace validation
} // namespace test