aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2021-02-25 20:15:01 +0000
committerGeorgios Pinitas <georgios.pinitas@arm.com>2021-03-29 20:23:11 +0000
commit2788609b8a10306e9eae47543b39812a7b075aaa (patch)
tree81515046e0c06d6a21ecdcebfe083ea5922fea0c /arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
parentf9a611a1fd309bb9a906c99eede5e6b7bceba26b (diff)
downloadComputeLibrary-2788609b8a10306e9eae47543b39812a7b075aaa.tar.gz
Port ClTranspose to new API
Partially Resolves: COMPMID-4277 (1/2) Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I704c2303135cbe1ba46d2fd5642c84c562204bc7 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5194 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h27
1 files changed, 23 insertions, 4 deletions
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
index 3f17e4a921..a640e344d4 100644
--- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
+++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,15 +36,27 @@
namespace arm_compute
{
-/** Basic function to reshape the weights of Fully Connected layer with OpenCL. This function calls the following kernels:
+/** Function to reshape the weights of Fully Connected layer with OpenCL by transposing input tensor. This function calls the following kernel:
*
- * -# @ref CLTransposeKernel
+ * -# @ref opencl::kernels::ClTransposeKernel
*
* @note The fully connected layer accepts "weights" tensors only with 2 dimensions.
*/
-class CLFullyConnectedLayerReshapeWeights : public ICLSimpleFunction
+class CLFullyConnectedLayerReshapeWeights : public IFunction
{
public:
+ /** Constructor */
+ CLFullyConnectedLayerReshapeWeights();
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLFullyConnectedLayerReshapeWeights(const CLFullyConnectedLayerReshapeWeights &) = delete;
+ /** Default move constructor */
+ CLFullyConnectedLayerReshapeWeights(CLFullyConnectedLayerReshapeWeights &&) = default;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ CLFullyConnectedLayerReshapeWeights &operator=(const CLFullyConnectedLayerReshapeWeights &) = delete;
+ /** Default move assignment operator */
+ CLFullyConnectedLayerReshapeWeights &operator=(CLFullyConnectedLayerReshapeWeights &&) = default;
+ /** Default destructor */
+ ~CLFullyConnectedLayerReshapeWeights();
/** Set the input and output tensors.
*
* @param[in] input Weights tensor. The weights must be 2 dimensional. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
@@ -66,6 +78,13 @@ public:
* @return a status
*/
static Status validate(const ITensorInfo *input, const ITensorInfo *output);
+
+ // Inherited methods overridden
+ void run() override;
+
+private:
+ struct Impl;
+ std::unique_ptr<Impl> _impl;
};
namespace weights_transformations