diff options
Diffstat (limited to 'src/core/NEON/kernels/NEReorderKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEReorderKernel.cpp | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/NEReorderKernel.cpp b/src/core/NEON/kernels/NEReorderKernel.cpp index 6c2c987eb7..f5bea3e163 100644 --- a/src/core/NEON/kernels/NEReorderKernel.cpp +++ b/src/core/NEON/kernels/NEReorderKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -54,17 +54,41 @@ void NEReorderKernel::run(const Window &window, const ThreadInfo &info) { case WeightFormat::OHWIo4: { - arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>( - reinterpret_cast<float *>(_output->buffer()) + jump_rows, - reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + switch (_output->info()->data_type()) + { + case DataType::F32: + arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>( + reinterpret_cast<float *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + case DataType::BFLOAT16: + arm_gemm::Transform<4, 4, true, arm_gemm::VLType::None>( + reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + default: + ARM_COMPUTE_ERROR("Unsupported data type!"); + } break; } #if defined(ARM_COMPUTE_ENABLE_SVE) case WeightFormat::OHWIo8: { - arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>( - reinterpret_cast<float *>(_output->buffer()) + jump_rows, - reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + switch (_output->info()->data_type()) + { + case DataType::F32: + arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>( + reinterpret_cast<float *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + case DataType::BFLOAT16: + arm_gemm::Transform<2, 4, true, arm_gemm::VLType::SVE>( + reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + default: + ARM_COMPUTE_ERROR("Unsupported data type!"); + } break; } #endif /* ARM_COMPUTE_ENABLE_SVE */ @@ -175,7 +199,8 @@ Status NEReorderKernel::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); if (output->tensor_shape().total_size() != 0) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != DataType::F32 && output->data_type() != DataType::BFLOAT16); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output); // Only input WeightFormat OHWI supported ARM_COMPUTE_RETURN_ERROR_ON(input_wf != arm_compute::WeightFormat::OHWI); |