aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSiCongLi <sicong.li@arm.com>2021-10-18 09:38:33 +0100
committerSiCong Li <sicong.li@arm.com>2021-11-01 15:18:12 +0000
commit579ca84bd8ef5a91eded65c4dc5e0b9f7de8bef1 (patch)
tree0c2ceba8ad5b2c944bce00055fe1ec7ac84b49f3 /src
parent48717a3d38fef8d316cd4b9fd9a3bc1a43db736b (diff)
downloadComputeLibrary-579ca84bd8ef5a91eded65c4dc5e0b9f7de8bef1.tar.gz
Add PostOp support to GEMM and CLGEMM operators and functions Part 2
* Implement PostOp interface changes * Remove spaces around "=" in TypePrinter Partially resolves COMPMID-4435 Signed-off-by: SiCongLi <sicong.li@arm.com> Change-Id: If1e2280554030a0f635e73339a2e86987f6dc41b Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6484 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Sheri Zhang <sheri.zhang@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src')
-rw-r--r--src/gpu/cl/operators/ClConv2d.cpp4
-rw-r--r--src/gpu/cl/operators/ClGemm.cpp15
-rw-r--r--src/gpu/cl/operators/ClGemm.h4
-rw-r--r--src/gpu/cl/operators/ClGemmConv2d.cpp36
-rw-r--r--src/gpu/cl/operators/ClGemmConv2d.h6
-rw-r--r--src/runtime/CL/functions/CLBatchNormalizationLayer.cpp3
-rw-r--r--src/runtime/CL/functions/CLConvolutionLayer.cpp39
-rw-r--r--src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp31
8 files changed, 99 insertions, 39 deletions
diff --git a/src/gpu/cl/operators/ClConv2d.cpp b/src/gpu/cl/operators/ClConv2d.cpp
index 7fe0de7a6f..d633c8f738 100644
--- a/src/gpu/cl/operators/ClConv2d.cpp
+++ b/src/gpu/cl/operators/ClConv2d.cpp
@@ -92,6 +92,7 @@ void ClConv2d::configure(const CLCompileContext &compile_context, ITensorInfo *s
case ConvolutionMethod::WINOGRAD:
{
ARM_COMPUTE_ERROR_ON(conv2d_info.num_groups != 1);
+ ARM_COMPUTE_ERROR_ON(conv2d_info.post_ops.size() > 0);
auto f = std::make_unique<ClWinogradConv2d>();
f->configure(compile_context, src, weights, biases, dst, conv2d_info.conv_info, conv2d_info.act_info, conv2d_info.enable_fast_math);
_operator = std::move(f);
@@ -100,6 +101,7 @@ void ClConv2d::configure(const CLCompileContext &compile_context, ITensorInfo *s
case ConvolutionMethod::DIRECT:
{
ARM_COMPUTE_ERROR_ON(conv2d_info.num_groups != 1);
+ ARM_COMPUTE_ERROR_ON(conv2d_info.post_ops.size() > 0);
auto f = std::make_unique<ClDirectConv2d>();
f->configure(compile_context, src, weights, biases, dst, conv2d_info.conv_info, conv2d_info.act_info);
_operator = std::move(f);
@@ -133,6 +135,7 @@ Status ClConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, co
{
//Validate Winograd
ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv2d_info.num_groups != 1, "Grouping (num_groups != 1) with ClWinogradConv2d is not supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv2d_info.post_ops.size() > 0, "ClWinogradConv2d does not support PostOps");
ARM_COMPUTE_RETURN_ON_ERROR(ClWinogradConv2d::validate(src, weights, biases, dst, conv2d_info.conv_info, conv2d_info.act_info, conv2d_info.enable_fast_math));
break;
}
@@ -140,6 +143,7 @@ Status ClConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights, co
{
// Validate direct convolution layer
ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv2d_info.num_groups != 1, "Grouping (num_groups != 1) with ClDirectConv2d is not supported");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv2d_info.post_ops.size() > 0, "ClDirectConv2d does not support PostOps");
ARM_COMPUTE_RETURN_ON_ERROR(ClDirectConv2d::validate(src, weights, biases, dst, conv2d_info.conv_info, conv2d_info.act_info));
break;
}
diff --git a/src/gpu/cl/operators/ClGemm.cpp b/src/gpu/cl/operators/ClGemm.cpp
index d2d0f8f91d..e05256ee2f 100644
--- a/src/gpu/cl/operators/ClGemm.cpp
+++ b/src/gpu/cl/operators/ClGemm.cpp
@@ -38,6 +38,7 @@
#include "arm_compute/runtime/CL/CLScheduler.h"
#include "arm_compute/runtime/ITensorAllocator.h"
+#include "arm_compute/core/experimental/IPostOp.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/MemoryHelpers.h"
#include "src/core/utils/helpers/float_ops.h"
@@ -64,7 +65,7 @@ namespace
{
inline bool validate_gemm_kernel(CLGEMMKernelType kernel_type)
{
- return kernel_type == CLGEMMKernelType::NATIVE? false : true;
+ return kernel_type == CLGEMMKernelType::NATIVE ? false : true;
}
//Automatically select between mlgo (prioritized) and default heuristics for gemm kernel type
inline CLGEMMKernelType auto_select_gemm_kernel(auto_heuristics::CommonQuery query, bool reshape_b_only_on_first_run, bool constant_weights)
@@ -203,6 +204,7 @@ ClGemm::ClGemm()
void ClGemm::configure_native(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
const GEMMInfo &gemm_info)
{
+ ARM_COMPUTE_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel");
DataType data_type = a->data_type();
bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
@@ -252,6 +254,7 @@ void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensor
kernel_info.reinterpret_input_as_3d = false;
kernel_info.broadcast_bias = broadcast_bias;
kernel_info.activation_info = gemm_info.activation_info();
+ kernel_info.post_ops = gemm_info.post_ops();
// Set the target for the kernels
_reshape_lhs_kernel->set_target(gpu_target);
@@ -278,6 +281,7 @@ void ClGemm::configure_reshaped(const CLCompileContext &compile_context, ITensor
void ClGemm::configure_reshaped_only_rhs(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta,
const GEMMInfo &gemm_info)
{
+ ARM_COMPUTE_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel");
DataType data_type = a->data_type();
bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d();
const unsigned int m = reinterpret_input_as_3d ? (a->dimension(1) * a->dimension(2)) : a->dimension(1);
@@ -330,6 +334,7 @@ Status ClGemm::validate_native(const ITensorInfo *a, const ITensorInfo *b, const
{
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_UNUSED(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel");
// Get the GPU target
const GPUTarget gpu_target = CLScheduler::get().target();
@@ -386,6 +391,7 @@ Status ClGemm::validate_reshaped(const ITensorInfo *a, const ITensorInfo *b, con
kernel_info.reinterpret_input_as_3d = false;
kernel_info.broadcast_bias = broadcast_bias;
kernel_info.activation_info = gemm_info.activation_info();
+ kernel_info.post_ops = gemm_info.post_ops();
GEMMLHSMatrixInfo lhs_info;
GEMMRHSMatrixInfo rhs_info;
@@ -412,6 +418,7 @@ Status ClGemm::validate_reshaped_only_rhs(const ITensorInfo *a, const ITensorInf
{
ARM_COMPUTE_UNUSED(alpha);
ARM_COMPUTE_UNUSED(output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.post_ops().size() > 0, "PostOps are not supported in this kernel");
TensorInfo tmp_b_info{};
@@ -588,8 +595,10 @@ void ClGemm::run(ITensorPack &tensors)
ITensorPack reshape_rhs_pack{ { ACL_SRC, rhs }, { ACL_DST, rhs_reshaped.get() } };
CLScheduler::get().enqueue_op(*_reshape_rhs_kernel, reshape_rhs_pack, false);
}
-
- ITensorPack gemm_reshaped_pack{ { ACL_SRC_0, lhs_reshaped.get() }, { ACL_SRC_1, rhs_reshaped.get() }, { ACL_SRC_2, src2 }, { ACL_DST, dst } };
+ // Copy original tensor pack and overwrite lhs and rhs with reshaped counterparts
+ ITensorPack gemm_reshaped_pack(tensors);
+ gemm_reshaped_pack.add_const_tensor(ACL_SRC_0, lhs_reshaped.get());
+ gemm_reshaped_pack.add_const_tensor(ACL_SRC_1, rhs_reshaped.get());
if(_gemm_kernel_type == CLGEMMKernelType::RESHAPED)
{
diff --git a/src/gpu/cl/operators/ClGemm.h b/src/gpu/cl/operators/ClGemm.h
index fd53648b3c..e084e53fe4 100644
--- a/src/gpu/cl/operators/ClGemm.h
+++ b/src/gpu/cl/operators/ClGemm.h
@@ -81,8 +81,8 @@ public:
* @param[in] alpha Weight of the matrix product
* @param[in] beta Weight of matrix C
* @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
- * if the reshape of matrix B should happen only for the first run. GEMMInfo also contains information about the reshaping
- * in case matrix A and matrix B have been already transformed.
+ * if the reshape of matrix B should happen only for the first run. GEMMInfo also contains information about the reshaping
+ * in case matrix A and matrix B have been already transformed.
*/
void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, float alpha, float beta, const GEMMInfo &gemm_info);
/** Static function to check if given info will lead to a valid configuration
diff --git a/src/gpu/cl/operators/ClGemmConv2d.cpp b/src/gpu/cl/operators/ClGemmConv2d.cpp
index 785f1f1c9c..7db5fa0052 100644
--- a/src/gpu/cl/operators/ClGemmConv2d.cpp
+++ b/src/gpu/cl/operators/ClGemmConv2d.cpp
@@ -54,14 +54,14 @@ namespace opencl
{
ClGemmConv2d::ClGemmConv2d()
: _weights_reshape_kernel(nullptr), _im2col_kernel(nullptr), _mm_gemm(nullptr), _mm_gemmlowp(nullptr), _col2im_kernel(nullptr), _activation_kernel(nullptr), _im2col_output(), _weights_reshaped(),
- _gemm_output(), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _append_bias(false), _is_prepared(false), _aux_mem(AuxTensorIdx::Count)
+ _gemm_output(), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _fuse_activation(true), _append_bias(false), _is_prepared(false), _use_post_ops(false), _aux_mem(AuxTensorIdx::Count)
{
}
ClGemmConv2d::~ClGemmConv2d() = default;
void ClGemmConv2d::configure_mm(const ClCompileContext &compile_context, const ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst,
const GEMMLowpOutputStageInfo &gemmlowp_output_stage,
- int gemm_3d_depth, const ActivationLayerInfo &act_info)
+ int gemm_3d_depth, const ActivationLayerInfo &act_info, const experimental::PostOpList<ITensorInfo *> &post_ops)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights);
ARM_COMPUTE_ERROR_THROW_ON(validate_mm(src, weights, biases, dst, gemmlowp_output_stage, gemm_3d_depth, _skip_im2col, act_info));
@@ -76,11 +76,14 @@ void ClGemmConv2d::configure_mm(const ClCompileContext &compile_context, const I
false, // fast_math
false, // fp_mixed_precision
true, // broadcast_bias
- act_info); // activation_info
+ act_info, // activation_info
+ post_ops // post ops
+ );
TensorInfo tmp_src{ *src };
if(_is_quantized)
{
+ ARM_COMPUTE_ERROR_ON_MSG(post_ops.size() > 0, "ClGemmConv2d quantized types do not support post ops");
// Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
// Extract and negate input and weights offset
const QuantizationInfo input_quantization_info = src->quantization_info();
@@ -115,7 +118,7 @@ void ClGemmConv2d::configure_mm(const ClCompileContext &compile_context, const I
}
Status ClGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst,
- const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info)
+ const GEMMLowpOutputStageInfo &gemmlowp_output_stage, int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info, const experimental::PostOpList<ITensorInfo *> &post_ops)
{
const bool is_quantized = is_data_type_quantized_asymmetric(src->data_type());
@@ -129,10 +132,13 @@ Status ClGemmConv2d::validate_mm(const ITensorInfo *src, const ITensorInfo *weig
false, // fast_math
false, // fp_mixed_precision
true, // broadcast_bias
- act_info); // activation_info
+ act_info, // activation_info
+ post_ops // post ops
+ );
if(is_quantized)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(post_ops.size() > 0, "ClGemmConv2d quantized types do not support post ops");
// Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
// Extract and negate input and weights offset
const QuantizationInfo input_quantization_info = src->quantization_info();
@@ -183,6 +189,7 @@ void ClGemmConv2d::configure(const CLCompileContext &compile_context, ITensorInf
// Only for quantize there are few cases where we cannot fuse the activation function in GEMM
_fuse_activation = true;
+ _use_post_ops = conv2d_info.post_ops.size() > 0;
const ITensorInfo *gemm_input_to_use = src;
ITensorInfo *gemm_output_to_use = dst;
@@ -311,10 +318,11 @@ void ClGemmConv2d::configure(const CLCompileContext &compile_context, ITensorInf
// In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0;
- configure_mm(compile_context, gemm_input_to_use, &_weights_reshaped, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, conv2d_info.act_info);
+ configure_mm(compile_context, gemm_input_to_use, &_weights_reshaped, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, conv2d_info.act_info, conv2d_info.post_ops);
if(!_skip_col2im)
{
+ ARM_COMPUTE_ERROR_ON_MSG(conv2d_info.post_ops.size() > 0, "ClGemmConv2d does not support post ops with col2im operation"); // Post ops must be performed after every other op
// Set the GPU target for col2im
_col2im_kernel = std::make_unique<opencl::kernels::ClCol2ImKernel>();
_col2im_kernel->set_target(CLScheduler::get().target());
@@ -326,7 +334,8 @@ void ClGemmConv2d::configure(const CLCompileContext &compile_context, ITensorInf
ARM_COMPUTE_ERROR_ON_MSG((dst->dimension(idx_width) != conv_w) || (dst->dimension(idx_height) != conv_h),
"Output shape does not match the expected one");
- if(!_fuse_activation)
+ // Disable running of activation kernel if post ops are used
+ if(!_fuse_activation && !_use_post_ops)
{
_activation_kernel = std::make_unique<opencl::kernels::ClActivationKernel>();
_activation_kernel->configure(compile_context, dst, nullptr, conv2d_info.act_info);
@@ -376,6 +385,7 @@ Status ClGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights
&& conv2d_info.conv_info.stride().second == 1);
const bool skip_col2im = data_layout == DataLayout::NHWC;
bool fuse_activation = true;
+ bool use_post_ops = conv2d_info.post_ops.size() > 0;
ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(idx_channel) * conv2d_info.num_groups) != src->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
@@ -507,16 +517,19 @@ Status ClGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights
// In case of NHWC, we need to run GEMM3D (gemm_3d_depth != 0) in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = (data_layout == DataLayout::NHWC) ? conv_h : 0;
- ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, skip_im2col, conv2d_info.act_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases_to_use, gemm_output_to_use, gemmlowp_output_stage, gemm_3d_depth, skip_im2col, conv2d_info.act_info,
+ conv2d_info.post_ops));
// Validate Col2Im
if(!skip_col2im)
{
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv2d_info.post_ops.size() > 0, "ClGemmConv2d does not support post ops with col2im operation"); // Post ops must be performed after every other op
ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClCol2ImKernel::validate(gemm_output_to_use, dst, Size2D(conv_w, conv_h), conv2d_info.num_groups));
}
- //Validate Activation Layer
- if(!fuse_activation)
+ // Validate Activation Layer
+ // Disable running (thus validation) of activation kernel if post ops are used
+ if(!fuse_activation && !use_post_ops)
{
ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClActivationKernel::validate(dst, nullptr, conv2d_info.act_info));
}
@@ -585,7 +598,8 @@ void ClGemmConv2d::run(ITensorPack &tensors)
}
//Run Activation Layer if we cannot fuse in GEMM
- if(!_fuse_activation)
+ // Disable running of activation kernel if post ops are used
+ if(!_fuse_activation && !_use_post_ops)
{
ITensorPack pack =
{
diff --git a/src/gpu/cl/operators/ClGemmConv2d.h b/src/gpu/cl/operators/ClGemmConv2d.h
index 9a5e381dd7..afde7c511d 100644
--- a/src/gpu/cl/operators/ClGemmConv2d.h
+++ b/src/gpu/cl/operators/ClGemmConv2d.h
@@ -26,6 +26,7 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/experimental/IPostOp.h"
#include "arm_compute/runtime/FunctionDescriptors.h"
#include "src/gpu/cl/ClCompileContext.h"
#include "src/gpu/cl/IClOperator.h"
@@ -132,7 +133,7 @@ private:
*/
void configure_mm(const CLCompileContext &compile_context, const ITensorInfo *src, ITensorInfo *weights, ITensorInfo *biases, ITensorInfo *dst,
const GEMMLowpOutputStageInfo &gemmlowp_output_stage,
- int gemm_3d_depth, const ActivationLayerInfo &act_info);
+ int gemm_3d_depth, const ActivationLayerInfo &act_info, const experimental::PostOpList<ITensorInfo *> &post_ops = experimental::PostOpList<ITensorInfo *> {});
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMConvolutionLayer matrix multiply routines
*
* @param[in] src Input tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
@@ -149,7 +150,7 @@ private:
* @return a status
*/
static Status validate_mm(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *dst, const GEMMLowpOutputStageInfo &gemmlowp_output_stage,
- int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info);
+ int gemm_3d_depth, bool skip_im2col, const ActivationLayerInfo &act_info, const experimental::PostOpList<ITensorInfo *> &post_ops = experimental::PostOpList<ITensorInfo *> {});
enum AuxTensorIdx
{
@@ -177,6 +178,7 @@ private:
bool _fuse_activation;
bool _append_bias;
bool _is_prepared;
+ bool _use_post_ops;
experimental::MemoryRequirements _aux_mem;
};
diff --git a/src/runtime/CL/functions/CLBatchNormalizationLayer.cpp b/src/runtime/CL/functions/CLBatchNormalizationLayer.cpp
index 234a0df2aa..e8affc0853 100644
--- a/src/runtime/CL/functions/CLBatchNormalizationLayer.cpp
+++ b/src/runtime/CL/functions/CLBatchNormalizationLayer.cpp
@@ -29,10 +29,11 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
-#include "src/common/utils/Log.h"
#include "src/core/CL/kernels/CLBatchNormalizationLayerKernel.h"
+#include "src/common/utils/Log.h"
+
namespace arm_compute
{
CLBatchNormalizationLayer::CLBatchNormalizationLayer()
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp
index eaca6ee504..d75f54f19c 100644
--- a/src/runtime/CL/functions/CLConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp
@@ -60,21 +60,26 @@ CLConvolutionLayer::CLConvolutionLayer(std::shared_ptr<IMemoryManager> memory_ma
CLConvolutionLayer::~CLConvolutionLayer() = default;
void CLConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
- const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
+ const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups, const experimental::PostOpList<ICLTensor *> &post_ops)
{
- configure(CLKernelLibrary::get().get_compile_context(), input, weights, biases, output, conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups);
+ configure(CLKernelLibrary::get().get_compile_context(), input, weights, biases, output, conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups, post_ops);
}
void CLConvolutionLayer::configure(const CLCompileContext &compile_context, ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info,
const WeightsInfo &weights_info,
- const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
+ const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups, const experimental::PostOpList<ICLTensor *> &post_ops)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
ARM_COMPUTE_ERROR_THROW_ON(CLConvolutionLayer::validate(input->info(), weights->info(), ((biases != nullptr) ? biases->info() : nullptr), output->info(), conv_info, weights_info, dilation, act_info,
enable_fast_math, num_groups));
- ARM_COMPUTE_LOG_PARAMS(input, weights, biases, output, conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups);
+ ARM_COMPUTE_LOG_PARAMS(input, weights, biases, output, conv_info, weights_info, dilation, act_info, enable_fast_math, num_groups, post_ops);
- const Conv2dInfo conv2d_info = Conv2dInfo(conv_info, dilation, act_info, enable_fast_math, num_groups);
+ // Convert post op arguments to ITensorInfo
+ auto transformed_post_ops = experimental::transform_post_op_list_arguments<ICLTensor *, ITensorInfo *>(post_ops, [](auto tensor)
+ {
+ return tensor->info();
+ });
+ const Conv2dInfo conv2d_info = Conv2dInfo(conv_info, dilation, act_info, enable_fast_math, num_groups, transformed_post_ops);
switch(opencl::ClConv2d::get_convolution_method(input->info(), weights->info(), output->info(), conv2d_info,
weights_info, CLScheduler::get().target()))
@@ -90,6 +95,7 @@ void CLConvolutionLayer::configure(const CLCompileContext &compile_context, ICLT
}
case ConvolutionMethod::FFT:
{
+ ARM_COMPUTE_ERROR_ON_MSG(post_ops.size() > 0, "CLFFTConvolutionLayer does not support post ops");
auto f = std::make_unique<CLFFTConvolutionLayer>(_impl->memory_manager);
f->configure(compile_context, input, weights, biases, output, conv_info, act_info, enable_fast_math);
_impl->func = std::move(f);
@@ -102,22 +108,30 @@ void CLConvolutionLayer::configure(const CLCompileContext &compile_context, ICLT
if(_impl->op)
{
- _impl->memory_group = MemoryGroup(std::move(_impl->memory_manager));
- _impl->aux_mem_req = _impl->op->workspace();
- _impl->run_pack = { { ACL_SRC_0, input }, { ACL_SRC_1, weights }, { ACL_SRC_2, biases }, { ACL_DST, output } };
- _impl->prep_pack = { { ACL_SRC_1, weights }, { ACL_SRC_2, biases } };
- _impl->workspace = manage_workspace<CLTensor>(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->prep_pack);
+ _impl->memory_group = MemoryGroup(std::move(_impl->memory_manager));
+ _impl->aux_mem_req = _impl->op->workspace();
+ _impl->run_pack = { { ACL_SRC_0, input }, { ACL_SRC_1, weights }, { ACL_SRC_2, biases }, { ACL_DST, output } };
+ size_t post_op_tensor_index = 0;
+ for(const auto &op : post_ops.get_list())
+ {
+ for(auto &tensor : op->arguments())
+ {
+ _impl->run_pack.add_const_tensor(experimental::get_post_op_arg_type(post_op_tensor_index++), *tensor);
+ }
+ }
+ _impl->prep_pack = { { ACL_SRC_1, weights }, { ACL_SRC_2, biases } };
+ _impl->workspace = manage_workspace<CLTensor>(_impl->aux_mem_req, _impl->memory_group, _impl->run_pack, _impl->prep_pack);
}
}
Status CLConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
- const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups)
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups, const experimental::PostOpList<ITensorInfo *> &post_ops)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
ARM_COMPUTE_RETURN_ERROR_ON_MSG((num_groups != 1) && (input->data_layout() != DataLayout::NCHW), "Grouping (num_groups != 1) with NHWC data layout is not supported");
const GPUTarget gpu_target = CLScheduler::get().target();
- const Conv2dInfo conv2d_info = Conv2dInfo(conv_info, dilation, act_info, enable_fast_math, num_groups);
+ const Conv2dInfo conv2d_info = Conv2dInfo(conv_info, dilation, act_info, enable_fast_math, num_groups, post_ops);
switch(opencl::ClConv2d::get_convolution_method(input, weights, output, conv2d_info, weights_info, gpu_target))
{
@@ -131,6 +145,7 @@ Status CLConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo
case ConvolutionMethod::FFT:
{
// Validate FFT-based convolution layer
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(post_ops.size() > 0, "CLFFTConvolutionLayer does not support post ops");
ARM_COMPUTE_RETURN_ON_ERROR(CLFFTConvolutionLayer::validate(input, weights, nullptr, output, conv_info, act_info, enable_fast_math));
break;
}
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index 837527bac3..1eabee65f8 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -31,6 +31,7 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "src/core/experimental/PostOp.h"
#include "src/core/helpers/MemoryHelpers.h"
#include "src/gpu/cl/operators/ClGemmConv2d.h"
#include "support/Cast.h"
@@ -68,19 +69,24 @@ CLGEMMConvolutionLayer::CLGEMMConvolutionLayer(std::shared_ptr<IMemoryManager> m
CLGEMMConvolutionLayer::~CLGEMMConvolutionLayer() = default;
void CLGEMMConvolutionLayer::configure(const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
- const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
+ const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups, const experimental::PostOpList<ICLTensor *> &post_ops)
{
- configure(CLKernelLibrary::get().get_compile_context(), input, weights, biases, output, conv_info, weights_info, dilation, act_info, num_groups);
+ configure(CLKernelLibrary::get().get_compile_context(), input, weights, biases, output, conv_info, weights_info, dilation, act_info, num_groups, post_ops);
}
void CLGEMMConvolutionLayer::configure(const CLCompileContext &compile_context, const ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output,
const PadStrideInfo &conv_info,
- const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups, const experimental::PostOpList<ICLTensor *> &post_ops)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
- _impl->weights = weights;
- _impl->op = std::make_unique<opencl::ClGemmConv2d>();
- const Conv2dInfo conv2d_info = Conv2dInfo(conv_info, dilation, act_info, false, num_groups);
+ _impl->weights = weights;
+ _impl->op = std::make_unique<opencl::ClGemmConv2d>();
+ // Convert post op arguments to ITensorInfo
+ auto transformed_post_ops = experimental::transform_post_op_list_arguments<ICLTensor *, ITensorInfo *>(post_ops, [](auto tensor)
+ {
+ return tensor->info();
+ });
+ const Conv2dInfo conv2d_info = Conv2dInfo(conv_info, dilation, act_info, false, num_groups, transformed_post_ops);
_impl->op->configure(compile_context, input->info(), weights->info(), (biases != nullptr ? biases->info() : nullptr), output->info(), conv2d_info, weights_info);
_impl->run_pack =
@@ -90,6 +96,15 @@ void CLGEMMConvolutionLayer::configure(const CLCompileContext &compile_context,
{ TensorType::ACL_SRC_2, biases },
{ TensorType::ACL_DST, output }
};
+ // Add post op tensors
+ size_t post_op_tensor_index = 0;
+ for(const auto &op : post_ops.get_list())
+ {
+ for(auto &tensor : op->arguments())
+ {
+ _impl->run_pack.add_const_tensor(experimental::get_post_op_arg_type(post_op_tensor_index++), *tensor);
+ }
+ }
_impl->prep_pack =
{
{ TensorType::ACL_SRC_1, weights },
@@ -100,9 +115,9 @@ void CLGEMMConvolutionLayer::configure(const CLCompileContext &compile_context,
}
Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
- const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
+ const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups, const experimental::PostOpList<ITensorInfo *> &post_ops)
{
- const Conv2dInfo conv2d_info = Conv2dInfo(conv_info, dilation, act_info, false, num_groups);
+ const Conv2dInfo conv2d_info = Conv2dInfo(conv_info, dilation, act_info, false, num_groups, post_ops);
return opencl::ClGemmConv2d::validate(input, weights, biases, output, conv2d_info, weights_info);
}