aboutsummaryrefslogtreecommitdiff
path: root/src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2018-04-13 14:28:08 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:54 +0000
commit164b65d3c8f61f1d6d404fb484c1998a20a2cbda (patch)
treeb60b9f49066ca8c008726dd193e4e0bd56ac1168 /src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp
parent0cbb927ac309e332ac6e6f1ab9170f041f0138ab (diff)
downloadComputeLibrary-164b65d3c8f61f1d6d404fb484c1998a20a2cbda.tar.gz
COMPMID-1043: Rework GCGEMMMatrixMultiplyKernel interface and allow auto initialization of the tensors
This patch also: - removes support for already reshaped weights in GCConvolutionLayer - makes GCConvolutionLayer similar to CLGEMMConvolutionLayer - enables usage of the GCGEMM function in GCConvolution instead of calling the GEMM kernels directly Change-Id: I3e4a64335555e86e18585d38d8fda4bfdb44e265 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127696 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp')
-rw-r--r--src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp11
1 files changed, 3 insertions, 8 deletions
diff --git a/src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp b/src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp
index 4c08873dcf..55bf9b754b 100644
--- a/src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp
+++ b/src/core/GLES_COMPUTE/kernels/GCWeightsReshapeKernel.cpp
@@ -31,11 +31,13 @@
#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/GLES_COMPUTE/GCHelpers.h"
using namespace arm_compute;
using namespace arm_compute::gles_compute;
+using namespace arm_compute::misc::shape_calculator;
GCWeightsReshapeKernel::GCWeightsReshapeKernel()
: _input(nullptr), _biases(nullptr), _output(nullptr)
@@ -47,15 +49,8 @@ void GCWeightsReshapeKernel::configure(const IGCTensor *input, const IGCTensor *
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
ARM_COMPUTE_ERROR_ON_NULLPTR(output);
- // Calculate output shape
- TensorShape output_shape{ input->info()->tensor_shape() };
- output_shape.collapse(3);
- const size_t tmp_dim = output_shape[0];
- output_shape.set(0, output_shape[1]);
- output_shape.set(1, tmp_dim + (biases != nullptr ? 1 : 0));
-
// Output tensor auto inizialitation if not yet initialized
- auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(output_shape));
+ auto_init_if_empty(*output->info(), input->info()->clone()->set_tensor_shape(compute_weights_reshaped_shape(*input->info(), (biases != nullptr))));
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);