aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSheri Zhang <sheri.zhang@arm.com>2020-03-16 14:20:56 +0000
committerSheri Zhang <sheri.zhang@arm.com>2020-03-16 16:56:04 +0000
commita0352d3fd3f0fd5256efe98bf934228374bcf48d (patch)
tree8a49848b2f4b911d7fbd2e9c338aae7de1e491ff
parent05b243aff343fd6761bbadb2fcb4d2d98b0848c9 (diff)
downloadComputeLibrary-a0352d3fd3f0fd5256efe98bf934228374bcf48d.tar.gz
COMPMID-3272: Add support for QASYMM8_SIGNED in CPPPermuteKernel/CPPPermute
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com> Change-Id: I3856661076b7e39213988251986299ebaa6d9c68 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2886 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CPP/kernels/CPPPermuteKernel.h6
-rw-r--r--arm_compute/runtime/CPP/functions/CPPPermute.h6
-rw-r--r--src/core/CPP/kernels/CPPPermuteKernel.cpp7
-rw-r--r--tests/validation/CPP/Permute.cpp14
4 files changed, 20 insertions, 13 deletions
diff --git a/arm_compute/core/CPP/kernels/CPPPermuteKernel.h b/arm_compute/core/CPP/kernels/CPPPermuteKernel.h
index dffc0dab78..e75152f4ea 100644
--- a/arm_compute/core/CPP/kernels/CPPPermuteKernel.h
+++ b/arm_compute/core/CPP/kernels/CPPPermuteKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -56,14 +56,14 @@ public:
/** Set the input and output of the kernel.
*
- * @param[in] input The input tensor to permute. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
+ * @param[in] input The input tensor to permute. Data types supported: U8/S8/QASYMM8/QASYMM8_SIGNED/U16/S16/F16/U32/S32/F32
* @param[out] output The output tensor. Data types supported: Same as @p input
* @param[in] perm Permutation vector
*/
void configure(const ITensor *input, ITensor *output, const PermutationVector &perm);
/** Static function to check if given info will lead to a valid configuration of @ref CPPPermuteKernel
*
- * @param[in] input The input tensor to permute. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
+ * @param[in] input The input tensor to permute. Data types supported: U8/S8/QASYMM8/QASYMM8_SIGNED/U16/S16/F16/U32/S32/F32
* @param[in] output The output tensor. Data types supported: Same as @p input
* @param[in] perm Permutation vector
*
diff --git a/arm_compute/runtime/CPP/functions/CPPPermute.h b/arm_compute/runtime/CPP/functions/CPPPermute.h
index 1b604e4b26..5a6d8ea106 100644
--- a/arm_compute/runtime/CPP/functions/CPPPermute.h
+++ b/arm_compute/runtime/CPP/functions/CPPPermute.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,14 +38,14 @@ class CPPPermute : public ICPPSimpleFunction
public:
/** Configure the permute CPP kernel
*
- * @param[in] input The input tensor to permute. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
+ * @param[in] input The input tensor to permute. Data types supported: All
* @param[out] output The output tensor. Data types supported: Same as @p input
* @param[in] perm Permutation vector
*/
void configure(const ITensor *input, ITensor *output, const PermutationVector &perm);
/** Static function to check if given info will lead to a valid configuration of @ref CPPPermute
*
- * @param[in] input The input tensor to permute. Data types supported: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
+ * @param[in] input The input tensor to permute. Data types supported: All
* @param[in] output The output tensor. Data types supported: Same as @p input
* @param[in] perm Permutation vector
*
diff --git a/src/core/CPP/kernels/CPPPermuteKernel.cpp b/src/core/CPP/kernels/CPPPermuteKernel.cpp
index d9fe5b0c0a..9d89836589 100644
--- a/src/core/CPP/kernels/CPPPermuteKernel.cpp
+++ b/src/core/CPP/kernels/CPPPermuteKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,10 +40,7 @@ namespace
{
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const PermutationVector &perm)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::U8, DataType::S8, DataType::QASYMM8,
- DataType::U16, DataType::S16,
- DataType::U32, DataType::S32,
- DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_type() == DataType::UNKNOWN);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(perm.num_dimensions() > 4, "Only up to 4D permutation vectors are supported");
const TensorShape output_shape = misc::shape_calculator::compute_permutation_output_shape(*input, perm);
diff --git a/tests/validation/CPP/Permute.cpp b/tests/validation/CPP/Permute.cpp
index 3d28df17b0..aab63e652e 100644
--- a/tests/validation/CPP/Permute.cpp
+++ b/tests/validation/CPP/Permute.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,7 +58,7 @@ const auto PermuteParametersLarge = datasets::Large4DShapes() * PermuteVectors;
TEST_SUITE(CPP)
TEST_SUITE(Permute)
-DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Small4DShapes(), framework::dataset::make("DataType", { DataType::S8, DataType::U8, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F16, DataType::F32 })),
+DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(datasets::Small4DShapes(), framework::dataset::make("DataType", { DataType::S8, DataType::U8, DataType::S16, DataType::U16, DataType::U32, DataType::S32, DataType::F16, DataType::F32, DataType::QASYMM8_SIGNED })),
shape, data_type)
{
// Define permutation vector
@@ -133,6 +133,16 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CPPPermuteFixture<uint32_t>, framework::Dataset
}
TEST_SUITE_END()
+TEST_SUITE(QASYMM8_SINGED)
+FIXTURE_DATA_TEST_CASE(RunSmall, CPPPermuteFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ PermuteParametersSmall * framework::dataset::make("DataType", DataType::QASYMM8_SIGNED))
+{
+ // Validate output
+ validate(Accessor(_target), _reference);
+}
+
+TEST_SUITE_END() // QASYMM8_SINGED
+
TEST_SUITE_END()
TEST_SUITE_END()
} // namespace validation