aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorManuel Bottini <manuel.bottini@arm.com>2020-09-15 13:03:34 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-10-14 09:49:04 +0000
commit87350f47084d2b69daa11c3b1c7eb47e02260063 (patch)
tree9324b91dd5b154209c3af24ceec02286537ddf36 /tests
parentcbede286da8711031fb6fc56bb2e2c246a4c5455 (diff)
downloadComputeLibrary-87350f47084d2b69daa11c3b1c7eb47e02260063.tar.gz
COMPMID-3144: Remove padding from NEDirectConvolutionLayerKernel
Change-Id: I22b907eebfbe037e6e1c7bf604172f4709a9cbed Signed-off-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4082 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/NEON/DirectConvolutionLayer.cpp52
-rw-r--r--tests/validation/fixtures/DirectConvolutionLayerFixture.h3
2 files changed, 45 insertions, 10 deletions
diff --git a/tests/validation/NEON/DirectConvolutionLayer.cpp b/tests/validation/NEON/DirectConvolutionLayer.cpp
index 7277592736..afd9e3952f 100644
--- a/tests/validation/NEON/DirectConvolutionLayer.cpp
+++ b/tests/validation/NEON/DirectConvolutionLayer.cpp
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+#include "arm_compute/core/Helpers.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h"
#include "arm_compute/runtime/Tensor.h"
@@ -78,12 +79,12 @@ const auto data_f16 = combine(datasets::SmallDirectConvolutionShapes(),
combine(framework::dataset::make("StrideY", { 1, 2, 3 }),
data_pad_f16)));
-const auto data = combine(datasets::SmallDirectConvolutionShapes(),
- combine(framework::dataset::make("StrideX", { 1 }),
- combine(framework::dataset::make("StrideY", { 1 }),
- combine(framework::dataset::make("PadX", { 1 }),
- combine(framework::dataset::make("PadY", { 1 }),
- framework::dataset::make("KernelSize", 3))))));
+const auto data_prec = combine(datasets::SmallDirectConvolutionShapes(),
+ combine(framework::dataset::make("StrideX", { 1 }),
+ combine(framework::dataset::make("StrideY", { 1 }),
+ combine(framework::dataset::make("PadX", { 1 }),
+ combine(framework::dataset::make("PadY", { 1 }),
+ framework::dataset::make("KernelSize", 3))))));
const auto data9x9 = combine(datasets::SmallDirectConvolutionShapes(),
combine(framework::dataset::make("StrideX", { 1 }),
@@ -95,7 +96,7 @@ const auto data9x9 = combine(datasets::SmallDirectConvolutionShapes(),
const auto data_f32_nightly = combine(data_f32, framework::dataset::make("NumKernels", { 1, 4 }));
const auto data_f16_nightly = combine(data_f16, framework::dataset::make("NumKernels", { 1, 4 }));
-const auto data_precommit = combine(data, framework::dataset::make("NumKernels", { 1 }));
+const auto data_precommit = combine(data_prec, framework::dataset::make("NumKernels", { 1 }));
const auto data_precommit9x9 = combine(data9x9, framework::dataset::make("NumKernels", { 4 }));
/* The following tests is from real use-case that made DirectConvolution
@@ -195,6 +196,43 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(
// clang-format on
// *INDENT-ON*
+DATA_TEST_CASE(NoPaddingNHWCKernel, framework::DatasetMode::ALL, combine(combine(combine(data_precommit,
+ framework::dataset::make("DataType", DataType::F32)),
+ ActivationFunctionsDataset),
+ framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+
+ shape, stride_x, stride_y, pad_x, pad_y, kernel_size, num_kernels, data_type, act_info, data_layout)
+{
+ TensorShape input_shape = TensorShape(shape);
+ TensorShape weights_shape(kernel_size, kernel_size, input_shape.z(), num_kernels);
+ const PadStrideInfo info(stride_x, stride_y, pad_x, pad_y, DimensionRoundingType::FLOOR);
+
+ TensorInfo input_info = TensorInfo(input_shape, 1, data_type);
+ TensorInfo weights_info = TensorInfo(weights_shape, 1, data_type);
+
+ TensorShape output_shape = compute_deep_convolution_shape(input_info, weights_info, info);
+
+ if(data_layout == DataLayout::NHWC)
+ {
+ permute(input_shape, PermutationVector(2U, 0U, 1U));
+ permute(weights_shape, PermutationVector(2U, 0U, 1U));
+ permute(output_shape, PermutationVector(2U, 0U, 1U));
+ }
+
+ // Create tensors
+ Tensor src = create_tensor<Tensor>(input_shape, data_type, 1, QuantizationInfo(), data_layout);
+ Tensor weights = create_tensor<Tensor>(weights_shape, data_type, 1, QuantizationInfo(), data_layout);
+ Tensor dst = create_tensor<Tensor>(output_shape, data_type, 1, QuantizationInfo(), data_layout);
+
+ // Create and configure function
+ NEDirectConvolutionLayer conv;
+ conv.configure(&src, &weights, nullptr, &dst, info, act_info);
+
+ validate(src.info()->padding(), PaddingSize(0, 0, 0, 0));
+ validate(weights.info()->padding(), PaddingSize(0, 0, 0, 0));
+ validate(dst.info()->padding(), PaddingSize(0, 0, 0, 0));
+}
+
template <typename T>
using NEDirectConvolutionLayerFixture = DirectConvolutionValidationFixture<Tensor, Accessor, NEDirectConvolutionLayer, T>;
diff --git a/tests/validation/fixtures/DirectConvolutionLayerFixture.h b/tests/validation/fixtures/DirectConvolutionLayerFixture.h
index 3da5158e97..e37063e2e5 100644
--- a/tests/validation/fixtures/DirectConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/DirectConvolutionLayerFixture.h
@@ -51,13 +51,10 @@ class DirectConvolutionValidationGenericFixture : public framework::Fixture
public:
using TBias = typename std::conditional < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T >::type;
-public:
template <typename...>
void setup(TensorShape input_shape, int stride_x, int stride_y, int pad_x, int pad_y, unsigned int kernel_size, unsigned int num_kernels,
DataType data_type, QuantizationInfo quantization_info, ActivationLayerInfo act_info, DataLayout data_layout)
{
- ARM_COMPUTE_ERROR_ON(data_layout == DataLayout::UNKNOWN);
-
_quantization_info = quantization_info;
_data_type = data_type;