diff options
Diffstat (limited to 'src/runtime/NEON/functions/NEFullyConnectedLayer.cpp')
-rw-r--r-- | src/runtime/NEON/functions/NEFullyConnectedLayer.cpp | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp index 26b7271710..b310ad35e3 100644 --- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp +++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp @@ -132,7 +132,7 @@ void NEFullyConnectedLayerReshapeWeights::run() NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager) : _memory_group(std::move(memory_manager)), _im2col_kernel(), _reshape_weights_kernel(), _interleave4x4_kernel(), _mm_kernel(), _accumulate_biases_kernel(), _im2col_output(), _interleave4x4_output(), - _reshape_weights_output(), _are_weights_reshaped(false), _is_batched_fc_layer(false), _linearize_input(false), _accumulate_biases(false) + _reshape_weights_output(), _are_weights_reshaped(false), _is_batched_fc_layer(false), _linearize_input(false), _accumulate_biases(false), _original_weights(nullptr) { } @@ -163,6 +163,7 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh const int num_input_dimensions = input->info()->tensor_shape().num_dimensions() - num_batch_dimensions; const size_t linear_input_size = input->info()->tensor_shape().total_size_lower(num_input_dimensions); + _original_weights = weights; _linearize_input = (input->info()->tensor_shape().x() != linear_input_size) || (num_input_dimensions > 1 && linear_input_size == 1); _are_weights_reshaped = are_weights_reshaped; _accumulate_biases = biases != nullptr; @@ -324,8 +325,13 @@ void NEFullyConnectedLayer::run() // Reshape of the weights (happens only once) if(!_are_weights_reshaped) { + ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + _are_weights_reshaped = true; _reshape_weights_kernel.run(); + + // Mark original weights tensor as unused + _original_weights->mark_as_unused(); } _memory_group.acquire(); |