aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2020-04-08 10:15:51 +0100
committerManuel Bottini <manuel.bottini@arm.com>2020-04-23 17:53:59 +0000
commit2b84be544e4a27f7e8e80827e9c85c8f0d58b4ce (patch)
tree078051a911f9b8883a3f11955cfd3b7ba0d7d9f3 /src/runtime/CL/functions/CLFullyConnectedLayer.cpp
parent0de45d0a8009e19331c4e29d617fa183167c513a (diff)
downloadComputeLibrary-2b84be544e4a27f7e8e80827e9c85c8f0d58b4ce.tar.gz
COMPMID-3280: Make all ML primitives for CL use the new interface - Part 2
- CLFunctions have been updated Change-Id: Ie3256a6c775bc12f3126482bd8e8a46da54b267c Signed-off-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3053 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/CL/functions/CLFullyConnectedLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLFullyConnectedLayer.cpp44
1 files changed, 29 insertions, 15 deletions
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index 0c0fbd5c9d..ecbac6f703 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -147,8 +147,13 @@ Status validate_mm(const ITensorInfo &input, const ITensorInfo &weights, const I
void CLFullyConnectedLayerReshapeWeights::configure(const ICLTensor *input, ICLTensor *output)
{
+ configure(CLKernelLibrary::get().get_compile_context(), input, output);
+}
+
+void CLFullyConnectedLayerReshapeWeights::configure(const CLCompileContext &compile_context, const ICLTensor *input, ICLTensor *output)
+{
auto k = arm_compute::support::cpp14::make_unique<CLTransposeKernel>();
- k->configure(input, output);
+ k->configure(compile_context, input, output);
_kernel = std::move(k);
}
@@ -163,7 +168,8 @@ CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> mem
_are_weights_reshaped(true), _is_fc_after_conv(true), _is_quantized(false), _is_prepared(false), _original_weights(nullptr)
{
}
-void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info)
+void CLFullyConnectedLayer::configure_mm(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
+ const FullyConnectedLayerInfo &fc_info)
{
GEMMLowpOutputStageInfo gemmlowp_output_stage;
construct_gemmlowp_output_stage(*input->info(), *weights->info(), *output->info(), gemmlowp_output_stage, fc_info.activation_info);
@@ -190,7 +196,7 @@ void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor
weights->info()->set_quantization_info(QuantizationInfo(weights_quantization_info.uniform().scale, -weights_quantization_info.uniform().offset));
// Configure gemmlowp function
- _mm_gemmlowp.configure(input, weights, bias, output, gemm_info);
+ _mm_gemmlowp.configure(compile_context, input, weights, bias, output, gemm_info);
// Revert back QuantizatioInfo as input and weights could be used in other fully connected layers
input->info()->set_quantization_info(input_quantization_info);
@@ -199,11 +205,12 @@ void CLFullyConnectedLayer::configure_mm(const ICLTensor *input, const ICLTensor
else
{
// Configure matrix multiply kernel
- _mm_gemm.configure(input, weights, bias, output, 1.f, 1.f, gemm_info);
+ _mm_gemm.configure(compile_context, input, weights, bias, output, 1.f, 1.f, gemm_info);
}
}
-void CLFullyConnectedLayer::configure_conv_fc(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info)
+void CLFullyConnectedLayer::configure_conv_fc(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
+ const FullyConnectedLayerInfo &fc_info)
{
ARM_COMPUTE_ERROR_ON((weights->info()->dimension(1) != (input->info()->dimension(0) * input->info()->dimension(1) * input->info()->dimension(2))));
@@ -215,26 +222,33 @@ void CLFullyConnectedLayer::configure_conv_fc(const ICLTensor *input, const ICLT
// Configure flatten kernel
_memory_group.manage(&_flatten_output);
- _flatten_layer.configure(input, &_flatten_output);
+ _flatten_layer.configure(compile_context, input, &_flatten_output);
// Configure matrix multiply kernel
- configure_mm(&_flatten_output, weights, bias, output, fc_info);
+ configure_mm(compile_context, &_flatten_output, weights, bias, output, fc_info);
// Allocate the output tensor for flatten once all the configure methods have been called
_flatten_output.allocator()->allocate();
}
-void CLFullyConnectedLayer::configure_fc_fc(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output, const FullyConnectedLayerInfo &fc_info)
+void CLFullyConnectedLayer::configure_fc_fc(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *bias, ICLTensor *output,
+ const FullyConnectedLayerInfo &fc_info)
{
ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != weights->info()->dimension(1));
// Configure matrix multiply kernel
- configure_mm(input, weights, bias, output, fc_info);
+ configure_mm(compile_context, input, weights, bias, output, fc_info);
}
void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output,
FullyConnectedLayerInfo fc_info)
{
+ configure(CLKernelLibrary::get().get_compile_context(), input, weights, biases, output, fc_info);
+}
+
+void CLFullyConnectedLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output,
+ FullyConnectedLayerInfo fc_info)
+{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
// Perform validate step
@@ -282,13 +296,13 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w
{
if(_weights_manager && _weights_manager->are_weights_managed(weights))
{
- _reshape_weights_managed_function.configure(weights);
+ _reshape_weights_managed_function.configure(compile_context, weights);
weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_reshape_weights_managed_function));
}
else
{
// Reshape the weights
- _reshape_weights_function.configure(weights, &_reshape_weights_output);
+ _reshape_weights_function.configure(compile_context, weights, &_reshape_weights_output);
weights_to_use = &_reshape_weights_output;
}
}
@@ -298,7 +312,7 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w
{
if(_weights_manager && _weights_manager->are_weights_managed(weights_to_use))
{
- _convert_weights_managed.configure(weights_to_use,
+ _convert_weights_managed.configure(compile_context, weights_to_use,
input->info()->tensor_shape(),
fc_info.weights_trained_layout);
weights_to_use = utils::cast::polymorphic_downcast<ICLTensor *>(_weights_manager->acquire(weights, &_convert_weights_managed));
@@ -306,7 +320,7 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w
else
{
// Convert weights
- _convert_weights.configure(weights_to_use,
+ _convert_weights.configure(compile_context, weights_to_use,
&_converted_weights_output,
input->info()->tensor_shape(),
fc_info.weights_trained_layout);
@@ -319,12 +333,12 @@ void CLFullyConnectedLayer::configure(const ICLTensor *input, const ICLTensor *w
if(_is_fc_after_conv)
{
// Fully Connected layer after a Convolution Layer without batches
- configure_conv_fc(input, weights_to_use, biases, output, fc_info);
+ configure_conv_fc(compile_context, input, weights_to_use, biases, output, fc_info);
}
else
{
// Fully Connected layer after a Fully Connected Layer without batches
- configure_fc_fc(input, weights_to_use, biases, output, fc_info);
+ configure_fc_fc(compile_context, input, weights_to_use, biases, output, fc_info);
}
}