aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/NEON/kernels/NEGEMMInterleaveBlockedKernel.cpp80
1 files changed, 62 insertions, 18 deletions
diff --git a/src/core/NEON/kernels/NEGEMMInterleaveBlockedKernel.cpp b/src/core/NEON/kernels/NEGEMMInterleaveBlockedKernel.cpp
index a9c624abd0..2a4a46e76c 100644
--- a/src/core/NEON/kernels/NEGEMMInterleaveBlockedKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMInterleaveBlockedKernel.cpp
@@ -40,10 +40,14 @@ using namespace arm_compute;
namespace
{
-inline void gemm_interleave_8bit_elements(const ITensor *input, ITensor *output, const Window &window, unsigned int block_width, unsigned int block_height, bool transpose)
+inline void gemm_interleave_blocked_transposed_8bit(const ITensor *input, ITensor *output, const Window &window, unsigned int block_width, unsigned int block_height)
{
- const size_t in_stride = input->info()->strides_in_bytes()[1];
- const float scale_y_factor = 1.f / float(block_height);
+ const size_t in_stride = input->info()->strides_in_bytes()[1];
+
+ const unsigned int in_height = input->info()->dimension(1);
+ const unsigned int in_width = input->info()->dimension(0);
+
+ const float scale_y_factor = 1.f / float(block_height);
// Set window for output tensor
Window win_out(window);
@@ -52,30 +56,63 @@ inline void gemm_interleave_8bit_elements(const ITensor *input, ITensor *output,
win_out.set_dimension_step(Window::DimX, block_width * block_height);
Iterator out(output, win_out);
+
+ execute_window_loop(window, [&](const Coordinates &)
+ {
+ std::fill_n(out.ptr(), block_width * block_height, 0);
+ },
+ out);
+
execute_window_loop(window, [&](const Coordinates & id)
{
- int j = 0;
- for(unsigned int z = 0; z < block_height; ++z)
+ for(unsigned int z = id.y(); (z < in_width) && z < (id.y() + block_height); ++z)
{
- for(unsigned int b = 0; b < block_width; ++b)
+ int j = (z - id.y()) * block_width;
+ for(unsigned int b = id.x(); (b < in_height) && (b < (id.x() + block_width)); ++b)
{
- if(!transpose)
- {
- const bool inbounds = (id.x() + b) < input->info()->dimension(0) && (id.y() + z) < input->info()->dimension(1);
- *(out.ptr() + j++) = (inbounds) ? *(in.ptr() + z * in_stride + b) : 0;
- }
- else
- {
- const bool inbounds = (id.x() + b) < input->info()->dimension(1) && (id.y() + z) < input->info()->dimension(0);
- const uint8_t value = (inbounds) ? *(input->buffer() + (id.x() + b) * in_stride + (id.y() + z)) : 0;
- *(out.ptr() + j++) = value;
- }
+ *(out.ptr() + j++) = *(input->buffer() + b * in_stride + z);
}
}
},
in, out);
}
+inline void gemm_interleave_blocked_8bit(const ITensor *input, ITensor *output, const Window &window, unsigned int block_width, unsigned int block_height)
+{
+ const size_t in_stride = input->info()->strides_in_bytes()[1];
+
+ const unsigned int in_height = input->info()->dimension(1);
+ const unsigned int in_width = input->info()->dimension(0);
+
+ const float scale_y_factor = 1.f / float(block_height);
+
+ // Set window for output tensor
+ Window win_out(window);
+ win_out.scale(Window::DimY, scale_y_factor);
+ Iterator in(input, window);
+
+ win_out.set_dimension_step(Window::DimX, block_width * block_height);
+ Iterator out(output, win_out);
+
+ execute_window_loop(window, [&](const Coordinates &)
+ {
+ std::fill_n(out.ptr(), block_width * block_height, 0);
+ },
+ out);
+
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ for(unsigned int z = id.y(); (z < in_height) && z < (id.y() + block_height); ++z)
+ {
+ int j = (z - id.y()) * block_width;
+ for(unsigned int b = id.x(); (b < in_width) && (b < (id.x() + block_width)); ++b)
+ {
+ *(out.ptr() + j++) = *(input->buffer() + z * in_stride + b);
+ }
+ }
+ },
+ in, out);
+}
} // namespace
NEGEMMInterleaveBlockedKernel::NEGEMMInterleaveBlockedKernel()
@@ -127,5 +164,12 @@ void NEGEMMInterleaveBlockedKernel::run(const Window &window, const ThreadInfo &
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- gemm_interleave_8bit_elements(_input, _output, window, _block_width, _block_height, _transpose);
+ if(_transpose)
+ {
+ gemm_interleave_blocked_transposed_8bit(_input, _output, window, _block_width, _block_height);
+ }
+ else
+ {
+ gemm_interleave_blocked_8bit(_input, _output, window, _block_width, _block_height);
+ }
}