aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEReorderKernel.cpp
diff options
context:
space:
mode:
authorRenato Arantes <renato.arantes@arm.com>2023-11-23 11:12:51 +0000
committerJakub Sujak <jakub.sujak@arm.com>2024-01-12 10:17:30 +0000
commit0eb9cfbe6e78f80008164cb0ee18afa09a7fe4eb (patch)
tree2dcfa82fdc905c91b9263e02815dc3394063ca18 /src/core/NEON/kernels/NEReorderKernel.cpp
parentc5df0c6c5d41a1c4c42ed9b9106d4a2c87689b38 (diff)
downloadComputeLibrary-0eb9cfbe6e78f80008164cb0ee18afa09a7fe4eb.tar.gz
[ONCPUML-1387] Add ACL based reorder for f32 to bf16 data type conversion.
The reorders supported at the moment are: ab->BA4b4a ab->BA8b4a Co-Authored-By: David Mansell <David.Mansell@arm.com> Change-Id: Ic466465629ce3bcdcee0089e251485b79b60e1f3 Signed-off-by: Renato Arantes <renato.arantes@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/10775 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Jakub Sujak <jakub.sujak@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEReorderKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEReorderKernel.cpp41
1 files changed, 33 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/NEReorderKernel.cpp b/src/core/NEON/kernels/NEReorderKernel.cpp
index 6c2c987eb7..f5bea3e163 100644
--- a/src/core/NEON/kernels/NEReorderKernel.cpp
+++ b/src/core/NEON/kernels/NEReorderKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,17 +54,41 @@ void NEReorderKernel::run(const Window &window, const ThreadInfo &info)
{
case WeightFormat::OHWIo4:
{
- arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>(
- reinterpret_cast<float *>(_output->buffer()) + jump_rows,
- reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ switch (_output->info()->data_type())
+ {
+ case DataType::F32:
+ arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>(
+ reinterpret_cast<float *>(_output->buffer()) + jump_rows,
+ reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ break;
+ case DataType::BFLOAT16:
+ arm_gemm::Transform<4, 4, true, arm_gemm::VLType::None>(
+ reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows,
+ reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type!");
+ }
break;
}
#if defined(ARM_COMPUTE_ENABLE_SVE)
case WeightFormat::OHWIo8:
{
- arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>(
- reinterpret_cast<float *>(_output->buffer()) + jump_rows,
- reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ switch (_output->info()->data_type())
+ {
+ case DataType::F32:
+ arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>(
+ reinterpret_cast<float *>(_output->buffer()) + jump_rows,
+ reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ break;
+ case DataType::BFLOAT16:
+ arm_gemm::Transform<2, 4, true, arm_gemm::VLType::SVE>(
+ reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows,
+ reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type!");
+ }
break;
}
#endif /* ARM_COMPUTE_ENABLE_SVE */
@@ -175,7 +199,8 @@ Status NEReorderKernel::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
if (output->tensor_shape().total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != DataType::F32 && output->data_type() != DataType::BFLOAT16);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
// Only input WeightFormat OHWI supported
ARM_COMPUTE_RETURN_ERROR_ON(input_wf != arm_compute::WeightFormat::OHWI);