diff options
Diffstat (limited to 'src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp')
-rw-r--r-- | src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp | 41 |
1 files changed, 11 insertions, 30 deletions
diff --git a/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp b/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp index 5bf70dc9bf..5406356bc9 100644 --- a/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp +++ b/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp @@ -81,22 +81,6 @@ Status CpuConvertFullyConnectedWeightsKernel::validate(const ITensorInfo *src, c return Status{}; } -template <typename T> -void CpuConvertFullyConnectedWeightsKernel::run_convert_fc_weights(const ITensor *in, ITensor *out, const Window &window) -{ - const unsigned int dst_stride_x = out->info()->strides_in_bytes().x(); - const unsigned int dst_stride_y = out->info()->strides_in_bytes().y(); - - Iterator input(in, window); - Iterator output(out, window); - - execute_window_loop(window, [&](const Coordinates & id) - { - *reinterpret_cast<T *>(output.ptr() + id.x() * dst_stride_x + (id.y() % _factor1 * _factor2 + id.y() / _factor1) * dst_stride_y) = *reinterpret_cast<T *>(input.ptr()); - }, - input); -} - void CpuConvertFullyConnectedWeightsKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); @@ -106,21 +90,18 @@ void CpuConvertFullyConnectedWeightsKernel::run_op(ITensorPack &tensors, const W const auto src = tensors.get_const_tensor(TensorType::ACL_SRC); auto dst = tensors.get_tensor(TensorType::ACL_DST); - switch(src->info()->element_size()) + const unsigned int dst_stride_x = dst->info()->strides_in_bytes().x(); + const unsigned int dst_stride_y = dst->info()->strides_in_bytes().y(); + const unsigned int element_size = src->info()->element_size(); + + Iterator input(src, window); + Iterator output(dst, window); + + execute_window_loop(window, [&](const Coordinates & id) { - case 1: - run_convert_fc_weights<uint8_t>(src, dst, window); - break; - case 2: - run_convert_fc_weights<uint16_t>(src, dst, window); - break; - case 4: - run_convert_fc_weights<uint32_t>(src, dst, window); - break; - default: - ARM_COMPUTE_ERROR("Data type not supported."); - break; - } + memcpy(output.ptr() + id.x() * dst_stride_x + (id.y() % _factor1 * _factor2 + id.y() / _factor1) * dst_stride_y, input.ptr(), element_size); + }, + input); } const char *CpuConvertFullyConnectedWeightsKernel::name() const |