aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuGemmTranspose1xWKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuGemmTranspose1xWKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuGemmTranspose1xWKernel.cpp43
1 files changed, 24 insertions, 19 deletions
diff --git a/src/cpu/kernels/CpuGemmTranspose1xWKernel.cpp b/src/cpu/kernels/CpuGemmTranspose1xWKernel.cpp
index 62d5d5f5e9..c47746bc4b 100644
--- a/src/cpu/kernels/CpuGemmTranspose1xWKernel.cpp
+++ b/src/cpu/kernels/CpuGemmTranspose1xWKernel.cpp
@@ -24,9 +24,10 @@
#include "src/cpu/kernels/CpuGemmTranspose1xWKernel.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
-#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
@@ -63,9 +64,10 @@ Status CpuGemmTranspose1xWKernel::validate(const ITensorInfo *src, const ITensor
ARM_COMPUTE_RETURN_ERROR_ON(src->data_type() == DataType::UNKNOWN);
//Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src) is not needed here as this kernel doesn't use CPU FP16 instructions.
- if(dst->total_size() != 0)
+ if (dst->total_size() != 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), compute_transpose1xW_with_element_size_shape(*src));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(),
+ compute_transpose1xW_with_element_size_shape(*src));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_QUANTIZATION_INFO(src, dst);
}
@@ -107,25 +109,28 @@ void CpuGemmTranspose1xWKernel::run_op(ITensorPack &tensors, const Window &windo
const size_t out_stride = dst->info()->strides_in_bytes()[1];
const size_t vector_size = 16 / element_size;
- execute_window_loop(window, [&](const Coordinates & id)
- {
- const uint8_t *in_ptr = in.ptr();
- uint8_t *const out_ptr = out.ptr() + (id.y() * vector_size) * element_size + (id.x() / vector_size) * out_stride;
-
- for(size_t k = 0; k < vector_size; ++k)
+ execute_window_loop(
+ window,
+ [&](const Coordinates &id)
{
- // If the src width is not multiple of W, we fill the reference with 0s
- if((id.x() + k) >= in_width)
- {
- std::memset(out_ptr + k * element_size, 0, element_size);
- }
- else
+ const uint8_t *in_ptr = in.ptr();
+ uint8_t *const out_ptr =
+ out.ptr() + (id.y() * vector_size) * element_size + (id.x() / vector_size) * out_stride;
+
+ for (size_t k = 0; k < vector_size; ++k)
{
- std::memcpy(out_ptr + k * element_size, in_ptr + k * element_size, element_size);
+ // If the src width is not multiple of W, we fill the reference with 0s
+ if ((id.x() + k) >= in_width)
+ {
+ std::memset(out_ptr + k * element_size, 0, element_size);
+ }
+ else
+ {
+ std::memcpy(out_ptr + k * element_size, in_ptr + k * element_size, element_size);
+ }
}
- }
- },
- in, out);
+ },
+ in, out);
}
const char *CpuGemmTranspose1xWKernel::name() const