aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp21
1 files changed, 15 insertions, 6 deletions
diff --git a/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp b/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
index 845fcef4f3..b2ffb01e99 100644
--- a/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
+++ b/src/core/NEON/kernels/NEFFTDigitReverseKernel.cpp
@@ -37,7 +37,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(idx, 1, DataType::U32);
- ARM_COMPUTE_RETURN_ERROR_ON(axis != 0);
+ ARM_COMPUTE_RETURN_ERROR_ON(axis > 1);
// Checks performed when output is configured
if((output != nullptr) && (output->total_size() != 0))
@@ -96,15 +96,24 @@ void NEFFTDigitReverseKernel::run(const Window &window, const ThreadInfo &info)
Iterator out(_output, window);
const size_t element_size = _input->info()->element_size();
+ // Pointers to the buffers
+ const size_t offset = _input->info()->offset_first_element_in_bytes();
+ auto *idx_ptr = reinterpret_cast<unsigned int *>(_idx->buffer());
+ uint8_t *input_ptr = offset + _input->buffer();
+
+ // Strides
+ const size_t stride_x = _input->info()->strides_in_bytes()[0];
+ const size_t stride_y = _input->info()->strides_in_bytes()[1];
+ const size_t stride_z = _input->info()->strides_in_bytes()[2];
+ const size_t stride_w = _input->info()->strides_in_bytes()[3];
+
execute_window_loop(window, [&](const Coordinates & id)
{
- unsigned int in_index_1d = *reinterpret_cast<unsigned int *>(_idx->ptr_to_element(Coordinates(id.x())));
-
- auto reverse_id = id;
+ unsigned int in_index_1d = idx_ptr[id[_axis]];
+ auto reverse_id = id;
reverse_id.set(_axis, in_index_1d);
- memcpy(out.ptr(), _input->ptr_to_element(reverse_id), 2 * element_size);
-
+ memcpy(out.ptr(), input_ptr + reverse_id.x() * stride_x + reverse_id.y() * stride_y + reverse_id.z() * stride_z + reverse_id[3] * stride_w, element_size);
},
out);