diff options
Diffstat (limited to 'src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp')
-rw-r--r-- | src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp | 35 |
1 files changed, 21 insertions, 14 deletions
diff --git a/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp index a300033bb2..ab2c6c2813 100644 --- a/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp +++ b/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp @@ -40,7 +40,7 @@ void GCFullyConnectedLayerReshapeWeights::configure(const IGCTensor *input, IGCT GCFullyConnectedLayer::GCFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager) : _memory_group(std::move(memory_manager)), _im2col_kernel(), _reshape_weights_kernel(), _mm_kernel(), _accumulate_biases_kernel(), _im2col_output(), _reshape_weights_output(), - _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false) + _original_weights(nullptr), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false) { } @@ -86,6 +86,7 @@ void GCFullyConnectedLayer::configure(const IGCTensor *input, const IGCTensor *w ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output); ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 2); + _original_weights = weights; _are_weights_reshaped = transpose_weights ? are_weights_reshaped : true; _is_fc_after_conv = true; _accumulate_biases = false; @@ -141,25 +142,13 @@ void GCFullyConnectedLayer::configure(const IGCTensor *input, const IGCTensor *w configure_fc_fc(input, weights_to_use, output); } - // Allocate the transpose tensor if the are_weights_reshaped flag is false and once all the configure methods have been called - if(!_are_weights_reshaped && !retain_internal_weights) - { - // Allocate the tensor for the weights reshaped - _reshape_weights_output.allocator()->allocate(); - } - ARM_COMPUTE_ERROR_ON(retain_internal_weights && _reshape_weights_output.gc_buffer() == 0); _are_weights_reshaped = _are_weights_reshaped || retain_internal_weights; } void GCFullyConnectedLayer::run() { - // Reshape of the weights (happens only once) - if(!_are_weights_reshaped) - { - _are_weights_reshaped = true; - _reshape_weights_kernel.run(); - } + prepare(); _memory_group.acquire(); @@ -187,3 +176,21 @@ void GCFullyConnectedLayer::run() _memory_group.release(); } + +void GCFullyConnectedLayer::prepare() +{ + // Reshape of the weights (happens only once) + if(!_are_weights_reshaped) + { + ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + + // Run reshape weights kernel and mark weights as unused + _reshape_weights_output.allocator()->allocate(); + _reshape_weights_kernel.run(); + + // Mark original weights tensor as unused + _original_weights->mark_as_unused(); + + _are_weights_reshaped = true; + } +}
\ No newline at end of file |