diff options
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h')
-rw-r--r-- | arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h | 71 |
1 files changed, 21 insertions, 50 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h b/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h index 984e8d68c0..dc6b22d717 100644 --- a/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h +++ b/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -24,19 +24,16 @@ #ifndef ARM_COMPUTE_NECONVERTFULLYCONNECTEDWEIGHTS_H #define ARM_COMPUTE_NECONVERTFULLYCONNECTEDWEIGHTS_H +#include "arm_compute/core/Types.h" #include "arm_compute/runtime/IFunction.h" -#include "arm_compute/runtime/ITransformWeights.h" -#include "arm_compute/runtime/NEON/NEScheduler.h" -#include "arm_compute/runtime/Tensor.h" -#include <memory> namespace arm_compute { // Forward declarations class ITensor; -class NEConvertFullyConnectedWeightsKernel; +class ITensorInfo; -/** Basic function to run @ref NEConvertFullyConnectedWeightsKernel. */ +/** Basic function to run @ref cpu::kernels::CpuConvertFullyConnectedWeightsKernel. */ class NEConvertFullyConnectedWeights : public IFunction { public: @@ -54,12 +51,22 @@ public: ~NEConvertFullyConnectedWeights(); /** Initialize the function. * + * Valid data layouts: + * - NHWC + * - NCHW + * + * Valid data type configurations: + * |src |dst | + * |:--------------|:--------------| + * |All |All | + * * @param[in] input Source weights tensor to convert. Must be 2 dimensional. Data types supported: All. * @param[out] output The converted weights tensor. Shape and Data Type: Same as @p input. * @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer). * @param[in] data_layout The data layout the weights have been trained in. */ - void configure(const ITensor *input, ITensor *output, const TensorShape &original_input_shape, DataLayout data_layout); + void + configure(const ITensor *input, ITensor *output, const TensorShape &original_input_shape, DataLayout data_layout); /** Static function to check if given info will lead to a valid configuration of @ref NEConvertFullyConnectedWeights * * @param[in] input Source weights tensor info to convert. Must be 2 dimensional. Data types supported: All. @@ -69,53 +76,17 @@ public: * * @return A Status */ - static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout); + static Status validate(const ITensorInfo *input, + const ITensorInfo *output, + const TensorShape &original_input_shape, + DataLayout data_layout); // Inherited methods overriden: void run() override; private: - std::unique_ptr<NEConvertFullyConnectedWeightsKernel> _kernel; -}; - -namespace weights_transformations -{ -/** Basic function to run @ref NEConvertFullyConnectedWeightsKernel. */ -class NEConvertFullyConnectedWeightsManaged : public ITransformWeights -{ -public: - void run() override - { - _output.allocator()->allocate(); - _func.run(); - _reshape_run = true; - } - - void release() override - { - _output.allocator()->free(); - } - - ITensor *get_weights() override - { - return &_output; - } - - uint32_t uid() override - { - return _uid; - } - - void configure(const ITensor *input, const TensorShape &original_input_shape, DataLayout data_layout) - { - _func.configure(input, &_output, original_input_shape, data_layout); - } - -private: - static constexpr uint32_t _uid = 0x4; - Tensor _output{}; - NEConvertFullyConnectedWeights _func{}; + struct Impl; + std::unique_ptr<Impl> _impl; }; -} // namespace weights_transformations } // namespace arm_compute #endif /* ARM_COMPUTE_NECONVERTFULLYCONNECTEDWEIGHTS_H */ |