From b27e13a0ad630d3d9b3143c0374b5ff5000eebc0 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Fri, 27 Sep 2019 11:04:27 +0100 Subject: COMPMID-2685: [CL] Use Weights manager Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/1997 Comments-Addressed: Arm Jenkins Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- .../CL/functions/CLConvertFullyConnectedWeights.h | 51 ++++++++++++++++++++++ 1 file changed, 51 insertions(+) (limited to 'arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h') diff --git a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h index 43abb6769b..e4e6f0760e 100644 --- a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h +++ b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h @@ -25,7 +25,9 @@ #define __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__ #include "arm_compute/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.h" +#include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/CL/ICLSimpleFunction.h" +#include "arm_compute/runtime/ITransformWeights.h" namespace arm_compute { @@ -54,5 +56,54 @@ public: */ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout); }; + +namespace weights_transformations +{ +/** Basic function to run @ref CLConvertFullyConnectedWeightsKernel. */ +class CLConvertFullyConnectedWeightsManaged : public ITransformWeights +{ +public: + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + _func.run(); + _reshape_run = true; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + uint32_t uid() override + { + return _uid; + } + /** Configures the @ref CLConvertFullyConnectedWeights function + * + * @param[in] input Source weights tensor info to convert. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32. + * @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer). + * @param[in] data_layout The data layout the weights have been trained in. + */ + void configure(const ICLTensor *input, const TensorShape &original_input_shape, DataLayout data_layout) + { + _func.configure(input, &_output, original_input_shape, data_layout); + } + +private: + static constexpr uint32_t _uid = 0x5; + CLTensor _output{}; + CLConvertFullyConnectedWeights _func{}; +}; +} // namespace weights_transformations } // namespace arm_compute #endif /* __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__ */ -- cgit v1.2.1