aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
diff options
context:
space:
mode:
authorGian Marco Iodice <gianmarco.iodice@arm.com>2017-08-15 11:45:22 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commitedfa9f463bed084f8b0953557202b2a1e56da817 (patch)
tree5d1e92926d112fde05dcbc61324d96f73f692390 /arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
parentdc460f13ee65e27b2a428e44c2d80afb1f516a99 (diff)
downloadComputeLibrary-edfa9f463bed084f8b0953557202b2a1e56da817.tar.gz
COMPMID-477 - Optimized batched case in CLConvolutionLayer
Change-Id: I4ef18f49f1da0cb816aaa0762466b940792c15ed Reviewed-on: http://mpd-gerrit.cambridge.arm.com/84162 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h42
1 files changed, 12 insertions, 30 deletions
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
index a29f68fcf1..e076f51b26 100644
--- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
+++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
@@ -24,12 +24,10 @@
#ifndef __ARM_COMPUTE_CLFULLYCONNECTEDLAYER_H__
#define __ARM_COMPUTE_CLFULLYCONNECTEDLAYER_H__
-#include "arm_compute/runtime/IFunction.h"
+#include "arm_compute/runtime/CL/ICLSimpleFunction.h"
-#include "arm_compute/core/CL/kernels/CLGEMMInterleave4x4Kernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixAccumulateBiasesKernel.h"
#include "arm_compute/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
-#include "arm_compute/core/CL/kernels/CLGEMMTranspose1xWKernel.h"
#include "arm_compute/core/CL/kernels/CLIm2ColKernel.h"
#include "arm_compute/core/CL/kernels/CLTransposeKernel.h"
#include "arm_compute/runtime/CL/CLTensor.h"
@@ -38,41 +36,25 @@ namespace arm_compute
{
/** Basic function to reshape the weights of Fully Connected layer with OpenCL. This function calls the following kernels:
*
- * -# @ref CLTransposeKernel (if @p transpose_weights is set to true)
- * -# @ref CLGEMMTranspose1xWKernel (if @p is_batched_fc_layer is set to true)
+ * -# @ref CLTransposeKernel
*
* @note The fully connected layer accepts "weights" tensors only with 2 dimensions.
*/
-class CLFullyConnectedLayerReshapeWeights : public IFunction
+class CLFullyConnectedLayerReshapeWeights : public ICLSimpleFunction
{
public:
- /** Constructor */
- CLFullyConnectedLayerReshapeWeights();
/** Set the input and output tensors.
*
- * @param[in] input Weights tensor. The weights must be 2 dimensional. Data types supported: QS8/QS16/F16/F32.
- * @param[out] output Destination tensor. Data type supported: Same as @p input.
- * @param[in] transpose_weights True if the weights must be transposed. Data types supported: Same as @p weights.
- * @param[in] is_batched_fc_layer True if it is a batched fully connected layer
+ * @param[in] input Weights tensor. The weights must be 2 dimensional. Data types supported: QS8/QS16/F16/F32.
+ * @param[out] output Destination tensor which stores the transposed input tensor. Data type supported: Same as @p input.
*/
- void configure(const ICLTensor *input, ICLTensor *output, bool transpose_weights, bool is_batched_fc_layer);
-
- // Inherited methods overridden:
- void run() override;
-
-private:
- CLTransposeKernel _transpose_kernel;
- CLGEMMTranspose1xWKernel _transpose1xW_kernel;
- CLTensor _transpose_output;
- bool _transpose_weights;
- bool _is_batched_fc_layer;
+ void configure(const ICLTensor *input, ICLTensor *output);
};
/** Basic function to compute a Fully Connected layer on OpenCL. This function calls the following OpenCL kernels:
*
* -# @ref CLIm2ColKernel (called when the input comes from a convolutional layer)
- * -# @ref CLFullyConnectedLayerReshapeWeights (if @p are_weights_reshaped is set to false) (called once)
- * -# @ref CLGEMMInterleave4x4Kernel (called if we have a multi-batch input)
+ * -# @ref CLFullyConnectedLayerReshapeWeights (if @p are_weights_reshaped is set to false and transpose_weights is set to true ) (called once)
* -# @ref CLGEMMMatrixMultiplyKernel
* -# @ref CLGEMMMatrixAccumulateBiasesKernel (if @p biases is not equal to nullptr)
*
@@ -85,7 +67,7 @@ public:
CLFullyConnectedLayer();
/** Set the input and output tensors.
*
- * @param[in] input Source tensor. Data type supported: QS8/F16/F32.
+ * @param[in] input Source tensor. Data type supported: QS8/QS16/F16/F32.
* @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input
* @param[in] biases Bias tensor. It can be nullptr. Data type supported:Same as @p input.
* @param[out] output Destination tensor. Data type supported: Same as @p input.
@@ -98,17 +80,17 @@ public:
void run() override;
private:
+ void configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output);
+ void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output);
+
CLIm2ColKernel _im2col_kernel;
CLFullyConnectedLayerReshapeWeights _reshape_weights_kernel;
- CLGEMMInterleave4x4Kernel _interleave4x4_kernel;
CLGEMMMatrixMultiplyKernel _mm_kernel;
CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel;
CLTensor _im2col_output;
- CLTensor _interleave4x4_output;
CLTensor _reshape_weights_output;
bool _are_weights_reshaped;
- bool _is_batched_fc_layer;
- bool _linearize_input;
+ bool _is_fc_after_conv;
bool _accumulate_biases;
};
}