aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2019-09-05 12:30:22 +0100
committerPablo Marquez <pablo.tello@arm.com>2019-09-27 16:20:14 +0000
commit6b612f5fa1fee9528f2f87491fe7edb3887d9817 (patch)
tree579ef443d61ed1319e5d8f44d8a7a8ce83c82aad /tests
parent240b79de1c211ebb8d439b4a1c8c79777aa36f13 (diff)
downloadComputeLibrary-6b612f5fa1fee9528f2f87491fe7edb3887d9817.tar.gz
COMPMID-2310: CLGenerateProposalsLayer: support for QASYMM8
Change-Id: I48b77e09857cd43f9498d28e8f4bf346e3d7110d Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/1969 Reviewed-by: Pablo Marquez <pablo.tello@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/GenerateProposalsLayer.cpp20
-rw-r--r--tests/validation/fixtures/ComputeAllAnchorsFixture.h39
-rw-r--r--tests/validation/reference/ComputeAllAnchors.cpp9
3 files changed, 59 insertions, 9 deletions
diff --git a/tests/validation/CL/GenerateProposalsLayer.cpp b/tests/validation/CL/GenerateProposalsLayer.cpp
index 4ebffd7e79..bfad8e8381 100644
--- a/tests/validation/CL/GenerateProposalsLayer.cpp
+++ b/tests/validation/CL/GenerateProposalsLayer.cpp
@@ -82,6 +82,8 @@ const auto ComputeAllInfoDataset = framework::dataset::make("ComputeAllInfo",
ComputeAnchorsInfo(100U, 100U, 1. / 4.f),
});
+
+constexpr AbsoluteTolerance<int16_t> tolerance_qsymm16(1);
} // namespace
TEST_SUITE(CL)
@@ -364,7 +366,7 @@ DATA_TEST_CASE(IntegrationTestCaseGenerateProposals, framework::DatasetMode::ALL
proposals_final.allocator()->allocate();
select_proposals.run();
- // Select the first N entries of the proposals
+ // Select the first N entries of the scores
CLTensor scores_final;
CLSlice select_scores;
select_scores.configure(&scores_out, &scores_final, Coordinates(0), Coordinates(N));
@@ -395,6 +397,22 @@ FIXTURE_DATA_TEST_CASE(ComputeAllAnchors, CLComputeAllAnchorsFixture<half>, fram
TEST_SUITE_END() // FP16
TEST_SUITE_END() // Float
+template <typename T>
+using CLComputeAllAnchorsQuantizedFixture = ComputeAllAnchorsQuantizedFixture<CLTensor, CLAccessor, CLComputeAllAnchors, T>;
+
+TEST_SUITE(Quantized)
+TEST_SUITE(QASYMM8)
+FIXTURE_DATA_TEST_CASE(ComputeAllAnchors, CLComputeAllAnchorsQuantizedFixture<int16_t>, framework::DatasetMode::ALL,
+ combine(combine(combine(framework::dataset::make("NumAnchors", { 2, 4, 8 }), ComputeAllInfoDataset),
+ framework::dataset::make("DataType", { DataType::QSYMM16 })),
+ framework::dataset::make("QuantInfo", { QuantizationInfo(0.125f, 0) })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference, tolerance_qsymm16);
+}
+TEST_SUITE_END() // QASYMM8
+TEST_SUITE_END() // Quantized
+
TEST_SUITE_END() // GenerateProposals
TEST_SUITE_END() // CL
diff --git a/tests/validation/fixtures/ComputeAllAnchorsFixture.h b/tests/validation/fixtures/ComputeAllAnchorsFixture.h
index 6f2db3e623..e837bd4838 100644
--- a/tests/validation/fixtures/ComputeAllAnchorsFixture.h
+++ b/tests/validation/fixtures/ComputeAllAnchorsFixture.h
@@ -41,14 +41,14 @@ namespace test
namespace validation
{
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class ComputeAllAnchorsFixture : public framework::Fixture
+class ComputeAllAnchorsGenericFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(size_t num_anchors, const ComputeAnchorsInfo &info, DataType data_type)
+ void setup(size_t num_anchors, const ComputeAnchorsInfo &info, DataType data_type, QuantizationInfo qinfo)
{
- _target = compute_target(num_anchors, data_type, info);
- _reference = compute_reference(num_anchors, data_type, info);
+ _target = compute_target(num_anchors, data_type, info, qinfo);
+ _reference = compute_reference(num_anchors, data_type, info, qinfo);
}
protected:
@@ -58,11 +58,11 @@ protected:
library->fill_tensor_uniform(tensor, 0, T(0), T(100));
}
- TensorType compute_target(size_t num_anchors, DataType data_type, const ComputeAnchorsInfo &info)
+ TensorType compute_target(size_t num_anchors, DataType data_type, const ComputeAnchorsInfo &info, QuantizationInfo qinfo)
{
// Create tensors
TensorShape anchors_shape(4, num_anchors);
- TensorType anchors = create_tensor<TensorType>(anchors_shape, data_type);
+ TensorType anchors = create_tensor<TensorType>(anchors_shape, data_type, 1, qinfo);
TensorType all_anchors;
// Create and configure function
@@ -88,10 +88,11 @@ protected:
SimpleTensor<T> compute_reference(size_t num_anchors,
DataType data_type,
- const ComputeAnchorsInfo &info)
+ const ComputeAnchorsInfo &info,
+ QuantizationInfo qinfo)
{
// Create reference tensor
- SimpleTensor<T> anchors(TensorShape(4, num_anchors), data_type);
+ SimpleTensor<T> anchors(TensorShape(4, num_anchors), data_type, 1, qinfo);
// Fill reference tensor
fill(anchors);
@@ -101,6 +102,28 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ComputeAllAnchorsFixture : public ComputeAllAnchorsGenericFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(size_t num_anchors, const ComputeAnchorsInfo &info, DataType data_type)
+ {
+ ComputeAllAnchorsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(num_anchors, info, data_type, QuantizationInfo());
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class ComputeAllAnchorsQuantizedFixture : public ComputeAllAnchorsGenericFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(size_t num_anchors, const ComputeAnchorsInfo &info, DataType data_type, QuantizationInfo qinfo)
+ {
+ ComputeAllAnchorsGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(num_anchors, info, data_type, qinfo);
+ }
+};
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/reference/ComputeAllAnchors.cpp b/tests/validation/reference/ComputeAllAnchors.cpp
index 3f0498015a..60be7ef8a8 100644
--- a/tests/validation/reference/ComputeAllAnchors.cpp
+++ b/tests/validation/reference/ComputeAllAnchors.cpp
@@ -73,6 +73,15 @@ SimpleTensor<T> compute_all_anchors(const SimpleTensor<T> &anchors, const Comput
}
template SimpleTensor<float> compute_all_anchors(const SimpleTensor<float> &anchors, const ComputeAnchorsInfo &info);
template SimpleTensor<half> compute_all_anchors(const SimpleTensor<half> &anchors, const ComputeAnchorsInfo &info);
+
+template <>
+SimpleTensor<int16_t> compute_all_anchors(const SimpleTensor<int16_t> &anchors, const ComputeAnchorsInfo &info)
+{
+ SimpleTensor<float> anchors_tmp = convert_from_symmetric(anchors);
+ SimpleTensor<float> all_anchors_tmp = compute_all_anchors(anchors_tmp, info);
+ SimpleTensor<int16_t> all_anchors = convert_to_symmetric<int16_t>(all_anchors_tmp, anchors.quantization_info());
+ return all_anchors;
+}
} // namespace reference
} // namespace validation
} // namespace test