diff options
author | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-09-10 17:20:34 +0100 |
---|---|---|
committer | Michalis Spyrou <michalis.spyrou@arm.com> | 2019-09-26 10:17:30 +0000 |
commit | 1a569a30a2f456ff1a3e0a665201e1c3ab92df80 (patch) | |
tree | 9d68934f461579edefbe65246f6ee435aaa18808 /arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h | |
parent | f1cf394ae882e6e8fb2e0986f88d2548b82a85bb (diff) | |
download | ComputeLibrary-1a569a30a2f456ff1a3e0a665201e1c3ab92df80.tar.gz |
COMPMID-2161 [NEON] Create IWeightManager class
Change-Id: I1a9a46da2f98e896b825099151b56d1d8271dd31
Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1915
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h')
-rw-r--r-- | arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h | 48 |
1 files changed, 46 insertions, 2 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h b/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h index 8f261421e6..50a86bd7c4 100644 --- a/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h +++ b/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -26,7 +26,9 @@ #include "arm_compute/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.h" #include "arm_compute/runtime/IFunction.h" +#include "arm_compute/runtime/ITransformWeights.h" #include "arm_compute/runtime/NEON/NEScheduler.h" +#include "arm_compute/runtime/Tensor.h" namespace arm_compute { @@ -52,6 +54,8 @@ public: * @param[in] output The converted weights tensor info. Shape and Data Type: Same as @p input. * @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer). * @param[in] data_layout The data layout the weights have been trained in. + * + * @return A Status */ static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout); @@ -61,5 +65,45 @@ public: private: NEConvertFullyConnectedWeightsKernel _kernel; }; -} + +namespace weights_transformations +{ +/** Basic function to run @ref NEConvertFullyConnectedWeightsKernel. */ +class NEConvertFullyConnectedWeightsManaged : public ITransformWeights +{ +public: + void run() override + { + _output.allocator()->allocate(); + _func.run(); + _reshape_run = true; + } + + void release() override + { + _output.allocator()->free(); + } + + ITensor *get_weights() override + { + return &_output; + } + + uint32_t uid() override + { + return _uid; + } + + void configure(const ITensor *input, const TensorShape &original_input_shape, DataLayout data_layout) + { + _func.configure(input, &_output, original_input_shape, data_layout); + } + +private: + static constexpr uint32_t _uid = 0x4; + Tensor _output{}; + NEConvertFullyConnectedWeights _func{}; +}; +} // namespace weights_transformations +} // namespace arm_compute #endif /* __ARM_COMPUTE_NECONVERTFULLYCONNECTEDWEIGHTS_H__ */ |