aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSiCongLi <sicong.li@arm.com>2021-11-03 19:01:22 +0000
committerSiCong Li <sicong.li@arm.com>2021-11-04 14:03:19 +0000
commitd928735fee6baefdb74325c05d8152dd13044f32 (patch)
tree6fb702e36da2863639149995e2df1cfe70905fc7
parent6049edadf0c89a026b3fcd1927ee7531d3c40278 (diff)
downloadComputeLibrary-d928735fee6baefdb74325c05d8152dd13044f32.tar.gz
Add validate tests for CLConvolutionLayer and CLGEMMConvolutionLayer with post ops
* Add validate tests * Restrict post ops support in ClGemmConv2d to only those that do not need im2col or col2im. In practice this means we only support post ops in conv1x1 with stride = 1, dilation = 1 and data layout = NHWC Resolves COMPMID-4435 Change-Id: I1fdf0c5d565a4624857250075ac76db35c2f383b Signed-off-by: SiCongLi <sicong.li@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6573 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/experimental/IPostOp.h9
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h31
-rw-r--r--src/core/CL/CLUtils.cpp6
-rw-r--r--src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl2
-rw-r--r--src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl8
-rw-r--r--src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl8
-rw-r--r--src/gpu/cl/operators/ClGemmConv2d.cpp4
-rw-r--r--tests/validation/CL/ConvolutionLayer.cpp244
-rw-r--r--utils/TypePrinter.h6
9 files changed, 295 insertions, 23 deletions
diff --git a/arm_compute/core/experimental/IPostOp.h b/arm_compute/core/experimental/IPostOp.h
index 178c83aa75..567a4023c0 100644
--- a/arm_compute/core/experimental/IPostOp.h
+++ b/arm_compute/core/experimental/IPostOp.h
@@ -71,6 +71,15 @@ using PostOpTypeSequence = std::vector<PostOpType>;
* * post_op_arg1 = [1, 1, 34] is allowed: broadcast in dims 0 and 1
* * post_op_arg1 = [14, 15, 34] is NOT allowed: broadcast widens the dst tensor
*
+ * @note: On Data layout
+ * All post ops are data layout agnostic. This means post ops do not have an inherent idea of "width", "height" and so on.
+ * Should we want to perform a post op with 2 tensors of different data layouts (where data layouts are significant to both),
+ * then we need to perform necessary permutation op beforehand to unify their data layout before they can be fused with a post op
+ *
+ * Note although post ops themselves should be able to support any data layout, the main op they fuse to may impose
+ * additional restrictions in the presence of post ops. For example, the implementation of a gemm op may only allow
+ * NHWC data layout if post ops are provided. Such restrictions are main op implementation specific.
+ *
* @note: PostOps do not own any resources pointed to by TensorRelatedT if it's a pointer type
* @note: If TensorRelatedT points to a resource, IPostOp assumes that resource is valid throughout its lifetime
* and the lifetime of its copies. This is almost guaranteed as IPostOp is only meant to be used at configure time
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index f18f5b7a42..3e8b024f82 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -703,20 +703,18 @@ inline TensorShape compute_winograd_output_transform_shape(const ITensorInfo &in
/** Calculate the deep convolution shape output shape of a tensor
*
- * @param[in] input Input tensor info
- * @param[in] weights Weights tensor info
- * @param[in] conv_info Contains padding and stride information
+ * @param[in] input_shape Input tensor shape
+ * @param[in] input_data_layout Input data layout
+ * @param[in] weights_shape Weights tensor shape
+ * @param[in] conv_info Contains padding and stride information
*
* @return the calculated shape
*/
-inline TensorShape compute_deep_convolution_shape(const ITensorInfo &input, const ITensorInfo &weights, PadStrideInfo conv_info)
+inline TensorShape compute_deep_convolution_shape(const TensorShape &input_shape, DataLayout input_data_layout, const TensorShape &weights_shape, const PadStrideInfo &conv_info)
{
- const TensorShape input_shape{ input.tensor_shape() };
- const TensorShape weights_shape{ weights.tensor_shape() };
-
- const size_t idx_width = get_data_layout_dimension_index(input.data_layout(), DataLayoutDimension::WIDTH);
- const size_t idx_height = get_data_layout_dimension_index(input.data_layout(), DataLayoutDimension::HEIGHT);
- const size_t idx_channel = get_data_layout_dimension_index(input.data_layout(), DataLayoutDimension::CHANNEL);
+ const size_t idx_width = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::WIDTH);
+ const size_t idx_height = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::HEIGHT);
+ const size_t idx_channel = get_data_layout_dimension_index(input_data_layout, DataLayoutDimension::CHANNEL);
const unsigned int input_width = input_shape[idx_width];
const unsigned int input_height = input_shape[idx_height];
@@ -735,6 +733,19 @@ inline TensorShape compute_deep_convolution_shape(const ITensorInfo &input, cons
return output_shape;
}
+/** Calculate the deep convolution shape output shape of a tensor
+ *
+ * @param[in] input Input tensor info
+ * @param[in] weights Weights tensor info
+ * @param[in] conv_info Contains padding and stride information
+ *
+ * @return the calculated shape
+ */
+inline TensorShape compute_deep_convolution_shape(const ITensorInfo &input, const ITensorInfo &weights, const PadStrideInfo &conv_info)
+{
+ return compute_deep_convolution_shape(input.tensor_shape(), input.data_layout(), weights.tensor_shape(), conv_info);
+}
+
/** Calculate the min/max shape output shape of a tensor
*
* @param[in] input Input tensor info
diff --git a/src/core/CL/CLUtils.cpp b/src/core/CL/CLUtils.cpp
index 88b31c8349..8dab8aa876 100644
--- a/src/core/CL/CLUtils.cpp
+++ b/src/core/CL/CLUtils.cpp
@@ -96,9 +96,9 @@ bool PostOpCLKernelUtils::are_post_op_shapes_compliant(const ITensorInfo *dst, c
return false;
}
// NOTE: Kernel limitation: currently only the following broadcasting types are supported:
- // 1. Post op arg is scalar, broadcast in both X and Y
- // 2. Post op arg is of shape: Y=1, X=N, broadcast only in Y
- // This means this case: Post op arg is of shape: Y=M, X=1, broadcast only in X, is NOT supported
+ // 1. Post op arg is scalar, broadcast in both first and second dims
+ // 2. Post op arg is of shape: second dim=1, first dim=N, broadcast only in second dim
+ // This means this case: Post op arg is of shape: second dim=M, first dim=1, broadcast only in first dim, is NOT supported
if(dst->dimension(0) > 1 && dst->dimension(1) > 1 && (*tensor)->dimension(0) == 1 && (*tensor)->dimension(1) > 1)
{
return false;
diff --git a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl
index bbe97b2781..4665d612f5 100644
--- a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl
+++ b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_native.cl
@@ -133,7 +133,7 @@ __kernel void gemm_mm_native_post_act_eltwise_op_act(IMAGE_DECLARATION(lhs),
IMAGE_DECLARATION(bias),
#endif // defined(BETA)
IMAGE_DECLARATION(dst),
- // Post-Op arguments
+ // Post Op arguments
IMAGE_DECLARATION(eltwise_operand),
uint lhs_stride_z,
uint rhs_stride_z,
diff --git a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl
index 9e9a73ccf6..32186c359b 100644
--- a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl
+++ b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped.cl
@@ -233,7 +233,7 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t_post_act_eltwise_op_act(IMAGE_DECLAR
IMAGE_DECLARATION(bias),
#endif // defined(BETA)
IMAGE_DECLARATION(dst),
- // Post-Op arguments
+ // Post Op arguments
IMAGE_DECLARATION(eltwise_operand),
uint k,
uint lhs_stride_z,
@@ -453,7 +453,7 @@ __kernel void gemm_mm_reshaped_lhs_nt_rhs_t_texture_post_act_eltwise_op_act(IMAG
IMAGE_DECLARATION(bias),
#endif // defined(BETA)
IMAGE_DECLARATION(dst),
- // Post-Op arguments
+ // Post Op arguments
IMAGE_DECLARATION(eltwise_operand),
uint k,
uint lhs_stride_z,
@@ -781,7 +781,7 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt_post_act_eltwise_op_act(IMAGE_DECLAR
IMAGE_DECLARATION(bias),
#endif // defined(BETA)
IMAGE_DECLARATION(dst),
- // Post-Op arguments
+ // Post Op arguments
IMAGE_DECLARATION(eltwise_operand),
uint k,
uint lhs_stride_z,
@@ -1110,7 +1110,7 @@ __kernel void gemm_mm_reshaped_lhs_t_rhs_nt_texture_post_act_eltwise_op_act(IMAG
IMAGE_DECLARATION(bias),
#endif // defined(BETA)
IMAGE_DECLARATION(dst),
- // Post-Op arguments
+ // Post Op arguments
IMAGE_DECLARATION(eltwise_operand),
uint k,
uint lhs_stride_z,
diff --git a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl
index fe2d103de5..e96aba613b 100644
--- a/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl
+++ b/src/core/CL/cl_kernels/common/experimental/gemm_fused_post_ops/act_eltwise_op_act/gemm_mm_reshaped_only_rhs.cl
@@ -177,7 +177,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t_post_act_eltwise_op_act(IMAGE_DECLARAT
IMAGE_DECLARATION(bias),
#endif // defined(BETA)
IMAGE_DECLARATION(dst),
- // Post-Op arguments
+ // Post Op arguments
IMAGE_DECLARATION(eltwise_operand),
uint lhs_stride_z,
uint rhs_stride_z,
@@ -437,7 +437,7 @@ __kernel void gemm_mm_reshaped_only_rhs_t_texture_post_act_eltwise_op_act(IMAGE_
IMAGE_DECLARATION(bias),
#endif // defined(BETA)
IMAGE_DECLARATION(dst),
- // Post-Op arguments
+ // Post Op arguments
IMAGE_DECLARATION(eltwise_operand),
uint lhs_stride_z,
uint rhs_stride_z,
@@ -831,7 +831,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_post_act_eltwise_op_act(IMAGE_DECLARA
IMAGE_DECLARATION(bias),
#endif // defined(BETA)
IMAGE_DECLARATION(dst),
- // Post-Op arguments
+ // Post Op arguments
IMAGE_DECLARATION(eltwise_operand),
uint lhs_stride_z,
uint rhs_stride_z,
@@ -1116,7 +1116,7 @@ __kernel void gemm_mm_reshaped_only_rhs_nt_texture_post_act_eltwise_op_act(IMAGE
IMAGE_DECLARATION(bias),
#endif // defined(BETA)
IMAGE_DECLARATION(dst),
- // Post-Op arguments
+ // Post Op arguments
IMAGE_DECLARATION(eltwise_operand),
uint lhs_stride_z,
uint rhs_stride_z,
diff --git a/src/gpu/cl/operators/ClGemmConv2d.cpp b/src/gpu/cl/operators/ClGemmConv2d.cpp
index 7db5fa0052..682477e4ea 100644
--- a/src/gpu/cl/operators/ClGemmConv2d.cpp
+++ b/src/gpu/cl/operators/ClGemmConv2d.cpp
@@ -389,6 +389,9 @@ Status ClGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights
ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(idx_channel) * conv2d_info.num_groups) != src->dimension(idx_channel));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!skip_im2col
+ && conv2d_info.post_ops.size() > 0,
+ "ClGemmConv2d does not support post ops with col2im or im2col operation"); // Post ops must be performed after every other op
// Validate biases
if(biases != nullptr)
@@ -523,7 +526,6 @@ Status ClGemmConv2d::validate(const ITensorInfo *src, const ITensorInfo *weights
// Validate Col2Im
if(!skip_col2im)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv2d_info.post_ops.size() > 0, "ClGemmConv2d does not support post ops with col2im operation"); // Post ops must be performed after every other op
ARM_COMPUTE_RETURN_ON_ERROR(kernels::ClCol2ImKernel::validate(gemm_output_to_use, dst, Size2D(conv_w, conv_h), conv2d_info.num_groups));
}
diff --git a/tests/validation/CL/ConvolutionLayer.cpp b/tests/validation/CL/ConvolutionLayer.cpp
index ae2949c767..ff28ac0985 100644
--- a/tests/validation/CL/ConvolutionLayer.cpp
+++ b/tests/validation/CL/ConvolutionLayer.cpp
@@ -22,10 +22,12 @@
* SOFTWARE.
*/
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/CL/CLTensor.h"
#include "arm_compute/runtime/CL/CLTensorAllocator.h"
#include "arm_compute/runtime/CL/functions/CLConvolutionLayer.h"
#include "arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h"
+#include "src/core/experimental/PostOp.h"
#include "tests/CL/CLAccessor.h"
#include "tests/PaddingCalculator.h"
#include "tests/datasets/LargeConvolutionLayerDataset.h"
@@ -88,6 +90,29 @@ const auto ActivationFunctionsSmallDataset = framework::dataset::make("Activatio
ActivationLayerInfo(),
ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 0.5f)
});
+
+bool is_post_op_list_valid_in_gemmconv(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &output_shape, DataType data_type, DataLayout data_layout,
+ const PadStrideInfo &conv_info, const experimental::PostOpList<ITensorInfo *> &post_ops)
+{
+ const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
+
+ const auto dilation = Size2D(1U, 1U);
+ const unsigned int num_groups = 1U;
+
+ TensorInfo input_info(input_shape, 1, data_type, data_layout);
+ TensorInfo weights_info(weights_shape, 1, data_type, data_layout);
+
+ TensorInfo output_info(output_shape, 1, data_type, data_layout);
+
+ WeightsInfo w_info(false, weights_info.dimension(idx_width), weights_info.dimension(idx_height), weights_info.dimension(idx_kernels));
+
+ const auto status = CLGEMMConvolutionLayer::validate(&input_info.clone()->set_is_resizable(true),
+ &weights_info.clone()->set_is_resizable(true), nullptr, &output_info.clone()->set_is_resizable(true),
+ conv_info, w_info, dilation, ActivationLayerInfo(), num_groups, post_ops);
+ return bool(status);
+}
} // namespace
TEST_SUITE(CL)
@@ -179,6 +204,72 @@ DATA_TEST_CASE(ValidateConvolutionMethod, framework::DatasetMode::ALL, zip(zip(z
enable_fast_math);
ARM_COMPUTE_EXPECT(is_valid == expected, framework::LogLevel::ERRORS);
}
+
+DATA_TEST_CASE(ValidatePostOpSupportInConvolutionMethod, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(
+ framework::dataset::make("InputInfo", { TensorInfo(TensorShape(2U, 17U, 31U), 1, DataType::F32, DataLayout::NHWC), // Select GEMM
+ TensorInfo(TensorShape(17U, 31U, 32U), 1, DataType::F32, DataLayout::NCHW), // Select WINOGRAD
+ TensorInfo(TensorShape(27U, 27U, 48U), 1, DataType::F32, DataLayout::NCHW), // Select Direct
+ TensorInfo(TensorShape(27U, 27U, 48U), 1, DataType::F32, DataLayout::NCHW), // Select FFT
+ }),
+ framework::dataset::make("WeightsInfo", { TensorInfo(TensorShape(2U, 1U, 1U, 19U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(5U, 5U, 32U, 19U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(5U, 5U, 48U, 128U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(11U, 11U, 48U, 24), 1, DataType::F32, DataLayout::NCHW),
+ })),
+ framework::dataset::make("OutputInfo", { TensorInfo(TensorShape(19U, 17U, 31U), 1, DataType::F32, DataLayout::NHWC),
+ TensorInfo(TensorShape(17U, 31U, 19U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(27U, 27U, 128U), 1, DataType::F32, DataLayout::NCHW),
+ TensorInfo(TensorShape(27U, 27U, 24U), 1, DataType::F32, DataLayout::NCHW),
+ })),
+ framework::dataset::make("ConvInfo", { PadStrideInfo(1U, 1U, 0U, 0U),
+ PadStrideInfo(1U, 1U, 2U, 2U),
+ PadStrideInfo(1U, 1U, 2U, 2U),
+ PadStrideInfo(1U, 1U, 5U, 5U),
+ })),
+ framework::dataset::make("EnableFastMath", { false, true, false, false})),
+ framework::dataset::make("ExpectedMethod",{ ConvolutionMethod::GEMM,
+ ConvolutionMethod::WINOGRAD,
+ ConvolutionMethod::DIRECT,
+ ConvolutionMethod::FFT,
+ })),
+ framework::dataset::make("PostOpSupported",{ true, false, false, false
+ })),
+ input_info, weights_info, output_info, conv_info, enable_fast_math, expected_method, post_op_supported)
+{
+ const int idx_width = get_data_layout_dimension_index(input_info.data_layout(), DataLayoutDimension::WIDTH);
+ const int idx_height = get_data_layout_dimension_index(input_info.data_layout(), DataLayoutDimension::HEIGHT);
+ const int idx_kernels = get_data_layout_dimension_index(input_info.data_layout(), DataLayoutDimension::BATCHES);
+
+ const auto dilation = Size2D(1U, 1U);
+ const unsigned int num_groups = 1U;
+
+ WeightsInfo w_info(false, weights_info.dimension(idx_width), weights_info.dimension(idx_height), weights_info.dimension(idx_kernels));
+
+ experimental::PostOpList<ITensorInfo*> post_ops{};
+ post_ops.push_back_op<experimental::PostOpAct<ITensorInfo*>>(ActivationLayerInfo{ActivationLayerInfo::ActivationFunction::LINEAR, 0.5F, 0.0F});
+
+ ConvolutionMethod actual_method = CLConvolutionLayer::get_convolution_method(&input_info.clone()->set_is_resizable(true),
+ &weights_info.clone()->set_is_resizable(true),
+ &output_info.clone()->set_is_resizable(true), conv_info,
+ WeightsInfo(),
+ ActivationLayerInfo(),
+ GPUTarget::BIFROST,
+ dilation,
+ enable_fast_math);
+ ARM_COMPUTE_EXPECT(actual_method == expected_method, framework::LogLevel::ERRORS);
+ const auto is_valid = CLConvolutionLayer::validate(&input_info.clone()->set_is_resizable(true),
+ &weights_info.clone()->set_is_resizable(true),
+ nullptr,
+ &output_info.clone()->set_is_resizable(true),
+ conv_info,
+ w_info,
+ dilation,
+ ActivationLayerInfo(),
+ enable_fast_math,
+ num_groups,
+ post_ops);
+ ARM_COMPUTE_EXPECT( bool(is_valid) == post_op_supported, framework::LogLevel::ERRORS);
+}
// clang-format on
// *INDENT-ON*
TEST_SUITE_END() // ConvolutionLayer
@@ -191,6 +282,159 @@ using CLGEMMConvolutionLayerMixedDataLayoutFixture = ConvolutionValidationFixtur
template <typename T>
using CLConvolutionValidationWithPaddingFixture = ConvolutionValidationWithPaddingFixture<CLTensor, CLAccessor, CLGEMMConvolutionLayer, T>;
+TEST_SUITE(ValidateFusedPostOpsConfigs)
+TEST_SUITE(Invalid)
+TEST_CASE(UnsupportedPostOpSequence, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const auto data_layout = DataLayout::NHWC;
+ const auto conv_info = PadStrideInfo(1, 1, 0, 0);
+ const auto input_shape = TensorShape(16U, 14U, 12U, 2U);
+ const auto weights_shape = TensorShape(16U, 1U, 1U, 24U);
+
+ const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+ const TensorShape post_op_arg0_shape(output_shape);
+ TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
+ auto post_op_arg1_info = post_op_arg_info.clone();
+
+ // Unsupported sequence of post ops
+ experimental::PostOpList<ITensorInfo *> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+ &post_op_arg_info,
+ 1,
+ ConvertPolicy::SATURATE);
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+ post_op_arg1_info.get(),
+ 0,
+ ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_CASE(OnlyNHWCIsSupported, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const auto data_layout = DataLayout::NCHW;
+ const auto conv_info = PadStrideInfo(1, 1, 0, 0);
+ const auto input_shape = TensorShape(14U, 12U, 16U, 2U);
+ const auto weights_shape = TensorShape(1U, 1U, 16U, 24U);
+
+ const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+ const TensorShape post_op_arg0_shape(output_shape);
+ TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
+
+ experimental::PostOpList<ITensorInfo *> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+ &post_op_arg_info,
+ 1,
+ ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_CASE(OnlyFloatingTypeIsSupported, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::QASYMM8;
+ const auto data_layout = DataLayout::NHWC;
+ const auto conv_info = PadStrideInfo(1, 1, 0, 0);
+ const auto input_shape = TensorShape(16U, 14U, 12U, 2U);
+ const auto weights_shape = TensorShape(16U, 1U, 1U, 24U);
+
+ const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+ const TensorShape post_op_arg0_shape(output_shape);
+ TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
+
+ experimental::PostOpList<ITensorInfo *> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+ &post_op_arg_info,
+ 1,
+ ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_CASE(OnlyConv1x1Stride1IsSupported_UnsupportedKernelSize, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const auto data_layout = DataLayout::NHWC;
+ const auto conv_info = PadStrideInfo(1, 1, 0, 0);
+ const auto input_shape = TensorShape(16U, 14U, 12U, 2U);
+ const auto weights_shape = TensorShape(16U, 3U, 3U, 24U);
+
+ const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+ const TensorShape post_op_arg0_shape(output_shape);
+ TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
+
+ experimental::PostOpList<ITensorInfo *> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+ &post_op_arg_info,
+ 1,
+ ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_CASE(OnlyConv1x1Stride1IsSupported_UnsupportedStride, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const auto data_layout = DataLayout::NHWC;
+ const auto conv_info = PadStrideInfo(3, 3, 0, 0);
+ const auto input_shape = TensorShape(16U, 14U, 12U, 2U);
+ const auto weights_shape = TensorShape(16U, 1U, 1U, 24U);
+
+ const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+ const TensorShape post_op_arg0_shape(output_shape);
+ TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
+
+ experimental::PostOpList<ITensorInfo *> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+ &post_op_arg_info,
+ 1,
+ ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == false, framework::LogLevel::ERRORS);
+}
+TEST_SUITE_END() // Invalid
+TEST_SUITE(Valid)
+TEST_CASE(EmptyPostOpList, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const auto data_layout = DataLayout::NHWC;
+ const auto conv_info = PadStrideInfo(1, 1, 0, 0);
+ const auto input_shape = TensorShape(16U, 14U, 12U, 2U);
+ const auto weights_shape = TensorShape(16U, 1U, 1U, 24U);
+
+ const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+ experimental::PostOpList<ITensorInfo *> post_ops{};
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == true, framework::LogLevel::ERRORS);
+}
+TEST_CASE(SupportedPostOps, framework::DatasetMode::ALL)
+{
+ const auto data_type = DataType::F32;
+ const auto data_layout = DataLayout::NHWC;
+ const auto conv_info = PadStrideInfo(1, 1, 0, 0);
+ const auto input_shape = TensorShape(16U, 14U, 12U, 2U);
+ const auto weights_shape = TensorShape(16U, 1U, 1U, 24U);
+
+ const auto output_shape = misc::shape_calculator::compute_deep_convolution_shape(input_shape, data_layout, weights_shape, conv_info);
+
+ TensorShape post_op_arg0_shape(output_shape);
+ post_op_arg0_shape[1] = 1; // Broadcast in "Y" (second) dimension
+ TensorInfo post_op_arg_info(post_op_arg0_shape, 1, data_type);
+
+ experimental::PostOpList<ITensorInfo *> post_ops{};
+ post_ops.push_back_op<experimental::PostOpEltwiseAdd<ITensorInfo *>>(
+ &post_op_arg_info,
+ 1,
+ ConvertPolicy::SATURATE);
+
+ ARM_COMPUTE_EXPECT(is_post_op_list_valid_in_gemmconv(input_shape, weights_shape, output_shape, data_type, data_layout, conv_info, post_ops) == true, framework::LogLevel::ERRORS);
+}
+TEST_SUITE_END() // Valid
+TEST_SUITE_END() // ValidateFusedPostOps
TEST_SUITE(Float)
TEST_SUITE(FP16)
diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h
index 785b41fc62..9858478c29 100644
--- a/utils/TypePrinter.h
+++ b/utils/TypePrinter.h
@@ -2180,6 +2180,12 @@ inline ::std::ostream &operator<<(::std::ostream &os, const ConvolutionMethod &c
case ConvolutionMethod::WINOGRAD:
os << "WINOGRAD";
break;
+ case ConvolutionMethod::FFT:
+ os << "FFT";
+ break;
+ case ConvolutionMethod::GEMM_CONV2D:
+ os << "GEMM_CONV2D";
+ break;
default:
ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
}