aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
authormorgolock <pablo.tello@arm.com>2020-09-29 14:24:32 +0100
committerPablo Marquez <pablo.tello@arm.com>2020-10-08 17:42:54 +0000
commit4adaddbaa633a4025f29f2e0a63c7126d9d7c530 (patch)
tree509da75143dcb2743a8eea2cc11f0a03c180c737 /src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
parentff4fca0d2ae523557a7b31db2014b48391f1d8c3 (diff)
downloadComputeLibrary-4adaddbaa633a4025f29f2e0a63c7126d9d7c530.tar.gz
COMPMID-3170: Remove padding in NEGEMMLowpMatrixMultiplyKernel
Change-Id: Ie95442c6c6a145c1a45937b03cbd433bf08e36ab Signed-off-by: morgolock <pablo.tello@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4094 Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp337
1 files changed, 227 insertions, 110 deletions
diff --git a/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
index c5d7f10e55..f3ba2901cb 100644
--- a/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.cpp
@@ -23,7 +23,6 @@
*/
#include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h"
-#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/Error.h"
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/ITensor.h"
@@ -32,11 +31,7 @@
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
-
#include <arm_neon.h>
-#include <cstddef>
-#include <cstdint>
-#include <tuple>
using namespace arm_compute;
@@ -44,7 +39,7 @@ namespace arm_compute
{
namespace
{
-void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, size_t stride_b, const Window &window)
+void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
{
execute_window_loop(window, [&](const Coordinates & id)
{
@@ -253,15 +248,29 @@ void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &ou
}
auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
- vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
- vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
- vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
- vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
+ if(id.x() < (width_out - 16))
+ {
+ vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
+ vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
+ vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
+ vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
+ }
+ else
+ {
+ auto left_over = width_out - id.x();
+ for(auto k = 0; k < 4 && left_over; ++k)
+ {
+ for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(vec_out + k * 4 + j) = c0.val[k][j];
+ }
+ }
+ }
},
ina, inb, out);
}
-void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, size_t stride_b, const Window &window)
+void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
{
execute_window_loop(window, [&](const Coordinates & id)
{
@@ -469,17 +478,34 @@ void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &ou
}
auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
- vst1q_s32(vec_out + 0, c0.val[0]);
- vst1q_s32(vec_out + 4, c0.val[1]);
- vst1q_s32(vec_out + 8, c0.val[2]);
- vst1q_s32(vec_out + 12, c0.val[3]);
+ if(id.x() < (width_out - 16))
+ {
+ vst1q_s32(vec_out + 0, c0.val[0]);
+ vst1q_s32(vec_out + 4, c0.val[1]);
+ vst1q_s32(vec_out + 8, c0.val[2]);
+ vst1q_s32(vec_out + 12, c0.val[3]);
+ }
+ else
+ {
+ auto left_over = width_out - id.x();
+ for(auto k = 0; k < 4 && left_over; ++k)
+ {
+ for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(vec_out + k * 4 + j) = c0.val[k][j];
+ }
+ }
+ }
},
ina, inb, out);
}
-void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, size_t out_stride, const Window &window)
+void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
{
- execute_window_loop(window, [&](const Coordinates &)
+ const auto width_out = static_cast<int>(out_info.dimension(0));
+ const auto height_out = static_cast<int>(out_info.dimension(1));
+ const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
+ execute_window_loop(window, [&](const Coordinates & id)
{
const uint8_t *mtx_a0 = ina.ptr();
const uint8_t *mtx_b0 = inb.ptr();
@@ -574,32 +600,93 @@ void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int
}
auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
- vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
- vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
- vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
- vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
- vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
- vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
- vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
- vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
- vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
- vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
- vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
- vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
- vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
- vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
- vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
- vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
+
+ if(id.y() < height_out && id.x() < (width_out - 16))
+ {
+ vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
+ vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
+ vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
+ vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
+ if(id.y() + 1 < height_out)
+ {
+ vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
+ vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
+ vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
+ vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
+ if(id.y() + 2 < height_out)
+ {
+ vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
+ vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
+ vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
+ vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
+ if(id.y() + 3 < height_out)
+ {
+ vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
+ vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
+ vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
+ vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
+ }
+ }
+ }
+ }
+ else
+ {
+ const auto left_over_value = width_out - id.x();
+ auto left_over = left_over_value;
+ for(auto k = 0; k < 4 && left_over; ++k)
+ {
+ for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(mtx_out + k * 4 + j) = c0.val[k][j];
+ }
+ }
+ if(id.y() + 1 < height_out)
+ {
+ left_over = left_over_value;
+ for(auto k = 0; k < 4 && left_over; ++k)
+ {
+ for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
+ }
+ }
+ if(id.y() + 2 < height_out)
+ {
+ left_over = left_over_value;
+ for(auto k = 0; k < 4 && left_over; ++k)
+ {
+ for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
+ }
+ }
+ if(id.y() + 3 < height_out)
+ {
+ left_over = left_over_value;
+ for(auto k = 0; k < 4 && left_over; ++k)
+ {
+ for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
+ }
+ }
+ }
+ }
+ }
+ }
},
ina, inb, out);
}
-void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, size_t out_stride, const Window &window)
+void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
{
+ const auto width_out = static_cast<int>(out_info.dimension(0));
+ const auto height_out = static_cast<int>(out_info.dimension(1));
+ const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
// The implementation assumes that the matrix A and Matrix B have been reshaped respectively with NEGEMMInterleave4x4 and NEGEMMTranspose1xW
// The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
// All the values needed for computing a single 4x4 block will be read from consecutive memory positions
- execute_window_loop(window, [&](const Coordinates &)
+ execute_window_loop(window, [&](const Coordinates & id)
{
auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
@@ -692,32 +779,86 @@ void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int
c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
}
-
auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
- vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
- vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
- vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
- vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
- vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
- vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
- vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
- vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
- vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
- vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
- vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
- vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
- vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
- vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
- vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
- vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
+ if(id.y() < height_out && id.x() < (width_out - 16))
+ {
+ vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
+ vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
+ vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
+ vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
+ if(id.y() + 1 < height_out)
+ {
+ vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
+ vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
+ vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
+ vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
+ if(id.y() + 2 < height_out)
+ {
+ vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
+ vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
+ vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
+ vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
+ if(id.y() + 3 < height_out)
+ {
+ vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
+ vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
+ vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
+ vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
+ }
+ }
+ }
+ }
+ else if(id.y() < height_out)
+ {
+ const auto left_over_value = width_out - id.x();
+ auto left_over = left_over_value;
+ for(auto k = 0; k < 4 && left_over; ++k)
+ {
+ for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(mtx_out + k * 4 + j) = c0.val[k][j];
+ }
+ }
+ if(id.y() + 1 < height_out)
+ {
+ left_over = left_over_value;
+ for(auto k = 0; k < 4 && left_over; ++k)
+ {
+ for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
+ }
+ }
+ if(id.y() + 2 < height_out)
+ {
+ left_over = left_over_value;
+ for(auto k = 0; k < 4 && left_over; ++k)
+ {
+ for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
+ }
+ }
+ if(id.y() + 3 < height_out)
+ {
+ left_over = left_over_value;
+ for(auto k = 0; k < 4 && left_over; ++k)
+ {
+ for(auto j = 0; j < 4 && left_over; ++j, --left_over)
+ {
+ *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
+ }
+ }
+ }
+ }
+ }
+ }
+
},
ina, inb, out);
}
} // namespace
-class Coordinates;
-} // namespace arm_compute
-
namespace
{
Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
@@ -748,50 +889,6 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
return Status{};
}
-
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
-{
- constexpr unsigned int num_elems_processed_per_iteration_x = 16;
- constexpr unsigned int num_elems_processed_per_iteration_y = 4;
-
- Window win;
- bool window_changed = false;
-
- // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
- if((output->dimension(1) == 1))
- {
- // Configure kernel window
- win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x));
-
- // We cannot read out-of-bound elements from matrix A as we use the left-over for loop
- AccessWindowStatic in0_access(input0, 0, 0, input0->tensor_shape().x(), 1);
- AccessWindowHorizontal in1_access(input1, 0, num_elems_processed_per_iteration_x);
- AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration_x);
-
- window_changed = update_window_and_padding(win, in0_access, in1_access, output_access);
-
- Coordinates coord;
- coord.set_num_dimensions(output->num_dimensions());
- output_access.set_valid_region(win, ValidRegion(coord, output->tensor_shape()));
- }
- else
- {
- win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
-
- unsigned int num_k_iterations = ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x) / 16;
- // For each iteration of "k" we increment the input pointer by 4, and we load 8 elements a the time:
- AccessWindowStatic in0_access(input0, 0, 0, (num_k_iterations - 1) * 4 + 8, input0->dimension(1));
- AccessWindowStatic in1_access(input1, 0, 0, ceil_to_multiple(input1->dimension(0), num_elems_processed_per_iteration_x), input1->dimension(1));
- AccessWindowRectangle output_access(output, 0, 0, num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y);
-
- window_changed = update_window_and_padding(win, in0_access, in1_access, output_access);
-
- output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
- }
-
- Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
- return std::make_pair(err, win);
-}
} // namespace
NEGEMMLowpMatrixMultiplyKernel::NEGEMMLowpMatrixMultiplyKernel()
@@ -812,16 +909,33 @@ void NEGEMMLowpMatrixMultiplyKernel::configure(const ITensor *input0, const ITen
_output = output;
_slide_matrix_b = in1_shape[2] != 1;
- // Configure kernel window
- auto win_config = validate_and_configure_window(input0->info(), input1->info(), output->info());
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
- INEKernel::configure(win_config.second);
+ constexpr unsigned int num_elems_processed_per_iteration_x = 16;
+ constexpr unsigned int num_elems_processed_per_iteration_y = 4;
+
+ Window win;
+
+ // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
+ if((output->info()->dimension(1) == 1))
+ {
+ // Configure kernel window
+ win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x));
+
+ Coordinates coord;
+ coord.set_num_dimensions(output->info()->num_dimensions());
+ output->info()->set_valid_region(ValidRegion(coord, output->info()->tensor_shape()));
+ }
+ else
+ {
+ win = calculate_max_window(*output->info(), Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
+ output->info()->set_valid_region(ValidRegion(Coordinates(), output->info()->tensor_shape()));
+ }
+
+ INEKernel::configure(win);
}
Status NEGEMMLowpMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first);
return Status{};
}
@@ -837,6 +951,7 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo
{
const auto width_matrix_a = static_cast<int>(_input0->info()->dimension(0));
const auto width_matrix_b = static_cast<int>(_input1->info()->dimension(0));
+ const auto width_out = static_cast<int>(_output->info()->dimension(0));
const auto in_b_stride = static_cast<int>(_input1->info()->strides_in_bytes()[1] / data_size_from_type(_input1->info()->data_type()));
// The implementation computes 16 elements per iteration
@@ -872,13 +987,13 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo
case DataType::S8:
case DataType::QASYMM8_SIGNED:
{
- vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, in_b_stride, window);
+ vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
break;
}
case DataType::U8:
case DataType::QASYMM8:
{
- vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, in_b_stride, window);
+ vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
break;
}
default:
@@ -891,7 +1006,7 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo
else
{
const size_t in_b_stride = _input1->info()->strides_in_bytes()[1];
- const size_t out_stride = _output->info()->strides_in_bytes()[1] / _output->info()->element_size();
+ const int width_b = _input1->info()->dimension(0);
// Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
Window win_a(window);
@@ -914,19 +1029,18 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo
Iterator inb(_input1, win_b);
Iterator out(_output, window);
- const int width_b = _input1->info()->dimension(0);
switch(_input0->info()->data_type())
{
case DataType::S8:
case DataType::QASYMM8_SIGNED:
{
- matrix_multiply_s8(ina, inb, out, width_b, out_stride, window);
+ matrix_multiply_s8(ina, inb, out, width_b, *_output->info(), window);
break;
}
case DataType::U8:
case DataType::QASYMM8:
{
- matrix_multiply_u8(ina, inb, out, width_b, out_stride, window);
+ matrix_multiply_u8(ina, inb, out, width_b, *_output->info(), window);
break;
}
default:
@@ -937,3 +1051,6 @@ void NEGEMMLowpMatrixMultiplyKernel::run(const Window &window, const ThreadInfo
}
}
}
+} // namespace arm_compute
+
+