diff options
author | Renato Arantes <renato.arantes@arm.com> | 2023-11-23 11:12:51 +0000 |
---|---|---|
committer | Jakub Sujak <jakub.sujak@arm.com> | 2024-01-12 10:17:30 +0000 |
commit | 0eb9cfbe6e78f80008164cb0ee18afa09a7fe4eb (patch) | |
tree | 2dcfa82fdc905c91b9263e02815dc3394063ca18 /src/core/NEON/kernels/NEReorderKernel.cpp | |
parent | c5df0c6c5d41a1c4c42ed9b9106d4a2c87689b38 (diff) | |
download | ComputeLibrary-0eb9cfbe6e78f80008164cb0ee18afa09a7fe4eb.tar.gz |
[ONCPUML-1387] Add ACL based reorder for f32 to bf16 data type conversion.
The reorders supported at the moment are:
ab->BA4b4a
ab->BA8b4a
Co-Authored-By: David Mansell <David.Mansell@arm.com>
Change-Id: Ic466465629ce3bcdcee0089e251485b79b60e1f3
Signed-off-by: Renato Arantes <renato.arantes@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10775
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Jakub Sujak <jakub.sujak@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEReorderKernel.cpp')
-rw-r--r-- | src/core/NEON/kernels/NEReorderKernel.cpp | 41 |
1 files changed, 33 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/NEReorderKernel.cpp b/src/core/NEON/kernels/NEReorderKernel.cpp index 6c2c987eb7..f5bea3e163 100644 --- a/src/core/NEON/kernels/NEReorderKernel.cpp +++ b/src/core/NEON/kernels/NEReorderKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -54,17 +54,41 @@ void NEReorderKernel::run(const Window &window, const ThreadInfo &info) { case WeightFormat::OHWIo4: { - arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>( - reinterpret_cast<float *>(_output->buffer()) + jump_rows, - reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + switch (_output->info()->data_type()) + { + case DataType::F32: + arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>( + reinterpret_cast<float *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + case DataType::BFLOAT16: + arm_gemm::Transform<4, 4, true, arm_gemm::VLType::None>( + reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + default: + ARM_COMPUTE_ERROR("Unsupported data type!"); + } break; } #if defined(ARM_COMPUTE_ENABLE_SVE) case WeightFormat::OHWIo8: { - arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>( - reinterpret_cast<float *>(_output->buffer()) + jump_rows, - reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + switch (_output->info()->data_type()) + { + case DataType::F32: + arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>( + reinterpret_cast<float *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + case DataType::BFLOAT16: + arm_gemm::Transform<2, 4, true, arm_gemm::VLType::SVE>( + reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows, + reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax); + break; + default: + ARM_COMPUTE_ERROR("Unsupported data type!"); + } break; } #endif /* ARM_COMPUTE_ENABLE_SVE */ @@ -175,7 +199,8 @@ Status NEReorderKernel::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN); if (output->tensor_shape().total_size() != 0) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != DataType::F32 && output->data_type() != DataType::BFLOAT16); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output); // Only input WeightFormat OHWI supported ARM_COMPUTE_RETURN_ERROR_ON(input_wf != arm_compute::WeightFormat::OHWI); |