aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2020-02-07 13:46:45 +0000
committerGiorgio Arena <giorgio.arena@arm.com>2020-03-02 15:51:39 +0000
commit1856ff7ebb29e04c3549b74d7ced336111cbf05e (patch)
treec94654f0d8535930a81712bf7aadffd757c82577 /src/runtime/CL/functions/CLFullyConnectedLayer.cpp
parent3c4bf0c4eab5ead756c472f17ddf008b882cc905 (diff)
downloadComputeLibrary-1856ff7ebb29e04c3549b74d7ced336111cbf05e.tar.gz
COMPMID-3097 Fuse activation with fully connected layer CL
Change-Id: I447030e69b9e565f2f81529a41af8c5e7ece7ecf Signed-off-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2702 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLFullyConnectedLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLFullyConnectedLayer.cpp42
1 files changed, 33 insertions, 9 deletions
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index dcaa12645e..9b7de8df1b 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -41,7 +41,7 @@ using namespace arm_compute::utils::cast;
namespace
{
Status construct_gemmlowp_output_stage(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo &output,
- GEMMLowpOutputStageInfo &gemmlowp_output_stage)
+ GEMMLowpOutputStageInfo &gemmlowp_output_stage, ActivationLayerInfo activation_info)
{
gemmlowp_output_stage.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
gemmlowp_output_stage.gemmlowp_offset = 0;
@@ -53,13 +53,14 @@ Status construct_gemmlowp_output_stage(const ITensorInfo &input, const ITensorIn
// Configure output stage for quantized case
if(is_data_type_quantized_asymmetric(data_type))
{
- const UniformQuantizationInfo iq_info = input.quantization_info().uniform();
- const UniformQuantizationInfo wq_info = weights.quantization_info().uniform();
- const UniformQuantizationInfo oq_info = output.quantization_info().uniform();
+ const QuantizationInfo oq_info = output.quantization_info();
+ const UniformQuantizationInfo iq_unif = input.quantization_info().uniform();
+ const UniformQuantizationInfo wq_unif = weights.quantization_info().uniform();
+ const UniformQuantizationInfo oq_unif = oq_info.uniform();
- const auto output_quant_info = (output.total_size() == 0) ? iq_info : oq_info;
+ const auto output_quant_info = (output.total_size() == 0) ? iq_unif : oq_unif;
- const float multiplier = (iq_info.scale * wq_info.scale) / output_quant_info.scale;
+ const float multiplier = (iq_unif.scale * wq_unif.scale) / output_quant_info.scale;
int output_multiplier = 0;
int output_shift = 0;
ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
@@ -68,6 +69,27 @@ Status construct_gemmlowp_output_stage(const ITensorInfo &input, const ITensorIn
PixelValue type_max{};
std::tie(type_min, type_max) = get_min_max(data_type);
+ if(activation_info.enabled())
+ {
+ switch(activation_info.activation())
+ {
+ case ActivationLayerInfo::ActivationFunction::RELU:
+ type_min = PixelValue(oq_unif.offset);
+ break;
+ case ActivationLayerInfo::ActivationFunction::BOUNDED_RELU:
+ type_min = PixelValue(oq_unif.offset);
+ type_max = PixelValue(activation_info.a(), data_type, oq_info);
+ break;
+ case ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU:
+ type_min = PixelValue(activation_info.b(), data_type, oq_info);
+ type_max = PixelValue(activation_info.a(), data_type, oq_info);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Activation function not supported.");
+ break;
+ }
+ }
+
// Set the GEMMLowp output stage info
gemmlowp_output_stage.gemmlowp_offset = output_quant_info.offset;
gemmlowp_output_stage.gemmlowp_multiplier = output_multiplier;
@@ -84,7 +106,7 @@ Status construct_gemmlowp_output_stage(const ITensorInfo &input, const ITensorIn
Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const ITensorInfo *bias, const ITensorInfo &output, const FullyConnectedLayerInfo &fc_info)
{
GEMMLowpOutputStageInfo gemmlowp_output_stage;
- ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(input, weights, output, gemmlowp_output_stage));
+ ARM_COMPUTE_RETURN_ON_ERROR(construct_gemmlowp_output_stage(input, weights, output, gemmlowp_output_stage, fc_info.activation_info));
const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped
false, // is_b_reshaped
@@ -144,7 +166,7 @@ CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> mem
void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info)
{
GEMMLowpOutputStageInfo gemmlowp_output_stage;
- construct_gemmlowp_output_stage(*input->info(), *weights->info(), *output->info(), gemmlowp_output_stage);
+ construct_gemmlowp_output_stage(*input->info(), *weights->info(), *output->info(), gemmlowp_output_stage, fc_info.activation_info);
const GEMMInfo &gemm_info = GEMMInfo(false, // is_a_reshaped
false, // is_b_reshaped
@@ -155,7 +177,7 @@ void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor
gemmlowp_output_stage, // gemmlowp_output_stage
fc_info.fp_mixed_precision, // fp_mixed_precision
true, // broadcast_bias
- ActivationLayerInfo()); // activation_info
+ fc_info.activation_info); // activation_info
if(_is_quantized)
{
@@ -313,6 +335,8 @@ Status CLFullyConnectedLayer::validate(const ITensorInfo *input, const ITensorIn
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights, output);
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 2);
+ ARM_COMPUTE_RETURN_ERROR_ON(fc_info.activation_info.enabled() && is_data_type_quantized(input->data_type()) && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::RELU
+ && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::BOUNDED_RELU && fc_info.activation_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU);
bool weights_reshaped = fc_info.transpose_weights ? fc_info.are_weights_reshaped : true;
bool is_fc_after_conv = true;