diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2018-05-31 17:31:05 +0100 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:52:54 +0000 |
commit | b62280aca3148dd6762e57e5af3da0cb0a9e2db5 (patch) | |
tree | aa10c3750dcb8b13151d40529facf92667c336c9 /src/runtime/CL/functions/CLFullyConnectedLayer.cpp | |
parent | da2491fb6d3cefb69846f220356fff282486495c (diff) | |
download | ComputeLibrary-b62280aca3148dd6762e57e5af3da0cb0a9e2db5.tar.gz |
COMPMID-1244: Allow retaining weights in CLGEMMConvolutionLayer and CLFullyConnectedLayer
Change-Id: I1c3b2197906cd4b905309bbd5f2012bbae6a7dba
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/133730
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLFullyConnectedLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLFullyConnectedLayer.cpp | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp index 151fa1b5fa..44bf28374f 100644 --- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp +++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp @@ -133,7 +133,8 @@ void CLFullyConnectedLayer::configure_fc_fc(const ICLTensor *input, const ICLTen configure_mm(input, weights, output); } -void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights, bool are_weights_reshaped) +void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights, bool are_weights_reshaped, + bool retain_internal_weights) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); @@ -143,7 +144,8 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w biases != nullptr ? biases->info() : nullptr, output->info(), transpose_weights, - are_weights_reshaped)); + are_weights_reshaped, + retain_internal_weights)); _are_weights_reshaped = transpose_weights ? are_weights_reshaped : true; _is_fc_after_conv = true; @@ -220,10 +222,14 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w _gemmlowp_output_stage.configure(&_gemmlowp_output, biases, output, output_multiplier, output_shift, output->info()->quantization_info().offset); _gemmlowp_output.allocator()->allocate(); } + + _are_weights_reshaped = _are_weights_reshaped || retain_internal_weights; } -Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights, bool are_weights_reshaped) +Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights, bool are_weights_reshaped, + bool retain_internal_weights) { + ARM_COMPUTE_UNUSED(retain_internal_weights); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output); |