aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h27
1 files changed, 23 insertions, 4 deletions
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
index 3f17e4a921..a640e344d4 100644
--- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
+++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,15 +36,27 @@
namespace arm_compute
{
-/** Basic function to reshape the weights of Fully Connected layer with OpenCL. This function calls the following kernels:
+/** Function to reshape the weights of Fully Connected layer with OpenCL by transposing input tensor. This function calls the following kernel:
*
- * -# @ref CLTransposeKernel
+ * -# @ref opencl::kernels::ClTransposeKernel
*
* @note The fully connected layer accepts "weights" tensors only with 2 dimensions.
*/
-class CLFullyConnectedLayerReshapeWeights : public ICLSimpleFunction
+class CLFullyConnectedLayerReshapeWeights : public IFunction
{
public:
+ /** Constructor */
+ CLFullyConnectedLayerReshapeWeights();
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLFullyConnectedLayerReshapeWeights(const CLFullyConnectedLayerReshapeWeights &) = delete;
+ /** Default move constructor */
+ CLFullyConnectedLayerReshapeWeights(CLFullyConnectedLayerReshapeWeights &&) = default;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLFullyConnectedLayerReshapeWeights &operator=(const CLFullyConnectedLayerReshapeWeights &) = delete;
+ /** Default move assignment operator */
+ CLFullyConnectedLayerReshapeWeights &operator=(CLFullyConnectedLayerReshapeWeights &&) = default;
+ /** Default destructor */
+ ~CLFullyConnectedLayerReshapeWeights();
/** Set the input and output tensors.
*
* @param[in] input Weights tensor. The weights must be 2 dimensional. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
@@ -66,6 +78,13 @@ public:
* @return a status
*/
static Status validate(const ITensorInfo *input, const ITensorInfo *output);
+
+ // Inherited methods overridden
+ void run() override;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> _impl;
};
namespace weights_transformations