aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLFullyConnectedLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLFullyConnectedLayer.cpp12
1 files changed, 9 insertions, 3 deletions
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index 151fa1b5fa..44bf28374f 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -133,7 +133,8 @@ void CLFullyConnectedLayer::configure_fc_fc(const ICLTensor *input, const ICLTen
configure_mm(input, weights, output);
}
-void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights, bool are_weights_reshaped)
+void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, bool transpose_weights, bool are_weights_reshaped,
+ bool retain_internal_weights)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
@@ -143,7 +144,8 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w
biases != nullptr ? biases->info() : nullptr,
output->info(),
transpose_weights,
- are_weights_reshaped));
+ are_weights_reshaped,
+ retain_internal_weights));
_are_weights_reshaped = transpose_weights ? are_weights_reshaped : true;
_is_fc_after_conv = true;
@@ -220,10 +222,14 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w
_gemmlowp_output_stage.configure(&_gemmlowp_output, biases, output, output_multiplier, output_shift, output->info()->quantization_info().offset);
_gemmlowp_output.allocator()->allocate();
}
+
+ _are_weights_reshaped = _are_weights_reshaped || retain_internal_weights;
}
-Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights, bool are_weights_reshaped)
+Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, bool transpose_weights, bool are_weights_reshaped,
+ bool retain_internal_weights)
{
+ ARM_COMPUTE_UNUSED(retain_internal_weights);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);