aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEReorderKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEReorderKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEReorderKernel.cpp41
1 files changed, 33 insertions, 8 deletions
diff --git a/src/core/NEON/kernels/NEReorderKernel.cpp b/src/core/NEON/kernels/NEReorderKernel.cpp
index 6c2c987eb7..f5bea3e163 100644
--- a/src/core/NEON/kernels/NEReorderKernel.cpp
+++ b/src/core/NEON/kernels/NEReorderKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2023 Arm Limited.
+ * Copyright (c) 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,17 +54,41 @@ void NEReorderKernel::run(const Window &window, const ThreadInfo &info)
{
case WeightFormat::OHWIo4:
{
- arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>(
- reinterpret_cast<float *>(_output->buffer()) + jump_rows,
- reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ switch (_output->info()->data_type())
+ {
+ case DataType::F32:
+ arm_gemm::Transform<4, 1, true, arm_gemm::VLType::None>(
+ reinterpret_cast<float *>(_output->buffer()) + jump_rows,
+ reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ break;
+ case DataType::BFLOAT16:
+ arm_gemm::Transform<4, 4, true, arm_gemm::VLType::None>(
+ reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows,
+ reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type!");
+ }
break;
}
#if defined(ARM_COMPUTE_ENABLE_SVE)
case WeightFormat::OHWIo8:
{
- arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>(
- reinterpret_cast<float *>(_output->buffer()) + jump_rows,
- reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ switch (_output->info()->data_type())
+ {
+ case DataType::F32:
+ arm_gemm::Transform<1, 1, true, arm_gemm::VLType::SVE>(
+ reinterpret_cast<float *>(_output->buffer()) + jump_rows,
+ reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ break;
+ case DataType::BFLOAT16:
+ arm_gemm::Transform<2, 4, true, arm_gemm::VLType::SVE>(
+ reinterpret_cast<bfloat16 *>(_output->buffer()) + jump_rows,
+ reinterpret_cast<float *>(_input->buffer()), stride, k_start, k_end, 0, _xmax);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type!");
+ }
break;
}
#endif /* ARM_COMPUTE_ENABLE_SVE */
@@ -175,7 +199,8 @@ Status NEReorderKernel::validate(const ITensorInfo *input,
ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
if (output->tensor_shape().total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() != DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != DataType::F32 && output->data_type() != DataType::BFLOAT16);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(input, output);
// Only input WeightFormat OHWI supported
ARM_COMPUTE_RETURN_ERROR_ON(input_wf != arm_compute::WeightFormat::OHWI);