diff options
Diffstat (limited to 'arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h')
-rw-r--r-- | arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h | 51 |
1 files changed, 51 insertions, 0 deletions
diff --git a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h index 43abb6769b..e4e6f0760e 100644 --- a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h +++ b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h @@ -25,7 +25,9 @@ #define __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__ #include "arm_compute/core/CL/kernels/CLConvertFullyConnectedWeightsKernel.h" +#include "arm_compute/runtime/CL/CLTensor.h" #include "arm_compute/runtime/CL/ICLSimpleFunction.h" +#include "arm_compute/runtime/ITransformWeights.h" namespace arm_compute { @@ -54,5 +56,54 @@ public: */ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout); }; + +namespace weights_transformations +{ +/** Basic function to run @ref CLConvertFullyConnectedWeightsKernel. */ +class CLConvertFullyConnectedWeightsManaged : public ITransformWeights +{ +public: + //Inherited method override + void run() override + { + _output.allocator()->allocate(); + _func.run(); + _reshape_run = true; + } + + //Inherited method override + void release() override + { + _output.allocator()->free(); + } + + //Inherited method override + ICLTensor *get_weights() override + { + return &_output; + } + + //Inherited method override + uint32_t uid() override + { + return _uid; + } + /** Configures the @ref CLConvertFullyConnectedWeights function + * + * @param[in] input Source weights tensor info to convert. Data type supported: U8/S8/QASYMM8/U16/S16/U32/S32/F16/F32. + * @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer). + * @param[in] data_layout The data layout the weights have been trained in. + */ + void configure(const ICLTensor *input, const TensorShape &original_input_shape, DataLayout data_layout) + { + _func.configure(input, &_output, original_input_shape, data_layout); + } + +private: + static constexpr uint32_t _uid = 0x5; + CLTensor _output{}; + CLConvertFullyConnectedWeights _func{}; +}; +} // namespace weights_transformations } // namespace arm_compute #endif /* __ARM_COMPUTE_CLCONVERTFULLYCONNECTEDWEIGHTS_H__ */ |