diff options
Diffstat (limited to 'src/cpu/kernels/CpuWeightsReshapeKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuWeightsReshapeKernel.cpp | 86 |
1 files changed, 45 insertions, 41 deletions
diff --git a/src/cpu/kernels/CpuWeightsReshapeKernel.cpp b/src/cpu/kernels/CpuWeightsReshapeKernel.cpp index 2ccc977995..297ba63826 100644 --- a/src/cpu/kernels/CpuWeightsReshapeKernel.cpp +++ b/src/cpu/kernels/CpuWeightsReshapeKernel.cpp @@ -25,6 +25,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/Validate.h" + #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" @@ -38,7 +39,7 @@ namespace { TensorShape get_output_shape(const ITensorInfo *src, bool has_bias) { - TensorShape output_shape{ src->tensor_shape() }; + TensorShape output_shape{src->tensor_shape()}; output_shape.collapse(3); const size_t tmp_dim = output_shape[0]; @@ -54,20 +55,22 @@ Status validate_arguments(const ITensorInfo *src, const ITensorInfo *biases, con //Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src) is not needed here as this kernel doesn't use CPU FP16 instructions. ARM_COMPUTE_RETURN_ERROR_ON(src->data_type() == DataType::UNKNOWN); - if(biases != nullptr) + if (biases != nullptr) { ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(src->data_type())); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, biases); ARM_COMPUTE_RETURN_ERROR_ON((src->num_dimensions() == 4) && (biases->num_dimensions() != 1)); ARM_COMPUTE_RETURN_ERROR_ON((src->num_dimensions() == 5) && (biases->num_dimensions() != 2)); ARM_COMPUTE_RETURN_ERROR_ON((src->num_dimensions() == 4) && (biases->dimension(0) != src->tensor_shape()[3])); - ARM_COMPUTE_RETURN_ERROR_ON((src->num_dimensions() == 5) && (biases->dimension(0) != src->tensor_shape()[3] || biases->dimension(1) != src->tensor_shape()[4])); + ARM_COMPUTE_RETURN_ERROR_ON((src->num_dimensions() == 5) && (biases->dimension(0) != src->tensor_shape()[3] || + biases->dimension(1) != src->tensor_shape()[4])); } // Checks performed when output is configured - if(dst->total_size() != 0) + if (dst->total_size() != 0) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), get_output_shape(src, biases != nullptr)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), + get_output_shape(src, biases != nullptr)); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst); } @@ -84,9 +87,7 @@ void CpuWeightsReshapeKernel::configure(const ITensorInfo *src, const ITensorInf auto_init_if_empty(*dst, src->clone()->set_tensor_shape(get_output_shape(src, (biases != nullptr)))); // Perform validation step - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, - biases, - dst)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, biases, dst)); // Configure kernel Window window = calculate_max_window(*src, Steps()); @@ -122,44 +123,47 @@ void CpuWeightsReshapeKernel::run_op(ITensorPack &tensors, const Window &window, // Create iterators Iterator in(src, window); - execute_window_loop(window, [&](const Coordinates & id) - { - // Get column index - const int kernel_idx = id[3]; - const int kernel_idz = id[4]; - - // Setup pointers - const uint8_t *tmp_input_ptr = in.ptr(); - uint8_t *tmp_output_ptr = dst->ptr_to_element(Coordinates(kernel_idx, 0, kernel_idz)); - const uint8_t *curr_input_row_ptr = tmp_input_ptr; - const uint8_t *curr_input_depth_ptr = tmp_input_ptr; - - // Linearize volume - for(unsigned int d = 0; d < kernel_depth; ++d) + execute_window_loop( + window, + [&](const Coordinates &id) { - for(unsigned int j = 0; j < kernel_size_y; ++j) + // Get column index + const int kernel_idx = id[3]; + const int kernel_idz = id[4]; + + // Setup pointers + const uint8_t *tmp_input_ptr = in.ptr(); + uint8_t *tmp_output_ptr = dst->ptr_to_element(Coordinates(kernel_idx, 0, kernel_idz)); + const uint8_t *curr_input_row_ptr = tmp_input_ptr; + const uint8_t *curr_input_depth_ptr = tmp_input_ptr; + + // Linearize volume + for (unsigned int d = 0; d < kernel_depth; ++d) { - for(unsigned int i = 0; i < kernel_size_x; ++i) + for (unsigned int j = 0; j < kernel_size_y; ++j) { - std::memcpy(tmp_output_ptr, tmp_input_ptr, src->info()->element_size()); - tmp_input_ptr += input_stride_x; - tmp_output_ptr += output_stride_y; + for (unsigned int i = 0; i < kernel_size_x; ++i) + { + std::memcpy(tmp_output_ptr, tmp_input_ptr, src->info()->element_size()); + tmp_input_ptr += input_stride_x; + tmp_output_ptr += output_stride_y; + } + curr_input_row_ptr += input_stride_y; + tmp_input_ptr = curr_input_row_ptr; } - curr_input_row_ptr += input_stride_y; - tmp_input_ptr = curr_input_row_ptr; + curr_input_depth_ptr += input_stride_z; + curr_input_row_ptr = curr_input_depth_ptr; + tmp_input_ptr = curr_input_depth_ptr; } - curr_input_depth_ptr += input_stride_z; - curr_input_row_ptr = curr_input_depth_ptr; - tmp_input_ptr = curr_input_depth_ptr; - } - // Add bias - if(biases != nullptr) - { - std::memcpy(tmp_output_ptr, biases->ptr_to_element(Coordinates(kernel_idx, kernel_idz)), src->info()->element_size()); - } - }, - in); + // Add bias + if (biases != nullptr) + { + std::memcpy(tmp_output_ptr, biases->ptr_to_element(Coordinates(kernel_idx, kernel_idz)), + src->info()->element_size()); + } + }, + in); } const char *CpuWeightsReshapeKernel::name() const { @@ -167,4 +171,4 @@ const char *CpuWeightsReshapeKernel::name() const } } // namespace kernels } // namespace cpu -} // namespace arm_compute
\ No newline at end of file +} // namespace arm_compute |