From 2b84be544e4a27f7e8e80827e9c85c8f0d58b4ce Mon Sep 17 00:00:00 2001 From: Manuel Bottini Date: Wed, 8 Apr 2020 10:15:51 +0100 Subject: COMPMID-3280: Make all ML primitives for CL use the new interface - Part 2 - CLFunctions have been updated Change-Id: Ie3256a6c775bc12f3126482bd8e8a46da54b267c Signed-off-by: Manuel Bottini Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3053 Reviewed-by: Georgios Pinitas Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- .../runtime/CL/functions/CLFullyConnectedLayer.h | 41 +++++++++++++++++++--- 1 file changed, 37 insertions(+), 4 deletions(-) (limited to 'arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h') diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h index cbd28603fc..188117f674 100644 --- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h +++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h @@ -52,6 +52,13 @@ public: * @param[out] output Destination tensor which stores the transposed input tensor. Data type supported: Same as @p input. */ void configure(const ICLTensor *input, ICLTensor *output); + /** Set the input and output tensors. + * + * @param[in] compile_context The compile context to be used. + * @param[in] input Weights tensor. The weights must be 2 dimensional. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32. + * @param[out] output Destination tensor which stores the transposed input tensor. Data type supported: Same as @p input. + */ + void configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output); /** Static function to check if given info will lead to a valid configuration of @ref CLFullyConnectedLayerReshapeWeights * * @param[in] input Weights tensor. The weights must be 2 dimensional. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32. @@ -100,7 +107,16 @@ public: */ void configure(const ICLTensor *input) { - _func.configure(input, &_output); + configure(CLKernelLibrary::get().get_compile_context(), input); + } + /** Configures the @ref CLFullyConnectedLayerReshapeWeights function + * + * @param[in] compile_context The compile context to be used. + * @param[in] input Source tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. + */ + void configure(const CLCompileContext &compile_context, const ICLTensor *input) + { + _func.configure(compile_context, input, &_output); } private: @@ -147,6 +163,23 @@ public: */ void configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); + /** Set the input and output tensors. + * + * @param[in] compile_context The compile context to be used. + * @param[in] input Source tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. + * @param[in] weights Weights tensor. The weights must be 2 dimensional. + * If this function is called after a Convolution Layer, the (transposed) weights will have as many rows as the product of the first 3 input's dimensions. + * If it is called after another FullyConnected Layer, the (transposed) weights will have as many rows as the input's first dimension. + * Data type supported: Same as @p input. + * @param[in] biases Bias tensor. Can be nullptr. Data type supported:Same as @p input. + * @param[out] output Destination tensor. Its shape should be equal to the output of a matrix multiplication between: + * - The output of im2col on the input and the (transposed) 2D weights, if the function is called after a Convolution Layer + * - The input tensor and the (transposed) 2D weights, if the function is called after another FullyConnected Layer. + * Data type supported: Same as @p input. + * @param[in] fc_info (Optional) Fully connected layer additional info + */ + void configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CLFullyConnectedLayer * * @param[in] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. @@ -171,9 +204,9 @@ public: void prepare() override; private: - void configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info); - void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info); - void configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info); + void configure_fc_fc(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info); + void configure_conv_fc(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info); + void configure_mm(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info); MemoryGroup _memory_group; IWeightsManager *_weights_manager; -- cgit v1.2.1