diff options
Diffstat (limited to 'src')
3 files changed, 12 insertions, 39 deletions
diff --git a/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp b/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp index 5bf70dc9bf..5406356bc9 100644 --- a/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp +++ b/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.cpp @@ -81,22 +81,6 @@ Status CpuConvertFullyConnectedWeightsKernel::validate(const ITensorInfo *src, c return Status{}; } -template <typename T> -void CpuConvertFullyConnectedWeightsKernel::run_convert_fc_weights(const ITensor *in, ITensor *out, const Window &window) -{ - const unsigned int dst_stride_x = out->info()->strides_in_bytes().x(); - const unsigned int dst_stride_y = out->info()->strides_in_bytes().y(); - - Iterator input(in, window); - Iterator output(out, window); - - execute_window_loop(window, [&](const Coordinates & id) - { - *reinterpret_cast<T *>(output.ptr() + id.x() * dst_stride_x + (id.y() % _factor1 * _factor2 + id.y() / _factor1) * dst_stride_y) = *reinterpret_cast<T *>(input.ptr()); - }, - input); -} - void CpuConvertFullyConnectedWeightsKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); @@ -106,21 +90,18 @@ void CpuConvertFullyConnectedWeightsKernel::run_op(ITensorPack &tensors, const W const auto src = tensors.get_const_tensor(TensorType::ACL_SRC); auto dst = tensors.get_tensor(TensorType::ACL_DST); - switch(src->info()->element_size()) + const unsigned int dst_stride_x = dst->info()->strides_in_bytes().x(); + const unsigned int dst_stride_y = dst->info()->strides_in_bytes().y(); + const unsigned int element_size = src->info()->element_size(); + + Iterator input(src, window); + Iterator output(dst, window); + + execute_window_loop(window, [&](const Coordinates & id) { - case 1: - run_convert_fc_weights<uint8_t>(src, dst, window); - break; - case 2: - run_convert_fc_weights<uint16_t>(src, dst, window); - break; - case 4: - run_convert_fc_weights<uint32_t>(src, dst, window); - break; - default: - ARM_COMPUTE_ERROR("Data type not supported."); - break; - } + memcpy(output.ptr() + id.x() * dst_stride_x + (id.y() % _factor1 * _factor2 + id.y() / _factor1) * dst_stride_y, input.ptr(), element_size); + }, + input); } const char *CpuConvertFullyConnectedWeightsKernel::name() const diff --git a/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.h b/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.h index 3ba3162c34..7baaf13417 100644 --- a/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.h +++ b/src/core/cpu/kernels/CpuConvertFullyConnectedWeightsKernel.h @@ -69,15 +69,6 @@ public: private: unsigned int _factor1{ 0 }; /* equals to the number of elements per original src plane if @p data_layout == NCHW; its number of channels otherwise */ unsigned int _factor2{ 0 }; /* equals to the number of elements per original src plane if @p data_layout == NHWC; its number of channels otherwise */ - - /** Template function to run the permute - * - * @param[in] in Source weights tensor info to convert. Must be 2 dimensional. Data types supported: All. - * @param[in] out The converted weights tensor info. Shape and Data Type: Same as @p in. - * @param[in] window Region on which to execute the kernel. (Must be a valid region of the window returned by window()). - */ - template <typename T> - void run_convert_fc_weights(const ITensor *in, ITensor *out, const Window &window); }; } // namespace kernels } // namespace cpu diff --git a/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp b/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp index f2253d8be4..1f6b3c94e2 100644 --- a/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp +++ b/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp @@ -23,6 +23,7 @@ */ #include "arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h" +#include "arm_compute/core/Validate.h" #include "src/runtime/cpu/operators/CpuConvertFullyConnectedWeights.h" namespace arm_compute |