aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-08-24 14:35:22 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2020-09-15 16:56:21 +0000
commit82c1a1fc63d6a49c0b4be39529412c7f7bc8ea64 (patch)
tree19b9ed859e1d5524c1fd1e8b807c8c2b9baace3d
parentf1109546ae79e56c8f6797248c5a15588a9a10eb (diff)
downloadComputeLibrary-82c1a1fc63d6a49c0b4be39529412c7f7bc8ea64.tar.gz
COMPMID-3752: NEPermuteKernel does not support permutations2
Solves also: - COMPMID-3766: CTS Failures in Transpose Neon + FP16 Change-Id: I9d323f45f49cc0bce9e6329790bcf2f0eeec8572 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3949 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Manuel Bottini <manuel.bottini@arm.com> Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/NEON/kernels/NEPermuteKernel.cpp22
-rw-r--r--tests/validation/NEON/Permute.cpp31
2 files changed, 32 insertions, 21 deletions
diff --git a/src/core/NEON/kernels/NEPermuteKernel.cpp b/src/core/NEON/kernels/NEPermuteKernel.cpp
index 737b10b03c..3f447f90b9 100644
--- a/src/core/NEON/kernels/NEPermuteKernel.cpp
+++ b/src/core/NEON/kernels/NEPermuteKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,15 +36,19 @@ namespace
#include "arm_compute/core/NEON/kernels/convolution/common/shims.hpp"
} // namespace
-#include <cstddef>
-#include <cstdint>
-
-using namespace arm_compute;
-
+namespace arm_compute
+{
namespace
{
inline bool is_permutation_supported(const PermutationVector &v)
{
+ static const std::array<PermutationVector, 2> permutations2 =
+ {
+ {
+ PermutationVector(0U, 1U),
+ PermutationVector(1U, 0U),
+ }
+ };
static const std::array<PermutationVector, 6> permutations3 =
{
{
@@ -86,7 +90,8 @@ inline bool is_permutation_supported(const PermutationVector &v)
}
};
- return (permutations3.end() != std::find(permutations3.begin(), permutations3.end(), v)) || (permutations4.end() != std::find(permutations4.begin(), permutations4.end(), v));
+ return (permutations2.end() != std::find(permutations2.begin(), permutations2.end(), v)) || (permutations3.end() != std::find(permutations3.begin(), permutations3.end(), v))
+ || (permutations4.end() != std::find(permutations4.begin(), permutations4.end(), v));
}
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm)
@@ -129,7 +134,7 @@ void NEPermuteKernel::run_permute(const Window &window)
// Output window
Window window_out(window);
const Window::Dimension zero_window = Window::Dimension(0, 0, 0);
- for(size_t d = 0; d <= _perm.num_dimensions(); ++d)
+ for(size_t d = 0; d <= _output->info()->num_dimensions(); ++d)
{
window_out.set(d, zero_window);
}
@@ -292,3 +297,4 @@ void NEPermuteKernel::run(const Window &window, const ThreadInfo &info)
(this->*_func)(window);
}
}
+} // namespace arm_compute
diff --git a/tests/validation/NEON/Permute.cpp b/tests/validation/NEON/Permute.cpp
index d405582192..9429f25618 100644
--- a/tests/validation/NEON/Permute.cpp
+++ b/tests/validation/NEON/Permute.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2019 Arm Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,6 +42,11 @@ namespace validation
{
namespace
{
+const auto PermuteVectors2 = framework::dataset::make("PermutationVector",
+{
+ PermutationVector(0U, 1U),
+ PermutationVector(1U, 0U),
+});
const auto PermuteVectors3 = framework::dataset::make("PermutationVector",
{
PermutationVector(2U, 0U, 1U),
@@ -61,7 +66,7 @@ const auto PermuteVectors4 = framework::dataset::make("PermutationVector",
PermutationVector(3U, 0U, 2U, 1U),
PermutationVector(0U, 3U, 2U, 1U)
});
-const auto PermuteVectors = concat(PermuteVectors3, PermuteVectors4);
+const auto PermuteVectors = concat(concat(PermuteVectors2, PermuteVectors3), PermuteVectors4);
const auto PermuteParametersSmall = concat(concat(datasets::Small2DShapes(), datasets::Small3DShapes()), datasets::Small4DShapes()) * PermuteVectors;
const auto PermuteParametersLarge = datasets::Large4DShapes() * PermuteVectors;
} // namespace
@@ -71,7 +76,7 @@ TEST_SUITE(Permute)
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("InputInfo",{
+ framework::dataset::make("InputInfo",{
TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16), // permutation not supported
TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16), // permutation not supported
TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16), // permutation not supported
@@ -85,26 +90,26 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
TensorInfo(TensorShape(27U, 13U, 37U, 2U), 1, DataType::F32) // permutation not supported
}),
- framework::dataset::make("OutputInfo", {
- TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
- TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),
+ framework::dataset::make("OutputInfo", {
+ TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
+ TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),
TensorInfo(TensorShape(7U, 7U, 5U, 3U), 1, DataType::U16),
TensorInfo(TensorShape(5U, 7U), 1, DataType::U8),
- TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
- TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),
- TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
- TensorInfo(TensorShape(3U, 5U, 7U, 7U), 1, DataType::S16),
- TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
+ TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(5U, 7U, 7U, 3U), 1, DataType::U16),
+ TensorInfo(TensorShape(3U, 5U, 7U, 7U), 1, DataType::S16),
+ TensorInfo(TensorShape(13U, 37U, 27U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(37U, 2U, 13U, 27U), 1, DataType::F32),
TensorInfo(TensorShape(37U, 2U, 13U, 27U), 1, DataType::F32)
})),
- framework::dataset::make("PermutationVector", {
+ framework::dataset::make("PermutationVector", {
PermutationVector(2U, 1U, 0U),
PermutationVector(2U, 2U, 1U),
PermutationVector(1U, 1U, 1U),
PermutationVector(2U, 0U, 1U),
- PermutationVector(2U, 0U, 1U),
+ PermutationVector(2U, 0U, 1U),
PermutationVector(1U, 2U, 0U),
PermutationVector(3U, 2U, 0U, 1U),
PermutationVector(3U, 2U, 0U, 1U),