diff options
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h | 89 |
1 files changed, 70 insertions, 19 deletions
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h index d54304ed77..9512b22c08 100644 --- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h +++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h @@ -64,6 +64,54 @@ public: static Status validate(const ITensorInfo *input, const ITensorInfo *output); }; +namespace weights_transformations +{ +/** Basic function to manage the reshape weights generated from @ref CLFullyConnectedLayerReshapeWeights */ +class CLFullyConnectedLayerReshapeWeightsManaged : public ITransformWeights +{ +public: + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + _func.run(); + _reshape_run = true; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + uint32_t uid() override + { + return _uid; + } + + /** Configures the @ref CLFullyConnectedLayerReshapeWeights function + * + * @param[in] input Source tensor. Data type supported: QASYMM8/F16/F32. + */ + void configure(const ICLTensor *input) + { + _func.configure(input, &_output); + } + +private: + static constexpr uint32_t _uid = 0x0; + CLTensor _output{}; + CLFullyConnectedLayerReshapeWeights _func{}; +}; +} // namespace weights_transformations + /** Basic function to compute a Fully Connected layer on OpenCL. This function calls the following OpenCL kernels: * * -# @ref CLIm2ColKernel (called when the input comes from a convolutional layer) @@ -130,25 +178,28 @@ private: void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights); void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights); - MemoryGroup _memory_group; - CLConvertFullyConnectedWeights _convert_weights; - CLFlattenLayer _flatten_layer; - CLFullyConnectedLayerReshapeWeights _reshape_weights_kernel; - CLGEMM _mm_gemm; - CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; - CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage; - CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer - CLTensor _flatten_output; - CLTensor _gemmlowp_output; - CLTensor _converted_weights_output; - CLTensor _reshape_weights_output; - bool _are_weights_converted; - bool _are_weights_reshaped; - bool _is_fc_after_conv; - bool _accumulate_biases; - bool _is_quantized; - bool _is_prepared; - const ICLTensor *_original_weights; + MemoryGroup _memory_group; + IWeightsManager *_weights_manager; + CLConvertFullyConnectedWeights _convert_weights; + weights_transformations::CLConvertFullyConnectedWeightsManaged _convert_weights_managed; + weights_transformations::CLFullyConnectedLayerReshapeWeightsManaged _reshape_weights_managed_function; + CLFlattenLayer _flatten_layer; + CLFullyConnectedLayerReshapeWeights _reshape_weights_function; + CLGEMM _mm_gemm; + CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp; + CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage; + CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer + CLTensor _flatten_output; + CLTensor _gemmlowp_output; + CLTensor _converted_weights_output; + CLTensor _reshape_weights_output; + bool _are_weights_converted; + bool _are_weights_reshaped; + bool _is_fc_after_conv; + bool _accumulate_biases; + bool _is_quantized; + bool _is_prepared; + const ICLTensor *_original_weights; }; } // namespace arm_compute #endif /* __ARM_COMPUTE_CLFULLYCONNECTEDLAYER_H__ */ |