aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSheri Zhang <sheri.zhang@arm.com>2020-03-16 14:07:51 +0000
committerSheri Zhang <sheri.zhang@arm.com>2020-03-16 16:35:33 +0000
commit05b243aff343fd6761bbadb2fcb4d2d98b0848c9 (patch)
treef0a99e11ae4990aa68881e558b0fd86ca8a1978e
parenta602f03f4c66e5ee2480f1a3fc66847968fc1076 (diff)
downloadComputeLibrary-05b243aff343fd6761bbadb2fcb4d2d98b0848c9.tar.gz
COMPMID-3271: Add support for QASYMM8_SIGNED in CPPTopKVKernel/CPPTopKV
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com> Change-Id: Ic34616fc3480ca85cc582e4e3db031d631ed5861 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2887 Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CPP/kernels/CPPTopKVKernel.h6
-rw-r--r--arm_compute/runtime/CPP/functions/CPPTopKV.h6
-rw-r--r--src/core/CPP/kernels/CPPTopKVKernel.cpp7
-rw-r--r--tests/validation/CPP/TopKV.cpp57
4 files changed, 66 insertions, 10 deletions
diff --git a/arm_compute/core/CPP/kernels/CPPTopKVKernel.h b/arm_compute/core/CPP/kernels/CPPTopKVKernel.h
index fc62b42ab8..4b9bfdd3c9 100644
--- a/arm_compute/core/CPP/kernels/CPPTopKVKernel.h
+++ b/arm_compute/core/CPP/kernels/CPPTopKVKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -53,7 +53,7 @@ public:
/** Set the input and output of the kernel.
*
- * @param[in] predictions A batch_size x classes tensor. Data types supported: F16/S32/F32/QASYMM8
+ * @param[in] predictions A batch_size x classes tensor. Data types supported: F16/S32/F32/QASYMM8/QASYMM8_SIGNED
* @param[in] targets A batch_size 1D tensor of class ids. Data types supported: S32
* @param[out] output Computed precision at @p k as a bool 1D tensor. Data types supported: U8
* @param[in] k Number of top elements to look at for computing precision.
@@ -62,7 +62,7 @@ public:
/** Static function to check if given info will lead to a valid configuration of @ref CPPTopKVKernel
*
- * @param[in] predictions A batch_size x classes tensor info. Data types supported: F16/S32/F32/QASYMM8
+ * @param[in] predictions A batch_size x classes tensor info. Data types supported: F16/S32/F32/QASYMM8/QASYMM8_SIGNED
* @param[in] targets A batch_size 1D tensor info of class ids. Data types supported: S32
* @param[in] output Computed precision at @p k as a bool 1D tensor info. Data types supported: U8
* @param[in] k Number of top elements to look at for computing precision.
diff --git a/arm_compute/runtime/CPP/functions/CPPTopKV.h b/arm_compute/runtime/CPP/functions/CPPTopKV.h
index d41b7d300f..c94e277312 100644
--- a/arm_compute/runtime/CPP/functions/CPPTopKV.h
+++ b/arm_compute/runtime/CPP/functions/CPPTopKV.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -38,7 +38,7 @@ class CPPTopKV : public ICPPSimpleFunction
public:
/** Set the input and output of the kernel.
*
- * @param[in] predictions A batch_size x classes tensor. Data types supported: F16/S32/F32/QASYMM8
+ * @param[in] predictions A batch_size x classes tensor. Data types supported: F16/S32/F32/QASYMM8/QASYMM8_SIGNED
* @param[in] targets A batch_size 1D tensor of class ids. Data types supported: S32
* @param[out] output Computed precision at @p k as a bool 1D tensor. Data types supported: U8
* @param[in] k Number of top elements to look at for computing precision.
@@ -47,7 +47,7 @@ public:
/** Static function to check if given info will lead to a valid configuration of @ref CPPTopKVKernel
*
- * @param[in] predictions A batch_size x classes tensor info. Data types supported: F16/S32/F32/QASYMM8
+ * @param[in] predictions A batch_size x classes tensor info. Data types supported: F16/S32/F32/QASYMM8/QASYMM8_SIGNED
* @param[in] targets A batch_size 1D tensor info of class ids. Data types supported: S32
* @param[in] output Computed precision at @p k as a bool 1D tensor info. Data types supported: U8
* @param[in] k Number of top elements to look at for computing precision.
diff --git a/src/core/CPP/kernels/CPPTopKVKernel.cpp b/src/core/CPP/kernels/CPPTopKVKernel.cpp
index 533543a988..7f284d4e1e 100644
--- a/src/core/CPP/kernels/CPPTopKVKernel.cpp
+++ b/src/core/CPP/kernels/CPPTopKVKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,7 +54,7 @@ inline bool greater_than(T a, T b)
Status validate_arguments(const ITensorInfo *predictions, const ITensorInfo *targets, ITensorInfo *output, const unsigned int k)
{
ARM_COMPUTE_UNUSED(k);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(predictions, 1, DataType::QASYMM8, DataType::S32, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(predictions, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S32, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(targets, 1, DataType::U32);
ARM_COMPUTE_RETURN_ERROR_ON(predictions->num_dimensions() > 2);
@@ -145,6 +145,9 @@ void CPPTopKVKernel::run(const Window &window, const ThreadInfo &info)
case DataType::QASYMM8:
run_topkv<uint8_t>();
break;
+ case DataType::QASYMM8_SIGNED:
+ run_topkv<int8_t>();
+ break;
default:
ARM_COMPUTE_ERROR("Not supported");
}
diff --git a/tests/validation/CPP/TopKV.cpp b/tests/validation/CPP/TopKV.cpp
index 02178192aa..e528c622b0 100644
--- a/tests/validation/CPP/TopKV.cpp
+++ b/tests/validation/CPP/TopKV.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -134,7 +134,7 @@ TEST_CASE(Float, framework::DatasetMode::ALL)
validate(Accessor(output), expected_output);
}
-TEST_CASE(Quantized, framework::DatasetMode::ALL)
+TEST_CASE(QASYMM8, framework::DatasetMode::ALL)
{
const unsigned int k = 5;
@@ -187,6 +187,59 @@ TEST_CASE(Quantized, framework::DatasetMode::ALL)
validate(Accessor(output), expected_output);
}
+TEST_CASE(QASYMM8_SIGNED, framework::DatasetMode::ALL)
+{
+ const unsigned int k = 5;
+
+ Tensor predictions = create_tensor<Tensor>(TensorShape(10, 20), DataType::QASYMM8_SIGNED, 1, QuantizationInfo());
+ Tensor targets = create_tensor<Tensor>(TensorShape(20), DataType::U32);
+
+ predictions.allocator()->allocate();
+ targets.allocator()->allocate();
+
+ // Fill the tensors with random pre-generated values
+ fill_tensor(Accessor(predictions), std::vector<int8_t>
+ {
+ 123, -34, 69, 118, 20, -45, 99, -98, 127, 117, //-34
+ -99, 1, -128, 90, 60, 25, 102, 76, 24, -110, //25
+ 99, 119, 49, 43, -40, 60, 43, 84, 29, 67, //84
+ 33, 109, 74, 90, 90, 44, 98, 90, 35, 123, //74
+ 62, 118, 24, -32, 34, 21, 114, 113, 124, 20, //124
+ 74, 98, 48, 107, 127, 125, 6, -98, 127, 59, //98
+ 75, 83, 75, -118, 56, 101, 85, 97, 49, 127, //75
+ 72, -20, 40, 14, 28, -30, 109, 43, 127, -31, //-20
+ 78, 121, 109, 66, 29, 90, 70, 21, 38, 48, //109
+ 18, 10, 115, 124, 17, 123, 51, 54, 15, 17, //17
+ 66, 46, -66, 125, 104, 90, 123, 113, -54, -126, //125
+ 58, -85, 74, 39, 115, 39, 111, -27, 44, 51, //51
+ 71, 122, -34, -123, 94, 113, 125, 111, 38, 64, //94
+ -17, 40, 42, 68, 96, 68, 101, 40, 79, 71, //40
+ 89, 88, 54, 82, 127, 12, 112, 52, 125, 22, //22
+ -128, 56, 82, 31, 98, 94, 102, 105, 127, 123, //123
+ 112, 50, 61, 41, 39, 63, -77, 92, 26, 70, //39
+ 2, 90, 31, 99, -34, 114, 112, 126, 127, 87, //90
+ 125, 63, 56, 123, 50, -77, 97, -93, 1, 29, //56
+ 100, -35, 116, 64, 36, 92, 56, 82, -22, -118 //36
+ });
+
+ fill_tensor(Accessor(targets), std::vector<int> { 1, 5, 7, 2, 8, 1, 2, 1, 2, 4, 3, 9, 4, 1, 9, 9, 4, 1, 2, 4 });
+
+ // Determine the output through the CPP kernel
+ Tensor output;
+ CPPTopKV topkv;
+ topkv.configure(&predictions, &targets, &output, k);
+
+ output.allocator()->allocate();
+
+ // Run the kernel
+ topkv.run();
+
+ // Validate against the expected values
+ SimpleTensor<int8_t> expected_output(TensorShape(20), DataType::U8);
+ fill_tensor(expected_output, std::vector<int8_t> { 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0 });
+ validate(Accessor(output), expected_output);
+}
+
TEST_SUITE_END() // TopKV
TEST_SUITE_END() // CPP
} // namespace validation