aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/Utils.h17
-rw-r--r--src/core/CPP/kernels/CPPPermuteKernel.cpp11
-rw-r--r--src/core/NEON/kernels/NEPermuteKernel.cpp140
-rw-r--r--tests/validation/CL/Permute.cpp14
-rw-r--r--tests/validation/CPP/Permute.cpp39
-rw-r--r--tests/validation/NEON/Permute.cpp61
-rw-r--r--tests/validation/fixtures/PermuteFixture.h18
7 files changed, 224 insertions, 76 deletions
diff --git a/arm_compute/core/Utils.h b/arm_compute/core/Utils.h
index 7c5d87f475..696845d3ff 100644
--- a/arm_compute/core/Utils.h
+++ b/arm_compute/core/Utils.h
@@ -805,6 +805,23 @@ inline DataType data_type_for_convolution_matrix(const int16_t *conv, size_t siz
}
}
+/** Permutes the given dimensions according the permutation vector
+ *
+ * @param[in,out] dimensions Dimensions to be permuted.
+ * @param[in] perm Vector describing the permutation.
+ *
+ */
+template <typename T>
+inline void permute_strides(Dimensions<T> &dimensions, const PermutationVector &perm)
+{
+ const auto old_dim = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
+ for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
+ {
+ T dimension_val = old_dim[i];
+ dimensions.set(perm[i], dimension_val);
+ }
+}
+
/** Calculate padding requirements in case of SAME padding
*
* @param[in] input_shape Input shape
diff --git a/src/core/CPP/kernels/CPPPermuteKernel.cpp b/src/core/CPP/kernels/CPPPermuteKernel.cpp
index 17eaec2670..d9fe5b0c0a 100644
--- a/src/core/CPP/kernels/CPPPermuteKernel.cpp
+++ b/src/core/CPP/kernels/CPPPermuteKernel.cpp
@@ -58,17 +58,6 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
return Status{};
}
-template <typename T>
-inline void permute_strides(Dimensions<T> &dimensions, const PermutationVector &perm)
-{
- const auto old_dim = utility::make_array<Dimensions<T>::num_max_dimensions>(dimensions.begin(), dimensions.end());
- for(unsigned int i = 0; i < perm.num_dimensions(); ++i)
- {
- T dimension_val = old_dim[i];
- dimensions.set(perm[i], dimension_val);
- }
-}
-
} // namespace
template <typename T>
diff --git a/src/core/NEON/kernels/NEPermuteKernel.cpp b/src/core/NEON/kernels/NEPermuteKernel.cpp
index 29e6d501a6..5a2f258d4e 100644
--- a/src/core/NEON/kernels/NEPermuteKernel.cpp
+++ b/src/core/NEON/kernels/NEPermuteKernel.cpp
@@ -43,6 +43,48 @@ using namespace arm_compute;
namespace
{
+inline bool is_permutation_supported(const PermutationVector &v)
+{
+ static const std::array<PermutationVector, 6> permutations3 =
+ {
+ PermutationVector(2U, 0U, 1U),
+ PermutationVector(1U, 2U, 0U),
+ PermutationVector(0U, 1U, 2U),
+ PermutationVector(0U, 2U, 1U),
+ PermutationVector(1U, 0U, 2U),
+ PermutationVector(2U, 1U, 0U),
+ };
+ static const std::array<PermutationVector, 24> permutations4 =
+ {
+ PermutationVector(0U, 1U, 2U, 3U),
+ PermutationVector(1U, 0U, 2U, 3U),
+ PermutationVector(2U, 0U, 1U, 3U),
+ PermutationVector(0U, 2U, 1U, 3U),
+ PermutationVector(1U, 2U, 0U, 3U),
+ PermutationVector(2U, 1U, 0U, 3U),
+ PermutationVector(2U, 1U, 3U, 0U),
+ PermutationVector(1U, 2U, 3U, 0U),
+ PermutationVector(3U, 2U, 1U, 0U),
+ PermutationVector(2U, 3U, 1U, 0U),
+ PermutationVector(1U, 3U, 2U, 0U),
+ PermutationVector(3U, 1U, 2U, 0U),
+ PermutationVector(3U, 0U, 2U, 1U),
+ PermutationVector(0U, 3U, 2U, 1U),
+ PermutationVector(2U, 3U, 0U, 1U),
+ PermutationVector(3U, 2U, 0U, 1U),
+ PermutationVector(0U, 2U, 3U, 1U),
+ PermutationVector(2U, 0U, 3U, 1U),
+ PermutationVector(1U, 0U, 3U, 2U),
+ PermutationVector(0U, 1U, 3U, 2U),
+ PermutationVector(3U, 1U, 0U, 2U),
+ PermutationVector(1U, 3U, 0U, 2U),
+ PermutationVector(0U, 3U, 1U, 2U),
+ PermutationVector(3U, 0U, 1U, 2U)
+ };
+
+ return (permutations3.end() != std::find(permutations3.begin(), permutations3.end(), v)) || (permutations4.end() != std::find(permutations4.begin(), permutations4.end(), v));
+}
+
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm)
{
//Note: ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input) is not needed here as this kernel doesn't use NEON FP16 instructions.
@@ -50,9 +92,8 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
DataType::U16, DataType::S16,
DataType::U32, DataType::S32,
DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((perm != PermutationVector{ 2U, 0U, 1U })
- && (perm != PermutationVector{ 1U, 2U, 0U }),
- "Only [2, 0, 1] and [1, 2, 0] permutation is supported");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!is_permutation_supported(perm), "PermutationVector not supported.");
const TensorShape output_shape = misc::shape_calculator::compute_permutation_output_shape(*input, perm);
@@ -70,12 +111,20 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
template <typename T>
void NEPermuteKernel::run_permute(const Window &window)
{
+ const DataLayout input_layout = _input->info()->data_layout();
+
// Input window
Window window_in = window;
- window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), window.x().end() - window.x().start()));
- window_in.set(Window::DimY, Window::Dimension(window.y().start(), window.y().end(), window.y().end() - window.y().start()));
- window_in.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), window.z().end() - window.z().start()));
- window_in.set(3, Window::Dimension(window[3].start(), window[3].end(), window[3].end() - window[3].start()));
+
+ // we only support these two configs in arm_compute/core/NEON/kernels/convolution/common/shims.hpp, for all others
+ // we have to fall back to C++
+ if((input_layout == DataLayout::NCHW && _perm == PermutationVector{ 2U, 0U, 1U }) || (input_layout == DataLayout::NHWC && _perm == PermutationVector{ 1U, 2U, 0U }))
+ {
+ window_in.set(Window::DimX, Window::Dimension(window.x().start(), window.x().end(), window.x().end() - window.x().start()));
+ window_in.set(Window::DimY, Window::Dimension(window.y().start(), window.y().end(), window.y().end() - window.y().start()));
+ window_in.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), window.z().end() - window.z().start()));
+ window_in.set(3, Window::Dimension(window[3].start(), window[3].end(), window[3].end() - window[3].start()));
+ }
// Output window
Window window_out(window);
@@ -89,23 +138,53 @@ void NEPermuteKernel::run_permute(const Window &window)
Iterator in(_input, window_in);
Iterator out(_output, window_out);
- // CHW -> HWC
- if(_perm == PermutationVector{ 2U, 0U, 1U })
+ int in_row_stride = 0;
+ int in_col_stride = 0;
+ int in_channel_stride = 0;
+ int in_batch_stride = 0;
+ int n_cols = 0;
+ int n_rows = 0;
+ int n_channels = 0;
+ int n_batches = 0;
+
+ switch(input_layout)
{
- const int in_row_stride = _input->info()->strides_in_bytes().y() / sizeof(T);
- const int in_channel_stride = _input->info()->strides_in_bytes().z() / sizeof(T);
- const int in_batch_stride = _input->info()->strides_in_bytes()[3] / sizeof(T);
+ case DataLayout::NCHW:
+ {
+ in_row_stride = _input->info()->strides_in_bytes().y() / sizeof(T);
+ in_channel_stride = _input->info()->strides_in_bytes().z() / sizeof(T);
+ in_batch_stride = _input->info()->strides_in_bytes()[3] / sizeof(T);
+ n_cols = _input->info()->tensor_shape().x();
+ n_rows = window_in.y().step();
+ n_channels = _input->info()->tensor_shape().z();
+ n_batches = _input->info()->tensor_shape()[3];
+ break;
+ }
+ case DataLayout::NHWC:
+ {
+ in_col_stride = _input->info()->strides_in_bytes().y() / sizeof(T);
+ in_row_stride = _input->info()->strides_in_bytes().z() / sizeof(T);
+ in_batch_stride = _input->info()->strides_in_bytes()[3] / sizeof(T);
+ n_channels = _input->info()->tensor_shape().x();
+ n_cols = window_in.y().step();
+ n_rows = _input->info()->tensor_shape().z();
+ n_batches = _input->info()->tensor_shape()[3];
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Invalid input data layout.");
+ break;
+ }
+ }
+ // CHW -> HWC
+ if(input_layout == DataLayout::NCHW && _perm == PermutationVector{ 2U, 0U, 1U })
+ {
const int out_channel_stride = _output->info()->strides_in_bytes().x() / sizeof(T);
const int out_col_stride = _output->info()->strides_in_bytes().y() / sizeof(T);
const int out_row_stride = _output->info()->strides_in_bytes().z() / sizeof(T);
const int out_batch_stride = _output->info()->strides_in_bytes()[3] / sizeof(T);
-
- const int n_cols = _input->info()->tensor_shape().x();
- const int n_rows = window_in.y().step();
- const int n_channels = _input->info()->tensor_shape().z();
- const int n_batches = _input->info()->tensor_shape()[3];
-
execute_window_loop(window_in, [&](const Coordinates & id)
{
const int idx = id[0] * out_col_stride + id[1] * out_row_stride + id[2] * out_channel_stride;
@@ -117,22 +196,12 @@ void NEPermuteKernel::run_permute(const Window &window)
in, out);
}
// HWC -> CHW
- else if(_perm == PermutationVector{ 1U, 2U, 0U })
+ else if(input_layout == DataLayout::NHWC && _perm == PermutationVector{ 1U, 2U, 0U })
{
- const int in_col_stride = _input->info()->strides_in_bytes().y() / sizeof(T);
- const int in_row_stride = _input->info()->strides_in_bytes().z() / sizeof(T);
- const int in_batch_stride = _input->info()->strides_in_bytes()[3] / sizeof(T);
-
const int out_col_stride = _output->info()->strides_in_bytes().x() / sizeof(T);
const int out_row_stride = _output->info()->strides_in_bytes().y() / sizeof(T);
const int out_channel_stride = _output->info()->strides_in_bytes().z() / sizeof(T);
const int out_batch_stride = _output->info()->strides_in_bytes()[3] / sizeof(T);
-
- const int n_channels = _input->info()->tensor_shape().x();
- const int n_cols = window_in.y().step();
- const int n_rows = _input->info()->tensor_shape().z();
- const int n_batches = _input->info()->tensor_shape()[3];
-
execute_window_loop(window_in, [&](const Coordinates & id)
{
const int idx = id[0] * out_channel_stride + id[1] * out_col_stride + id[2] * out_row_stride;
@@ -145,7 +214,18 @@ void NEPermuteKernel::run_permute(const Window &window)
}
else
{
- ARM_COMPUTE_ERROR("Unsupported permutation vector");
+ // All other cases fall back to C++
+ // Permute strides
+ Strides strides = _output->info()->strides_in_bytes();
+ Strides perm_strides = strides;
+ permute_strides(perm_strides, _perm);
+ const int perm_stride_3 = _input->info()->num_dimensions() >= 4 ? perm_strides[3] : 0;
+ execute_window_loop(window, [&](const Coordinates & id)
+ {
+ const int idx = id[0] * perm_strides[0] + id[1] * perm_strides[1] + id[2] * perm_strides[2] + id[3] * perm_stride_3;
+ *(reinterpret_cast<T *>(out.ptr() + idx)) = *(reinterpret_cast<const T *>(in.ptr()));
+ },
+ in, out);
}
}
diff --git a/tests/validation/CL/Permute.cpp b/tests/validation/CL/Permute.cpp
index 1371e717e7..a75b8cf9cd 100644
--- a/tests/validation/CL/Permute.cpp
+++ b/tests/validation/CL/Permute.cpp
@@ -42,10 +42,16 @@ namespace validation
{
namespace
{
-const auto PermuteParametersSmall = combine(datasets::Small4DShapes(),
- framework::dataset::make("PermutationVector", { PermutationVector(2U, 0U, 1U), PermutationVector(1U, 2U, 0U), PermutationVector(3U, 2U, 0U, 1U) }));
-const auto PermuteParametersLarge = combine(datasets::Large4DShapes(),
- framework::dataset::make("PermutationVector", { PermutationVector(2U, 0U, 1U), PermutationVector(1U, 2U, 0U), PermutationVector(3U, 2U, 0U, 1U) }));
+const auto PermuteVectors = framework::dataset::make("PermutationVector",
+{
+ PermutationVector(2U, 0U, 1U),
+ PermutationVector(1U, 2U, 0U),
+ PermutationVector(3U, 2U, 0U, 1U)
+});
+const auto PermuteInputLayout = framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC });
+const auto PermuteParametersSmall = concat(concat(datasets::Small2DShapes(), datasets::Small3DShapes()), datasets::Small4DShapes()) * PermuteInputLayout * PermuteVectors;
+const auto PermuteParametersLarge = datasets::Large4DShapes() * PermuteInputLayout * PermuteVectors;
+
} // namespace
TEST_SUITE(CL)
TEST_SUITE(Permute)
diff --git a/tests/validation/CPP/Permute.cpp b/tests/validation/CPP/Permute.cpp
index 0a97041f81..2ba10ec651 100644
--- a/tests/validation/CPP/Permute.cpp
+++ b/tests/validation/CPP/Permute.cpp
@@ -42,10 +42,19 @@ namespace validation
{
namespace
{
-const auto PermuteParametersSmall = combine(concat(concat(datasets::Small2DShapes(), datasets::Small3DShapes()), datasets::Small4DShapes()),
- framework::dataset::make("PermutationVector", { PermutationVector(2U, 0U, 1U), PermutationVector(1U, 2U, 0U), PermutationVector(3U, 2U, 0U, 1U) }));
-const auto PermuteParametersLarge = combine(datasets::Large4DShapes(),
- framework::dataset::make("PermutationVector", { PermutationVector(2U, 0U, 1U), PermutationVector(1U, 2U, 0U), PermutationVector(3U, 2U, 0U, 1U) }));
+const auto PermuteVectors = framework::dataset::make("PermutationVector",
+{
+ PermutationVector(2U, 0U, 1U),
+ PermutationVector(1U, 2U, 0U),
+ PermutationVector(0U, 1U, 2U),
+ PermutationVector(0U, 2U, 1U),
+ PermutationVector(1U, 0U, 2U),
+ PermutationVector(2U, 1U, 0U),
+});
+const auto PermuteInputLayout = framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC });
+const auto PermuteParametersSmall = concat(concat(datasets::Small2DShapes(), datasets::Small3DShapes()), datasets::Small4DShapes()) * PermuteInputLayout * PermuteVectors;
+const auto PermuteParametersLarge = datasets::Large4DShapes() * PermuteInputLayout * PermuteVectors;
+
} // namespace
TEST_SUITE(CPP)
TEST_SUITE(Permute)
@@ -77,25 +86,32 @@ template <typename T>
using CPPPermuteFixture = PermuteValidationFixture<Tensor, Accessor, CPPPermute, T>;
TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CPPPermuteFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(PermuteParametersSmall, framework::dataset::make("DataType", DataType::U8)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CPPPermuteFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ PermuteParametersSmall * framework::dataset::make("DataType", DataType::U8))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CPPPermuteFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(PermuteParametersLarge, framework::dataset::make("DataType", DataType::U8)))
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CPPPermuteFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
+ PermuteParametersLarge * framework::dataset::make("DataType", DataType::U8))
{
// Validate output
validate(Accessor(_target), _reference);
}
+
TEST_SUITE_END()
TEST_SUITE(U16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CPPPermuteFixture<uint16_t>, framework::DatasetMode::PRECOMMIT, combine(PermuteParametersSmall, framework::dataset::make("DataType", DataType::U16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CPPPermuteFixture<uint16_t>, framework::DatasetMode::PRECOMMIT,
+ PermuteParametersSmall * framework::dataset::make("DataType", DataType::U16))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CPPPermuteFixture<uint16_t>, framework::DatasetMode::NIGHTLY, combine(PermuteParametersLarge, framework::dataset::make("DataType", DataType::U16)))
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CPPPermuteFixture<uint16_t>, framework::DatasetMode::NIGHTLY,
+ PermuteParametersLarge * framework::dataset::make("DataType", DataType::U16))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -103,12 +119,15 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CPPPermuteFixture<uint16_t>, framework::Dataset
TEST_SUITE_END()
TEST_SUITE(U32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CPPPermuteFixture<uint32_t>, framework::DatasetMode::PRECOMMIT, combine(PermuteParametersSmall, framework::dataset::make("DataType", DataType::U32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, CPPPermuteFixture<uint32_t>, framework::DatasetMode::PRECOMMIT,
+ PermuteParametersSmall * framework::dataset::make("DataType", DataType::U32))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CPPPermuteFixture<uint32_t>, framework::DatasetMode::NIGHTLY, combine(PermuteParametersLarge, framework::dataset::make("DataType", DataType::U32)))
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CPPPermuteFixture<uint32_t>, framework::DatasetMode::NIGHTLY,
+ PermuteParametersLarge * framework::dataset::make("DataType", DataType::U32))
{
// Validate output
validate(Accessor(_target), _reference);
diff --git a/tests/validation/NEON/Permute.cpp b/tests/validation/NEON/Permute.cpp
index 8c172ddded..a5a81b7ac3 100644
--- a/tests/validation/NEON/Permute.cpp
+++ b/tests/validation/NEON/Permute.cpp
@@ -42,10 +42,29 @@ namespace validation
{
namespace
{
-const auto PermuteParametersSmall = combine(concat(concat(datasets::Small2DShapes(), datasets::Small3DShapes()), datasets::Small4DShapes()),
- framework::dataset::make("PermutationVector", { PermutationVector(2U, 0U, 1U), PermutationVector(1U, 2U, 0U) }));
-const auto PermuteParametersLarge = combine(datasets::Large4DShapes(),
- framework::dataset::make("PermutationVector", { PermutationVector(2U, 0U, 1U), PermutationVector(1U, 2U, 0U) }));
+const auto PermuteVectors3 = framework::dataset::make("PermutationVector",
+{
+ PermutationVector(2U, 0U, 1U),
+ PermutationVector(1U, 2U, 0U),
+ PermutationVector(0U, 1U, 2U),
+ PermutationVector(0U, 2U, 1U),
+ PermutationVector(1U, 0U, 2U),
+ PermutationVector(2U, 1U, 0U),
+});
+const auto PermuteVectors4 = framework::dataset::make("PermutationVector",
+{
+ PermutationVector(3U, 2U, 0U, 1U),
+ PermutationVector(3U, 2U, 1U, 0U),
+ PermutationVector(2U, 3U, 1U, 0U),
+ PermutationVector(1U, 3U, 2U, 0U),
+ PermutationVector(3U, 1U, 2U, 0U),
+ PermutationVector(3U, 0U, 2U, 1U),
+ PermutationVector(0U, 3U, 2U, 1U)
+});
+const auto PermuteVectors = concat(PermuteVectors3, PermuteVectors4);
+const auto PermuteInputLayout = framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC });
+const auto PermuteParametersSmall = concat(concat(datasets::Small2DShapes(), datasets::Small3DShapes()), datasets::Small4DShapes()) * PermuteInputLayout * PermuteVectors;
+const auto PermuteParametersLarge = datasets::Large4DShapes() * PermuteInputLayout * PermuteVectors;
} // namespace
TEST_SUITE(NEON)
TEST_SUITE(Permute)
@@ -61,7 +80,11 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16), // valid
TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32), // valid
TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16), // permutation not supported
+ TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::S16), // permutation not supported
TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32), // permutation not supported
+ TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32), // permutation not supported
+ TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32) // permutation not supported
+
}),
framework::dataset::make("OutputInfo", {
TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
@@ -71,7 +94,11 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
+ TensorInfo(TensorShape(3U, 5U, 7U, 7U), 1, DataType::S16),
TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(37U, 2U, 13U, 27U), 1, DataType::F32),
+ TensorInfo(TensorShape(37U, 2U, 13U, 27U), 1, DataType::F32)
+
})),
framework::dataset::make("PermutationVector", {
PermutationVector(2U, 1U, 0U),
@@ -81,9 +108,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
PermutationVector(2U, 0U, 1U),
PermutationVector(1U, 2U, 0U),
PermutationVector(3U, 2U, 0U, 1U),
- PermutationVector(2U, 3U, 1U, 0U)
+ PermutationVector(3U, 2U, 0U, 1U),
+ PermutationVector(2U, 3U, 1U, 0U),
+ PermutationVector(2U, 3U, 1U, 0U),
+ PermutationVector(0U, 0U, 0U, 1000U)
})),
- framework::dataset::make("Expected", { false, false, false, false, true, true, false, false })),
+ framework::dataset::make("Expected", { true, false, false, false, true, true, false,true, false, true, false })),
input_info, output_info, perm_vect, expected)
{
ARM_COMPUTE_EXPECT(bool(NEPermute::validate(&input_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), perm_vect)) == expected, framework::LogLevel::ERRORS);
@@ -118,12 +148,15 @@ template <typename T>
using NEPermuteFixture = PermuteValidationFixture<Tensor, Accessor, NEPermute, T>;
TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPermuteFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(PermuteParametersSmall, framework::dataset::make("DataType", DataType::U8)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPermuteFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ PermuteParametersSmall * framework::dataset::make("DataType", DataType::U8))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPermuteFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(PermuteParametersLarge, framework::dataset::make("DataType", DataType::U8)))
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPermuteFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
+ PermuteParametersLarge * framework::dataset::make("DataType", DataType::U8))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -131,12 +164,14 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPermuteFixture<uint8_t>, framework::DatasetMo
TEST_SUITE_END()
TEST_SUITE(U16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPermuteFixture<uint16_t>, framework::DatasetMode::PRECOMMIT, combine(PermuteParametersSmall, framework::dataset::make("DataType", DataType::U16)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPermuteFixture<uint16_t>, framework::DatasetMode::PRECOMMIT,
+ PermuteParametersSmall * framework::dataset::make("DataType", DataType::U16))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPermuteFixture<uint16_t>, framework::DatasetMode::NIGHTLY, combine(PermuteParametersLarge, framework::dataset::make("DataType", DataType::U16)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPermuteFixture<uint16_t>, framework::DatasetMode::NIGHTLY,
+ PermuteParametersLarge * framework::dataset::make("DataType", DataType::U16))
{
// Validate output
validate(Accessor(_target), _reference);
@@ -144,12 +179,14 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEPermuteFixture<uint16_t>, framework::DatasetM
TEST_SUITE_END()
TEST_SUITE(U32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEPermuteFixture<uint32_t>, framework::DatasetMode::PRECOMMIT, combine(PermuteParametersSmall, framework::dataset::make("DataType", DataType::U32)))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEPermuteFixture<uint32_t>, framework::DatasetMode::PRECOMMIT,
+ PermuteParametersSmall * framework::dataset::make("DataType", DataType::U32))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEPermuteFixture<uint32_t>, framework::DatasetMode::NIGHTLY, combine(PermuteParametersLarge, framework::dataset::make("DataType", DataType::U32)))
+FIXTURE_DATA_TEST_CASE(RunLarge, NEPermuteFixture<uint32_t>, framework::DatasetMode::NIGHTLY,
+ PermuteParametersLarge * framework::dataset::make("DataType", DataType::U32))
{
// Validate output
validate(Accessor(_target), _reference);
diff --git a/tests/validation/fixtures/PermuteFixture.h b/tests/validation/fixtures/PermuteFixture.h
index 3aae384706..92d01a5654 100644
--- a/tests/validation/fixtures/PermuteFixture.h
+++ b/tests/validation/fixtures/PermuteFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,10 +46,10 @@ class PermuteValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, PermutationVector perm, DataType data_type)
+ void setup(TensorShape input_shape, DataLayout input_layout, PermutationVector perm, DataType data_type)
{
- _target = compute_target(shape, data_type, perm);
- _reference = compute_reference(shape, data_type, perm);
+ _target = compute_target(input_shape, input_layout, data_type, perm);
+ _reference = compute_reference(input_shape, data_type, perm);
}
protected:
@@ -59,14 +59,14 @@ protected:
library->fill_tensor_uniform(tensor, 0);
}
- TensorType compute_target(const TensorShape &shape, DataType data_type, PermutationVector perm)
+ TensorType compute_target(const TensorShape &input_shape, DataLayout input_layout, DataType data_type, PermutationVector perm)
{
// Permute shapes
- TensorShape output_shape = shape;
+ TensorShape output_shape = input_shape;
permute(output_shape, perm);
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, data_type);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type);
TensorType dst = create_tensor<TensorType>(output_shape, data_type);
// Create and configure function
@@ -92,10 +92,10 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type, PermutationVector perm)
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, DataType data_type, PermutationVector perm)
{
// Create reference
- SimpleTensor<T> src{ shape, data_type };
+ SimpleTensor<T> src{ input_shape, data_type };
// Fill reference
fill(src);