aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEReductionOperationKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEReductionOperationKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEReductionOperationKernel.cpp38
1 files changed, 4 insertions, 34 deletions
diff --git a/src/core/NEON/kernels/NEReductionOperationKernel.cpp b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
index ffa4fa3565..16cd6f77b4 100644
--- a/src/core/NEON/kernels/NEReductionOperationKernel.cpp
+++ b/src/core/NEON/kernels/NEReductionOperationKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -742,23 +742,8 @@ struct RedOpYZW
for(unsigned int dim = 0; dim < in_info.dimension(axis); ++dim)
{
- T *in_ptr;
- switch(axis)
- {
- case 1:
- in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, dim)));
- break;
- case 2:
- in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, dim)));
- break;
- case 3:
- in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, dim)));
- break;
- default:
- ARM_COMPUTE_ERROR("Not supported");
- }
+ const T *in_ptr = reinterpret_cast<T *>(input.ptr() + in_info.strides_in_bytes()[axis] * dim);
const auto vec_elements = wrapper::vloadq(in_ptr);
-
switch(op)
{
case ReductionOperation::SUM:
@@ -907,23 +892,8 @@ struct RedOpYZW_qasymm8
for(unsigned int index_dim = 0; index_dim < in_info.dimension(axis); ++index_dim)
{
- uint8_t *in_ptr;
- switch(axis)
- {
- case 1:
- in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, index_dim));
- break;
- case 2:
- in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, index_dim));
- break;
- case 3:
- in_ptr = input.ptr() + in_info.offset_element_in_bytes(Coordinates(0, 0, 0, index_dim));
- break;
- default:
- ARM_COMPUTE_ERROR("Not supported");
- }
- const auto vec_elements = wrapper::vloadq(in_ptr);
-
+ const uint8_t *in_ptr = input.ptr() + in_info.strides_in_bytes()[axis] * index_dim;
+ const auto vec_elements = wrapper::vloadq(in_ptr);
switch(op)
{
case ReductionOperation::SUM: