diff options
Diffstat (limited to 'src/core/NEON/kernels/NEReorderKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEReorderKernel.cpp | 277 |
1 files changed, 277 insertions, 0 deletions
diff --git a/src/core/NEON/kernels/NEReorderKernel.cpp b/src/core/NEON/kernels/NEReorderKernel.cpp new file mode 100644 index 0000000000..fe8882f59f --- /dev/null +++ b/src/core/NEON/kernels/NEReorderKernel.cpp @@ -0,0 +1,277 @@ +/* + * Copyright (c) 2023-2024 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#if defined(__aarch64__) + +#include "src/core/NEON/kernels/NEReorderKernel.h" + +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/runtime/Scheduler.h" + +#include "src/common/utils/Log.h" +#include "src/core/NEON/kernels/arm_gemm/transform.hpp" + +namespace arm_compute +{ + +void NEReorderKernel::run(const Window &window, const ThreadInfo &info) +{ + ARM_COMPUTE_UNUSED(info); + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); + switch (_input->info()->data_type()) + { + case DataType::F32: + { + const int ksize_rows_elements = _xmax * _ksize; + const int jump_rows = ksize_rows_elements * window.x().start(); + const int k_start = window.x().start() * _ksize; + const int k_end = std::min(window.x().end() * _ksize, _kmax); + const int stride = _kmax; + if (k_start < k_end) + { + switch (_output_wf) + { + case WeightFormat::OHWIo4: + { + switch (_output->info()->data_type()) + { + case DataType::F32: + arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>( + reinterpret_cast<float *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + case DataType::BFLOAT16: + arm_gemm::Transform<4, 4, true, arm_gemm::VLType::None>( + reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + default: + ARM_COMPUTE_ERROR("Unsupported data type!"); + } + break; + } +#if defined(ARM_COMPUTE_ENABLE_SVE) + case WeightFormat::OHWIo8: + { + switch (_output->info()->data_type()) + { + case DataType::F32: + arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>( + reinterpret_cast<float *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + case DataType::BFLOAT16: + arm_gemm::Transform<2, 4, true, arm_gemm::VLType::SVE>( + reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + default: + ARM_COMPUTE_ERROR("Unsupported data type!"); + } + break; + } +#endif /* ARM_COMPUTE_ENABLE_SVE */ + default: + { + ARM_COMPUTE_ERROR("Unsupported data type!"); + break; + } + } + } + break; + } + default: + ARM_COMPUTE_ERROR("Unsupported data type!"); + } +} + +NEReorderKernel::NEReorderKernel() + : _input(nullptr), + _output(nullptr), + _ksize(0), + _kmax(0), + _xmax(0), + _input_wf(WeightFormat::ANY), + _output_wf(WeightFormat::ANY) +{ +} + +void NEReorderKernel::configure(const ITensor *input, + ITensor *output, + arm_compute::WeightFormat input_wf, + arm_compute::WeightFormat output_wf) +{ + ARM_COMPUTE_LOG_PARAMS(input, output, input_wf, output_wf); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_ERROR_THROW_ON(validate(input->info(), output->info(), input_wf, output_wf)); + + // Set variables + _input = input; + _output = output; + _input_wf = input_wf; + _output_wf = output_wf; + + // Setting parameters for transform + auto dims = input->info()->num_dimensions(); + switch (dims) + { + case 2: + { + _xmax = input->info()->dimension(0); // Number of columns in input matrix + _kmax = input->info()->dimension(1); // Number of rows in input matrix + break; + } + case 4: + { + _xmax = input->info()->dimension(2); // Number of columns in input matrix + _kmax = input->info()->dimension(3); // Number of rows in input matrix + break; + } + default: + { + ARM_COMPUTE_ERROR("Only 2 or 4 dimensions supported."); + } + } + + // Configure kernel window + // Window size is set by rows / _ksize + Window win; + int window_size = 0; + switch (_output_wf) + { +#if defined(ARM_COMPUTE_ENABLE_SVE) + case WeightFormat::OHWIo8: + { + _ksize = 8; + window_size = _kmax / _ksize; + break; + } +#endif /* ARM_COMPUTE_ENABLE_SVE */ + case WeightFormat::OHWIo4: + { + _ksize = 4; + window_size = _kmax / _ksize; + break; + } + default: + { + ARM_COMPUTE_ERROR("Unsupported weight format."); + break; + } + } + if (_kmax % _ksize != 0) + { + window_size += 1; + } + + win.set(Window::DimX, Window::Dimension(0, window_size, 1)); + + INEKernel::configure(win); +} + +Status NEReorderKernel::validate(const ITensorInfo *input, + const ITensorInfo *output, + arm_compute::WeightFormat input_wf, + arm_compute::WeightFormat output_wf) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); + ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); + if (output->tensor_shape().total_size() != 0) + { + ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != DataType::F32 && output->data_type() != DataType::BFLOAT16); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output); + // Only input WeightFormat OHWI supported + ARM_COMPUTE_RETURN_ERROR_ON(input_wf != arm_compute::WeightFormat::OHWI); + int input_x_dim; + int input_k_dim; + int output_x_dim; + int output_k_dim; + auto dims = output->num_dimensions(); + switch (dims) + { + case 2: + { + input_x_dim = input->dimension(0); // Number of columns in input matrix + input_k_dim = input->dimension(1); // Number of rows in input matrix + output_x_dim = output->dimension(0); // Number of columns in output matrix + output_k_dim = output->dimension(1); // Number of rows in output matrix + break; + } + case 4: + { + input_x_dim = input->dimension(2); // Number of columns in input matrix + input_k_dim = input->dimension(3); // Number of rows in input matrix + output_x_dim = output->dimension(2); // Number of columns in output matrix + output_k_dim = output->dimension(3); // Number of rows in output matrix + break; + } + default: + { + ARM_COMPUTE_RETURN_ERROR_MSG("Only 2 or 4 dimensions supported."); + } + } + + int ksize = 0; + switch (output_wf) + { +#if defined(ARM_COMPUTE_ENABLE_SVE) + case WeightFormat::OHWIo8: + { + if (Scheduler::get().cpu_info().has_sve() && arm_gemm::utils::get_vector_length<float>() == 8) + { + ksize = 8; + } + else + { + ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported weight format."); + } + break; + } +#endif /* ARM_COMPUTE_ENABLE_SVE */ + case WeightFormat::OHWIo4: + { + ksize = 4; + break; + } + default: + { + ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported weight format."); + break; + } + } + + // output k_dim needs to be same as input but multiple of ksize + int32_t rnd_up_input_kdim = arm_compute::ceil_to_multiple<int32_t, int32_t>(input_k_dim, ksize); + ARM_COMPUTE_RETURN_ERROR_ON(rnd_up_input_kdim != output_k_dim); + // output x_dim needs to be same as input + ARM_COMPUTE_RETURN_ERROR_ON(input_x_dim != output_x_dim); + } + return Status{}; +} + +} // namespace arm_compute + +#endif // defined(__aarch64__) |