diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-09-27 11:04:27 +0100 |
---|---|---|
committer | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-10-03 15:59:01 +0000 |
commit | b27e13a0ad630d3d9b3143c0374b5ff5000eebc0 (patch) | |
tree | 86defdbcd080fb8ab7f22c8c46e7793eeac80640 /arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h | |
parent | 2ff0009ca9245304c48889c8ba8d3a39d42febed (diff) | |
download | ComputeLibrary-b27e13a0ad630d3d9b3143c0374b5ff5000eebc0.tar.gz |
COMPMID-2685: [CL] Use Weights manager
Change-Id: Ia1818e6ecd9386e96378e64f14d02592fe3cdf0f
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1997
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h index 43abb6769b..e4e6f0760e 100644 --- a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h +++ b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h @@ -25,7 +25,9 @@ #define __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__ #include "arm_compute/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.h" +#include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/CL/ICLSimpleFunction.h" +#include "arm_compute/runtime/ITransformWeights.h" namespace arm_compute { @@ -54,5 +56,54 @@ public: */ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout); }; + +namespace weights_transformations +{ +/** Basic function to run @ref CLConvertFullyConnectedWeightsKernel. */ +class CLConvertFullyConnectedWeightsManaged : public ITransformWeights +{ +public: + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + _func.run(); + _reshape_run = true; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + uint32_t uid() override + { + return _uid; + } + /** Configures the @ref CLConvertFullyConnectedWeights function + * + * @param[in] input Source weights tensor info to convert. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32. + * @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer). + * @param[in] data_layout The data layout the weights have been trained in. + */ + void configure(const ICLTensor *input, const TensorShape &original_input_shape, DataLayout data_layout) + { + _func.configure(input, &_output, original_input_shape, data_layout); + } + +private: + static constexpr uint32_t _uid = 0x5; + CLTensor _output{}; + CLConvertFullyConnectedWeights _func{}; +}; +} // namespace weights_transformations } // namespace arm_compute #endif /* __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__ */ |