aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h')
-rw-r--r--arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h89
1 files changed, 70 insertions, 19 deletions
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
index d54304ed77..9512b22c08 100644
--- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
+++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
@@ -64,6 +64,54 @@ public:
static Status validate(const ITensorInfo *input, const ITensorInfo *output);
};
+namespace weights_transformations
+{
+/** Basic function to manage the reshape weights generated from @ref CLFullyConnectedLayerReshapeWeights */
+class CLFullyConnectedLayerReshapeWeightsManaged : public ITransformWeights
+{
+public:
+ //Inherited method override
+ void run() override
+ {
+ _output.allocator()->allocate();
+ _func.run();
+ _reshape_run = true;
+ }
+
+ //Inherited method override
+ void release() override
+ {
+ _output.allocator()->free();
+ }
+
+ //Inherited method override
+ ICLTensor *get_weights() override
+ {
+ return &_output;
+ }
+
+ //Inherited method override
+ uint32_t uid() override
+ {
+ return _uid;
+ }
+
+ /** Configures the @ref CLFullyConnectedLayerReshapeWeights function
+ *
+ * @param[in] input Source tensor. Data type supported: QASYMM8/F16/F32.
+ */
+ void configure(const ICLTensor *input)
+ {
+ _func.configure(input, &_output);
+ }
+
+private:
+ static constexpr uint32_t _uid = 0x0;
+ CLTensor _output{};
+ CLFullyConnectedLayerReshapeWeights _func{};
+};
+} // namespace weights_transformations
+
/** Basic function to compute a Fully Connected layer on OpenCL. This function calls the following OpenCL kernels:
*
* -# @ref CLIm2ColKernel (called when the input comes from a convolutional layer)
@@ -130,25 +178,28 @@ private:
void configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights);
void configure_mm(const ICLTensor *input, const ICLTensor *weights, ICLTensor *output, bool retain_internal_weights);
- MemoryGroup _memory_group;
- CLConvertFullyConnectedWeights _convert_weights;
- CLFlattenLayer _flatten_layer;
- CLFullyConnectedLayerReshapeWeights _reshape_weights_kernel;
- CLGEMM _mm_gemm;
- CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp;
- CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
- CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer
- CLTensor _flatten_output;
- CLTensor _gemmlowp_output;
- CLTensor _converted_weights_output;
- CLTensor _reshape_weights_output;
- bool _are_weights_converted;
- bool _are_weights_reshaped;
- bool _is_fc_after_conv;
- bool _accumulate_biases;
- bool _is_quantized;
- bool _is_prepared;
- const ICLTensor *_original_weights;
+ MemoryGroup _memory_group;
+ IWeightsManager *_weights_manager;
+ CLConvertFullyConnectedWeights _convert_weights;
+ weights_transformations::CLConvertFullyConnectedWeightsManaged _convert_weights_managed;
+ weights_transformations::CLFullyConnectedLayerReshapeWeightsManaged _reshape_weights_managed_function;
+ CLFlattenLayer _flatten_layer;
+ CLFullyConnectedLayerReshapeWeights _reshape_weights_function;
+ CLGEMM _mm_gemm;
+ CLGEMMLowpMatrixMultiplyCore _mm_gemmlowp;
+ CLGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
+ CLGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel; // TODO(COMPMID-1889): Use CLGEMM to add bias in CLFullyConnectedLayer
+ CLTensor _flatten_output;
+ CLTensor _gemmlowp_output;
+ CLTensor _converted_weights_output;
+ CLTensor _reshape_weights_output;
+ bool _are_weights_converted;
+ bool _are_weights_reshaped;
+ bool _is_fc_after_conv;
+ bool _accumulate_biases;
+ bool _is_quantized;
+ bool _is_prepared;
+ const ICLTensor *_original_weights;
};
} // namespace arm_compute
#endif /* __ARM_COMPUTE_CLFULLYCONNECTEDLAYER_H__ */