diff options
Diffstat (limited to 'src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp')
-rw-r--r-- | src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp index 0f8f8e6c94..a300033bb2 100644 --- a/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp +++ b/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp @@ -79,7 +79,8 @@ void GCFullyConnectedLayer::configure_fc_fc(const IGCTensor *input, const IGCTen _mm_kernel.configure(input, weights, output, 1.0f, false); } -void GCFullyConnectedLayer::configure(const IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, bool transpose_weights, bool are_weights_reshaped) +void GCFullyConnectedLayer::configure(const IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, + bool transpose_weights, bool are_weights_reshaped, bool retain_internal_weights) { ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32, DataType::F16); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output); @@ -141,11 +142,14 @@ void GCFullyConnectedLayer::configure(const IGCTensor *input, const IGCTensor *w } // Allocate the transpose tensor if the are_weights_reshaped flag is false and once all the configure methods have been called - if(!_are_weights_reshaped) + if(!_are_weights_reshaped && !retain_internal_weights) { // Allocate the tensor for the weights reshaped _reshape_weights_output.allocator()->allocate(); } + + ARM_COMPUTE_ERROR_ON(retain_internal_weights && _reshape_weights_output.gc_buffer() == 0); + _are_weights_reshaped = _are_weights_reshaped || retain_internal_weights; } void GCFullyConnectedLayer::run() @@ -158,6 +162,7 @@ void GCFullyConnectedLayer::run() } _memory_group.acquire(); + // Linearize input if it comes from a convolutional layer if(_is_fc_after_conv) { @@ -179,5 +184,6 @@ void GCFullyConnectedLayer::run() GCScheduler::get().dispatch(_accumulate_biases_kernel); } + _memory_group.release(); } |