aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
diff options
context:
space:
mode:
authorIoan-Cristian Szabo <ioan-cristian.szabo@arm.com>2017-11-30 17:17:17 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:47:40 +0000
commitb4e3e1c371d8091e86ee1c6e704057559bbe1554 (patch)
treed072c9f9d7471e4df9ef5aa6b50cb09c35b0c361 /src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
parentc1b6e37233e0ebd21cb44bf8863a09c0ba5feeb1 (diff)
downloadComputeLibrary-b4e3e1c371d8091e86ee1c6e704057559bbe1554.tar.gz
COMPMID-617: Add validate support for NEON FullyConnectedLayer
Change-Id: I08987022c8d4cc335c00b8af27bd3edb8fe64d3b Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111596 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Alexander Gilday <alexander.gilday@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp85
1 files changed, 71 insertions, 14 deletions
diff --git a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
index aa5e2dd0dd..69b052a9bd 100644
--- a/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
+++ b/src/core/NEON/kernels/NEGEMMMatrixMultiplyKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,6 +36,8 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+
#include <arm_neon.h>
#include <cstddef>
#include <cstdint>
@@ -1409,27 +1411,73 @@ void matrix_matrix_multiply_qs16(const ITensor *input0, const ITensor *input1, I
ina, inb, out);
}
-Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
+inline Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
{
+ ARM_COMPUTE_UNUSED(alpha);
+
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::F16, DataType::F32, DataType::QS8, DataType::QS16);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1, output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, input1, output);
- ARM_COMPUTE_UNUSED(input0);
- ARM_COMPUTE_UNUSED(input1);
- ARM_COMPUTE_UNUSED(output);
- if(output->dimension(1) == 1)
+ if(!is_interleaved)
{
ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(0) != input1->dimension(1));
+
+ if(output->total_size() != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(input1->dimension(0) != output->dimension(0));
+ ARM_COMPUTE_RETURN_ERROR_ON(input0->dimension(1) != output->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
+ }
+ }
+ else
+ {
+ const int m = reshape_info.m();
+ const int n = reshape_info.n();
+ const int k = reshape_info.k();
+ const int mult_transpose1xW_width = reshape_info.mult_transpose1xW_width();
+ const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
+
+ /* Interleave */
+ TensorShape tensor_shape0{ input0->tensor_shape() };
+ tensor_shape0.set(0, k);
+ tensor_shape0.set(1, m);
+
+ const TensorInfo tensor_info0 = input0->clone()->set_tensor_shape(tensor_shape0);
+ const TensorInfo tensor_info_reshaped0 = input0->clone()->set_tensor_shape(misc::shape_calculator::compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input0, &tensor_info_reshaped0);
+
+ if(n != 0) /* Transpose */
+ {
+ TensorShape tensor_shape1{ input1->tensor_shape() };
+ tensor_shape1.set(0, n);
+ tensor_shape1.set(1, k);
+
+ const TensorInfo tensor_info1 = input1->clone()->set_tensor_shape(tensor_shape1);
+ const TensorInfo tensor_info_reshaped1 = input1->clone()->set_tensor_shape(misc::shape_calculator::compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, &tensor_info_reshaped1);
+ }
+
+ if(output->total_size() != 0)
+ {
+ if(n != 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(0) != static_cast<size_t>(n));
+ }
+ ARM_COMPUTE_RETURN_ERROR_ON(output->dimension(1) != static_cast<size_t>(m));
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input0, output);
+ }
}
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
+inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input0, ITensorInfo *input1, ITensorInfo *output)
{
- Window win = Window();
- bool window_changed = false;
+ bool window_changed{};
+ Window win{};
unsigned int num_elems_processed_per_iteration_x = 0;
const unsigned int num_elems_processed_per_iteration_y = 4;
@@ -1538,11 +1586,19 @@ NEGEMMMatrixMultiplyKernel::NEGEMMMatrixMultiplyKernel()
{
}
-void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha)
+void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor *input1, ITensor *output, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
{
- // Perform validate step
ARM_COMPUTE_ERROR_ON_NULLPTR(input0, input1, output);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info()));
+
+ // Output tensor auto inizialitation if not yet initialized
+ TensorShape tensor_shape{ input0->info()->tensor_shape() };
+ tensor_shape.set(0, is_interleaved ? reshape_info.n() : input1->info()->dimension(0));
+ tensor_shape.set(1, is_interleaved ? reshape_info.m() : input0->info()->dimension(1));
+
+ auto_init_if_empty(*output->info(), input0->info()->clone()->set_tensor_shape(tensor_shape));
+
+ // Perform validate step
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input0->info(), input1->info(), output->info(), alpha, is_interleaved, reshape_info));
_input0 = input0;
_input1 = input1;
@@ -1555,9 +1611,10 @@ void NEGEMMMatrixMultiplyKernel::configure(const ITensor *input0, const ITensor
INEKernel::configure(win_config.second);
}
-Status NEGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output)
+Status NEGEMMMatrixMultiplyKernel::validate(const ITensorInfo *input0, const ITensorInfo *input1, const ITensorInfo *output, float alpha, bool is_interleaved,
+ const GEMMReshapeInfo &reshape_info)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input0, input1, output, alpha, is_interleaved, reshape_info));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input0->clone().get(), input1->clone().get(), output->clone().get()).first);
return Status{};