aboutsummaryrefslogtreecommitdiff
path: root/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp')
-rw-r--r--src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp41
1 files changed, 11 insertions, 30 deletions
diff --git a/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp b/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp
index 5bf70dc9bf..5406356bc9 100644
--- a/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp
+++ b/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp
@@ -81,22 +81,6 @@ Status CpuConvertFullyConnectedWeightsKernel::validate(const ITensorInfo *src, c
return Status{};
}
-template <typename T>
-void CpuConvertFullyConnectedWeightsKernel::run_convert_fc_weights(const ITensor *in, ITensor *out, const Window &window)
-{
- const unsigned int dst_stride_x = out->info()->strides_in_bytes().x();
- const unsigned int dst_stride_y = out->info()->strides_in_bytes().y();
-
- Iterator input(in, window);
- Iterator output(out, window);
-
- execute_window_loop(window, [&](const Coordinates & id)
- {
- *reinterpret_cast<T *>(output.ptr() + id.x() * dst_stride_x + (id.y() % _factor1 * _factor2 + id.y() / _factor1) * dst_stride_y) = *reinterpret_cast<T *>(input.ptr());
- },
- input);
-}
-
void CpuConvertFullyConnectedWeightsKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
@@ -106,21 +90,18 @@ void CpuConvertFullyConnectedWeightsKernel::run_op(ITensorPack &tensors, const W
const auto src = tensors.get_const_tensor(TensorType::ACL_SRC);
auto dst = tensors.get_tensor(TensorType::ACL_DST);
- switch(src->info()->element_size())
+ const unsigned int dst_stride_x = dst->info()->strides_in_bytes().x();
+ const unsigned int dst_stride_y = dst->info()->strides_in_bytes().y();
+ const unsigned int element_size = src->info()->element_size();
+
+ Iterator input(src, window);
+ Iterator output(dst, window);
+
+ execute_window_loop(window, [&](const Coordinates & id)
{
- case 1:
- run_convert_fc_weights<uint8_t>(src, dst, window);
- break;
- case 2:
- run_convert_fc_weights<uint16_t>(src, dst, window);
- break;
- case 4:
- run_convert_fc_weights<uint32_t>(src, dst, window);
- break;
- default:
- ARM_COMPUTE_ERROR("Data type not supported.");
- break;
- }
+ memcpy(output.ptr() + id.x() * dst_stride_x + (id.y() % _factor1 * _factor2 + id.y() / _factor1) * dst_stride_y, input.ptr(), element_size);
+ },
+ input);
}
const char *CpuConvertFullyConnectedWeightsKernel::name() const