diff options
Diffstat (limited to 'src/core/NEON/kernels/NEStackLayerKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEStackLayerKernel.cpp | 33 |
1 files changed, 4 insertions, 29 deletions
diff --git a/src/core/NEON/kernels/NEStackLayerKernel.cpp b/src/core/NEON/kernels/NEStackLayerKernel.cpp index 0c33f36983..3447d59bcc 100644 --- a/src/core/NEON/kernels/NEStackLayerKernel.cpp +++ b/src/core/NEON/kernels/NEStackLayerKernel.cpp @@ -87,7 +87,7 @@ inline Coordinates shift_from_axis_and_replace_coordinate(const Coordinates &id, } // namespace NEStackLayerKernel::NEStackLayerKernel() - : _input(nullptr), _output(nullptr), _axis(), _idx_input(), _func(nullptr) + : _input(nullptr), _output(nullptr), _axis(), _idx_input() { } @@ -101,22 +101,6 @@ void NEStackLayerKernel::configure(const ITensor *input, unsigned int axis, unsi _axis = axis; _idx_input = idx_input; - switch(input->info()->element_size()) - { - case 1: - _func = &NEStackLayerKernel::run_stack<uint8_t>; - break; - case 2: - _func = &NEStackLayerKernel::run_stack<uint16_t>; - break; - case 4: - _func = &NEStackLayerKernel::run_stack<uint32_t>; - break; - default: - ARM_COMPUTE_ERROR("Element size not supported"); - break; - } - // Configure kernel window auto win_config = validate_and_configure_window(input->info(), axis, num_tensors, output->info()); @@ -137,15 +121,6 @@ void NEStackLayerKernel::run(const Window &window, const ThreadInfo &info) ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - if(_func != nullptr) - { - (this->*_func)(window); - } -} - -template <typename T> -void NEStackLayerKernel::run_stack(const Window &window) -{ Window window_out; window_out.use_tensor_dimensions(_output->info()->tensor_shape()); @@ -160,9 +135,9 @@ void NEStackLayerKernel::run_stack(const Window &window) execute_window_loop(window, [&](const Coordinates & id) { - Coordinates id_out = shift_from_axis_and_replace_coordinate(id, _axis, _idx_input); - const int idx = id_out[0] * stride_x + id_out[1] * stride_y + id_out[2] * stride_z + id_out[3] * stride_w + id_out[4] * stride_k; - *(reinterpret_cast<T *>(output.ptr() + idx)) = *(reinterpret_cast<const T *>(input.ptr())); + Coordinates id_out = shift_from_axis_and_replace_coordinate(id, _axis, _idx_input); + const int idx = id_out[0] * stride_x + id_out[1] * stride_y + id_out[2] * stride_z + id_out[3] * stride_w + id_out[4] * stride_k; + std::memcpy(output.ptr() + idx, input.ptr(), _input->info()->element_size()); }, input); } |