diff options
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h | 39 |
1 files changed, 19 insertions, 20 deletions
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h index 3357868968..6b8d7a97ec 100644 --- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h +++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h @@ -31,6 +31,7 @@ #include "arm_compute/core/CL/kernels/CLTransposeKernel.h" #include "arm_compute/runtime/CL/CLMemoryGroup.h" #include "arm_compute/runtime/CL/CLTensor.h" +#include "arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h" #include "arm_compute/runtime/CL/functions/CLGEMM.h" #include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h" #include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h" @@ -86,32 +87,26 @@ public: CLFullyConnectedLayer &operator=(CLFullyConnectedLayer &&) = default; /** Set the input and output tensors. * - * @param[in] input Source tensor. Data type supported: QASYMM8/F16/F32. - * @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input - * @param[in] biases Bias tensor. It can be nullptr. Data type supported:Same as @p input. - * @param[out] output Destination tensor. Data type supported: Same as @p input. - * @param[in] transpose_weights (Optional) Transpose weights if true. Defaults to true. - * @param[in] are_weights_reshaped (Optional) Reshape the weights tensor if false. Defaults to false. - * @param[in] retain_internal_weights (Optional) Retain internal reshaped weights. Defaults to false. - * Used for reconfiguration purposes. + * @param[in] input Source tensor. Data type supported: QASYMM8/F16/F32. + * @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input + * @param[in] biases Bias tensor. It can be nullptr. Data type supported:Same as @p input. + * @param[out] output Destination tensor. Data type supported: Same as @p input. + * @param[in] fc_info (Optional) Fully connected layer additional info */ - void configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights = true, bool are_weights_reshaped = false, - bool retain_internal_weights = false); + void configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CLFullyConnectedLayer * - * @param[in] input Source tensor. Data type supported: QASYMM8/F16/F32. - * @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input - * @param[in] biases Bias tensor. It can be nullptr. Data type supported:Same as @p input. - * @param[in] output Destination tensor. Data type supported: Same as @p input. - * @param[in] transpose_weights (Optional) Transpose weights if true. Defaults to true. - * @param[in] are_weights_reshaped (Optional) Reshape the weights tensor if false. Defaults to false. - * @param[in] retain_internal_weights (Optional) Retain internal reshaped weights. Defaults to false. - * Used for reconfiguration purposes. + * @param[in] input Source tensor. Data type supported: QASYMM8/F16/F32. + * @param[in] weights Weights tensor. The weights must be 2 dimensional. Data type supported: Same as @p input + * @param[in] biases Bias tensor. It can be nullptr. Data type supported:Same as @p input. + * @param[in] output Destination tensor. Data type supported: Same as @p input. + * @param[in] fc_info (Optional) Fully connected layer additional info * * @return a status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights = true, bool are_weights_reshaped = false, - bool retain_internal_weights = false); + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo()); //Inherited methods override void run() override; @@ -124,6 +119,7 @@ private: CLMemoryGroup _memory_group; CLIm2ColKernel _im2col_kernel; + CLConvertFullyConnectedWeights _convert_weights; CLFullyConnectedLayerReshapeWeights _reshape_weights_kernel; CLGEMM _mm_gemm; CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; @@ -131,11 +127,14 @@ private: CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel; CLTensor _im2col_output; CLTensor _gemmlowp_output; + CLTensor _converted_weights_output; CLTensor _reshape_weights_output; + bool _are_weights_converted; bool _are_weights_reshaped; bool _is_fc_after_conv; bool _accumulate_biases; bool _is_quantized; + bool _is_prepared; const ICLTensor *_original_weights; }; } |