aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEWeightsReshapeKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEWeightsReshapeKernel.cpp128
1 files changed, 49 insertions, 79 deletions
diff --git a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp
index 4a0cf27592..624833adfb 100644
--- a/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp
+++ b/src/core/NEON/kernels/NEWeightsReshapeKernel.cpp
@@ -34,59 +34,6 @@ using namespace arm_compute;
namespace
{
-template <typename T>
-void weights_reshape(const ITensor *input, const ITensor *bias, ITensor *output, const Window &window)
-{
- const unsigned int kernel_size_x = input->info()->dimension(0);
- const unsigned int kernel_size_y = input->info()->dimension(1);
- const unsigned int kernel_depth = input->info()->dimension(2);
- const unsigned int input_stride_x = input->info()->strides_in_bytes().x();
- const unsigned int input_stride_y = input->info()->strides_in_bytes().y();
- const unsigned int input_stride_z = input->info()->strides_in_bytes().z();
- const unsigned int output_stride_y = output->info()->strides_in_bytes().y();
-
- // Create iterators
- Iterator in(input, window);
- execute_window_loop(window, [&](const Coordinates & id)
- {
- // Get column index
- const int kernel_idx = id[3];
- const int kernel_idz = id[4];
-
- // Setup pointers
- const uint8_t *tmp_input_ptr = in.ptr();
- uint8_t *tmp_output_ptr = output->ptr_to_element(Coordinates(kernel_idx, 0, kernel_idz));
- const uint8_t *curr_input_row_ptr = tmp_input_ptr;
- const uint8_t *curr_input_depth_ptr = tmp_input_ptr;
-
- // Linearize volume
- for(unsigned int d = 0; d < kernel_depth; ++d)
- {
- for(unsigned int j = 0; j < kernel_size_y; ++j)
- {
- for(unsigned int i = 0; i < kernel_size_x; ++i)
- {
- *(reinterpret_cast<T *>(tmp_output_ptr)) = *(reinterpret_cast<const T *>(tmp_input_ptr));
- tmp_input_ptr += input_stride_x;
- tmp_output_ptr += output_stride_y;
- }
- curr_input_row_ptr += input_stride_y;
- tmp_input_ptr = curr_input_row_ptr;
- }
- curr_input_depth_ptr += input_stride_z;
- curr_input_row_ptr = curr_input_depth_ptr;
- tmp_input_ptr = curr_input_depth_ptr;
- }
-
- // Add bias
- if(bias != nullptr)
- {
- *(reinterpret_cast<T *>(tmp_output_ptr)) = *(reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(kernel_idx, kernel_idz))));
- }
- },
- in);
-}
-
TensorShape get_output_shape(const ITensorInfo *input, bool has_bias)
{
TensorShape output_shape{ input->tensor_shape() };
@@ -141,7 +88,7 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
} // namespace
NEWeightsReshapeKernel::NEWeightsReshapeKernel()
- : _func(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr)
+ : _input(nullptr), _bias(nullptr), _output(nullptr)
{
}
@@ -161,30 +108,6 @@ void NEWeightsReshapeKernel::configure(const ITensor *input, const ITensor *bias
_bias = bias;
_output = output;
- switch(_input->info()->element_size())
- {
- case 4:
- {
- _func = &weights_reshape<uint32_t>;
- break;
- }
- case 2:
- {
- _func = &weights_reshape<uint16_t>;
- break;
- }
- case 1:
- {
- _func = &weights_reshape<uint8_t>;
- break;
- }
- default:
- {
- ARM_COMPUTE_ERROR_ON("Element size not supported");
- break;
- }
- }
-
// Configure kernel
auto win_config = validate_and_configure_window(input->info(), output->info());
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
@@ -205,5 +128,52 @@ void NEWeightsReshapeKernel::run(const Window &window, const ThreadInfo &info)
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- (*_func)(_input, _bias, _output, window);
+ const unsigned int kernel_size_x = _input->info()->dimension(0);
+ const unsigned int kernel_size_y = _input->info()->dimension(1);
+ const unsigned int kernel_depth = _input->info()->dimension(2);
+ const unsigned int input_stride_x = _input->info()->strides_in_bytes().x();
+ const unsigned int input_stride_y = _input->info()->strides_in_bytes().y();
+ const unsigned int input_stride_z = _input->info()->strides_in_bytes().z();
+ const unsigned int output_stride_y = _output->info()->strides_in_bytes().y();
+
+ // Create iterators
+ Iterator in(_input, window);
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ // Get column index
+ const int kernel_idx = id[3];
+ const int kernel_idz = id[4];
+
+ // Setup pointers
+ const uint8_t *tmp_input_ptr = in.ptr();
+ uint8_t *tmp_output_ptr = _output->ptr_to_element(Coordinates(kernel_idx, 0, kernel_idz));
+ const uint8_t *curr_input_row_ptr = tmp_input_ptr;
+ const uint8_t *curr_input_depth_ptr = tmp_input_ptr;
+
+ // Linearize volume
+ for(unsigned int d = 0; d < kernel_depth; ++d)
+ {
+ for(unsigned int j = 0; j < kernel_size_y; ++j)
+ {
+ for(unsigned int i = 0; i < kernel_size_x; ++i)
+ {
+ std::memcpy(tmp_output_ptr, tmp_input_ptr, _input->info()->element_size());
+ tmp_input_ptr += input_stride_x;
+ tmp_output_ptr += output_stride_y;
+ }
+ curr_input_row_ptr += input_stride_y;
+ tmp_input_ptr = curr_input_row_ptr;
+ }
+ curr_input_depth_ptr += input_stride_z;
+ curr_input_row_ptr = curr_input_depth_ptr;
+ tmp_input_ptr = curr_input_depth_ptr;
+ }
+
+ // Add bias
+ if(_bias != nullptr)
+ {
+ std::memcpy(tmp_output_ptr, _bias->ptr_to_element(Coordinates(kernel_idx, kernel_idz)), _input->info()->element_size());
+ }
+ },
+ in);
}