diff options
Diffstat (limited to 'src/runtime/CL/functions/CLFullyConnectedLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLFullyConnectedLayer.cpp | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp index 2b4670b98c..676706fb17 100644 --- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp @@ -76,7 +76,7 @@ Status CLFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, c CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager) : _memory_group(memory_manager), _im2col_kernel(), _reshape_weights_kernel(), _mm_kernel(), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _im2col_output(), - _gemmlowp_output(), _reshape_weights_output(), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false) + _gemmlowp_output(), _reshape_weights_output(), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _original_weights(nullptr) { } @@ -152,6 +152,7 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w _is_fc_after_conv = true; _accumulate_biases = false; _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); + _original_weights = weights; // Configure gemmlowp output if(_is_quantized) @@ -316,8 +317,13 @@ void CLFullyConnectedLayer::run() // Reshape of the weights (happens only once) if(!_are_weights_reshaped) { + ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); + _are_weights_reshaped = true; _reshape_weights_kernel.run(); + + // Mark original weights tensor as unused + _original_weights->mark_as_unused(); } _memory_group.acquire(); |