diff options
Diffstat (limited to 'src/runtime/CL/functions/CLFullyConnectedLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLFullyConnectedLayer.cpp | 46 |
1 files changed, 28 insertions, 18 deletions
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp index 9b3bf48bca..151fa1b5fa 100644 --- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp @@ -220,13 +220,6 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w _gemmlowp_output_stage.configure(&_gemmlowp_output, biases, output, output_multiplier, output_shift, output->info()->quantization_info().offset); _gemmlowp_output.allocator()->allocate(); } - - // Allocate the transpose tensor if the are_weights_reshaped flag is false and once all the configure methods have been called - if(!_are_weights_reshaped) - { - // Allocate the tensor for the weights reshaped - _reshape_weights_output.allocator()->allocate(); - } } Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights, bool are_weights_reshaped) @@ -311,17 +304,7 @@ Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn void CLFullyConnectedLayer::run() { - // Reshape of the weights (happens only once) - if(!_are_weights_reshaped) - { - ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); - - _are_weights_reshaped = true; - _reshape_weights_kernel.run(); - - // Mark original weights tensor as unused - _original_weights->mark_as_unused(); - } + prepare(); _memory_group.acquire(); @@ -356,3 +339,30 @@ void CLFullyConnectedLayer::run() _memory_group.release(); } + +void CLFullyConnectedLayer::prepare() +{ + // Reshape of the weights (happens only once) + if(!_are_weights_reshaped) + { + ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + + // Run reshape weights kernel and mark weights as unused + _reshape_weights_output.allocator()->allocate(); + _reshape_weights_kernel.run(); + _original_weights->mark_as_unused(); + + // Prepare GEMM prepare and release unused weights + if(!_is_quantized) + { + _mm_gemm.prepare(); + if(!_reshape_weights_output.is_used()) + { + _reshape_weights_output.allocator()->free(); + } + } + + CLScheduler::get().queue().finish(); + _are_weights_reshaped = true; + } +} |