From 05fb448bf48e31d723dfd9f4bbf3899ff65f0fba Mon Sep 17 00:00:00 2001 From: giuros01 Date: Tue, 26 Mar 2019 17:44:40 +0000 Subject: COMPMID-1963: Implement FFT (2D) on NEON Change-Id: I3b564be8d7949e00c6544071ef62dd51de838c96 Signed-off-by: giuros01 Reviewed-on: https://review.mlplatform.org/c/1048 Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas --- tests/validation/NEON/FFT.cpp | 74 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/validation/NEON/FFT.cpp b/tests/validation/NEON/FFT.cpp index 14fb7e2518..598c8bb10d 100644 --- a/tests/validation/NEON/FFT.cpp +++ b/tests/validation/NEON/FFT.cpp @@ -23,6 +23,7 @@ */ #include "arm_compute/core/Types.h" #include "arm_compute/runtime/NEON/functions/NEFFT1D.h" +#include "arm_compute/runtime/NEON/functions/NEFFT2D.h" #include "arm_compute/runtime/Tensor.h" #include "tests/NEON/Accessor.h" #include "tests/framework/Asserts.h" @@ -49,6 +50,13 @@ const auto shapes_1d = framework::dataset::make("TensorShape", { TensorShape(2U TensorShape(96U, 2U, 2U) }); +const auto shapes_2d = framework::dataset::make("TensorShape", { TensorShape(2U, 2U, 3U), TensorShape(3U, 6U, 3U), + TensorShape(4U, 5U, 3U), TensorShape(5U, 7U, 3U), + TensorShape(7U, 25U, 3U), TensorShape(8U, 2U, 3U), + TensorShape(9U, 16U, 3U), TensorShape(25U, 32U, 3U), + TensorShape(192U, 128U, 2U) + }); + const auto ActivationFunctionsSmallDataset = framework::dataset::make("ActivationInfo", { ActivationLayerInfo(), @@ -127,8 +135,72 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEFFT1DFixture, framework::DatasetMode:: } TEST_SUITE_END() // FP32 TEST_SUITE_END() // Float - TEST_SUITE_END() // FFT1D + +TEST_SUITE(FFT2D) + +DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(shapes_2d, data_types), + shape, data_type) +{ + // Create tensors + Tensor src = create_tensor(shape, data_type, 2); + Tensor dst = create_tensor(shape, data_type, 2); + + ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); + + // Create and configure function + NEFFT2D fft2d; + fft2d.configure(&src, &dst, FFT2DInfo()); + + // Validate valid region + const ValidRegion valid_region = shape_to_valid_region(shape); + validate(src.info()->valid_region(), valid_region); + validate(dst.info()->valid_region(), valid_region); + + // Validate padding + validate(src.info()->padding(), PaddingSize()); + validate(dst.info()->padding(), PaddingSize()); +} + +// *INDENT-OFF* +// clang-format off +DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip( + framework::dataset::make("InputInfo", { TensorInfo(TensorShape(32U, 25U, 2U), 2, DataType::F32), // Mismatching data types + TensorInfo(TensorShape(32U, 25U, 2U), 2, DataType::F32), // Mismatching shapes + TensorInfo(TensorShape(32U, 25U, 2U), 3, DataType::F32), // Invalid channels + TensorInfo(TensorShape(32U, 13U, 2U), 2, DataType::F32), // Undecomposable FFT + TensorInfo(TensorShape(32U, 25U, 2U), 2, DataType::F32), + }), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 25U, 2U), 2, DataType::F16), + TensorInfo(TensorShape(16U, 25U, 2U), 2, DataType::F32), + TensorInfo(TensorShape(32U, 25U, 2U), 1, DataType::F32), + TensorInfo(TensorShape(32U, 13U, 2U), 2, DataType::F32), + TensorInfo(TensorShape(32U, 25U, 2U), 2, DataType::F32), + })), + framework::dataset::make("Expected", { false, false, false, false, true })), + input_info, output_info, expected) +{ + const Status s = NEFFT2D::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), FFT2DInfo()); + ARM_COMPUTE_EXPECT(bool(s) == expected, framework::LogLevel::ERRORS); +} +// clang-format on +// *INDENT-ON* + +template +using NEFFT2DFixture = FFTValidationFixture; + +TEST_SUITE(Float) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, NEFFT2DFixture, framework::DatasetMode::ALL, combine(shapes_2d, framework::dataset::make("DataType", DataType::F32))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_f32, tolerance_num); +} +TEST_SUITE_END() // FP32 +TEST_SUITE_END() // Float +TEST_SUITE_END() // FFT2D + TEST_SUITE_END() // NEON } // namespace validation } // namespace test -- cgit v1.2.1