aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSheri Zhang <sheri.zhang@arm.com>2020-03-16 14:31:53 +0000
committerSheri Zhang <sheri.zhang@arm.com>2020-03-19 13:31:07 +0000
commit95f8089c63fbeab67bfe57b2232adbdccc7932c3 (patch)
tree10eedf0f63a6423d746e5cf318e8a64b5c717ffe
parentfd7780d910f3bc4c85bc95b57ea3dd4375d95d41 (diff)
downloadComputeLibrary-95f8089c63fbeab67bfe57b2232adbdccc7932c3.tar.gz
COMPMID-3273: Add support for QASYMM8_SIGNED in CPPDetectionPostProcessLayer
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com> Change-Id: I8dad529892caf7389efb311e810c8a80ca3d03c2 Signed-off-by: Sheri Zhang <sheri.zhang@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2888 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h8
-rw-r--r--src/runtime/CPP/functions/CPPDetectionPostProcessLayer.cpp97
-rw-r--r--tests/validation/CPP/DetectionPostProcessLayer.cpp190
3 files changed, 216 insertions, 79 deletions
diff --git a/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h b/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h
index 44ebf9d7f2..cb74ca9077 100644
--- a/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h
+++ b/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,7 +54,7 @@ public:
CPPDetectionPostProcessLayer &operator=(const CPPDetectionPostProcessLayer &) = delete;
/** Configure the detection output layer CPP function
*
- * @param[in] input_box_encoding The bounding box input tensor. Data types supported: F32, QASYMM8.
+ * @param[in] input_box_encoding The bounding box input tensor. Data types supported: F32/QASYMM8/QASYMM8_SIGNED.
* @param[in] input_score The class prediction input tensor. Data types supported: Same as @p input_box_encoding.
* @param[in] input_anchors The anchors input tensor. Data types supported: Same as @p input_box_encoding.
* @param[out] output_boxes The boxes output tensor. Data types supported: F32.
@@ -69,8 +69,8 @@ public:
ITensor *output_boxes, ITensor *output_classes, ITensor *output_scores, ITensor *num_detection, DetectionPostProcessLayerInfo info = DetectionPostProcessLayerInfo());
/** Static function to check if given info will lead to a valid configuration of @ref CPPDetectionPostProcessLayer
*
- * @param[in] input_box_encoding The bounding box input tensor info. Data types supported: F32, QASYMM8.
- * @param[in] input_class_score The class prediction input tensor info. Data types supported: F32, QASYMM8.
+ * @param[in] input_box_encoding The bounding box input tensor info. Data types supported: F32/QASYMM8/QASYMM8_SIGNED.
+ * @param[in] input_class_score The class prediction input tensor info. Data types supported: Same as @p input_box_encoding.
* @param[in] input_anchors The anchors input tensor. Data types supported: F32, QASYMM8.
* @param[out] output_boxes The output tensor. Data types supported: F32.
* @param[out] output_classes The output tensor. Data types supported: Same as @p output_boxes.
diff --git a/src/runtime/CPP/functions/CPPDetectionPostProcessLayer.cpp b/src/runtime/CPP/functions/CPPDetectionPostProcessLayer.cpp
index 9856f527ee..b3fc9c776d 100644
--- a/src/runtime/CPP/functions/CPPDetectionPostProcessLayer.cpp
+++ b/src/runtime/CPP/functions/CPPDetectionPostProcessLayer.cpp
@@ -40,7 +40,7 @@ Status validate_arguments(const ITensorInfo *input_box_encoding, const ITensorIn
DetectionPostProcessLayerInfo info, const unsigned int kBatchSize, const unsigned int kNumCoordBox)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input_box_encoding, input_class_score, input_anchors);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_box_encoding, 1, DataType::F32, DataType::QASYMM8);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_box_encoding, 1, DataType::F32, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_box_encoding, input_anchors);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input_box_encoding->num_dimensions() > 3, "The location input tensor shape should be [4, N, kBatchSize].");
if(input_box_encoding->num_dimensions() > 2)
@@ -90,6 +90,24 @@ Status validate_arguments(const ITensorInfo *input_box_encoding, const ITensorIn
return Status{};
}
+inline void DecodeBoxCorner(BBox &box_centersize, BBox &anchor, Iterator &decoded_it, DetectionPostProcessLayerInfo info)
+{
+ const float half_factor = 0.5f;
+
+ // BBox is equavalent to CenterSizeEncoding [y,x,h,w]
+ const float y_center = box_centersize[0] / info.scale_value_y() * anchor[2] + anchor[0];
+ const float x_center = box_centersize[1] / info.scale_value_x() * anchor[3] + anchor[1];
+ const float half_h = half_factor * static_cast<float>(std::exp(box_centersize[2] / info.scale_value_h())) * anchor[2];
+ const float half_w = half_factor * static_cast<float>(std::exp(box_centersize[3] / info.scale_value_w())) * anchor[3];
+
+ // Box Corner encoding boxes are saved as [xmin, ymin, xmax, ymax]
+ auto decoded_ptr = reinterpret_cast<float *>(decoded_it.ptr());
+ *(decoded_ptr) = x_center - half_w; // xmin
+ *(1 + decoded_ptr) = y_center - half_h; // ymin
+ *(2 + decoded_ptr) = x_center + half_w; // xmax
+ *(3 + decoded_ptr) = y_center + half_h; // ymax
+}
+
/** Decode a bbox according to a anchors and scale info.
*
* @param[in] input_box_encoding The input prior bounding boxes.
@@ -101,8 +119,8 @@ void DecodeCenterSizeBoxes(const ITensor *input_box_encoding, const ITensor *inp
{
const QuantizationInfo &qi_box = input_box_encoding->info()->quantization_info();
const QuantizationInfo &qi_anchors = input_anchors->info()->quantization_info();
- BBox box_centersize;
- BBox anchor;
+ BBox box_centersize{ {} };
+ BBox anchor{ {} };
Window win;
win.use_tensor_dimensions(input_box_encoding->info()->tensor_shape());
@@ -112,11 +130,9 @@ void DecodeCenterSizeBoxes(const ITensor *input_box_encoding, const ITensor *inp
Iterator anchor_it(input_anchors, win);
Iterator decoded_it(decoded_boxes, win);
- const float half_factor = 0.5f;
-
- execute_window_loop(win, [&](const Coordinates &)
+ if(input_box_encoding->info()->data_type() == DataType::QASYMM8)
{
- if(is_data_type_quantized(input_box_encoding->info()->data_type()))
+ execute_window_loop(win, [&](const Coordinates &)
{
const auto box_ptr = reinterpret_cast<const qasymm8_t *>(box_it.ptr());
const auto anchor_ptr = reinterpret_cast<const qasymm8_t *>(anchor_it.ptr());
@@ -126,29 +142,38 @@ void DecodeCenterSizeBoxes(const ITensor *input_box_encoding, const ITensor *inp
anchor = BBox({ dequantize_qasymm8(*anchor_ptr, qi_anchors), dequantize_qasymm8(*(anchor_ptr + 1), qi_anchors),
dequantize_qasymm8(*(2 + anchor_ptr), qi_anchors), dequantize_qasymm8(*(3 + anchor_ptr), qi_anchors)
});
- }
- else
+ DecodeBoxCorner(box_centersize, anchor, decoded_it, info);
+ },
+ box_it, anchor_it, decoded_it);
+ }
+ else if(input_box_encoding->info()->data_type() == DataType::QASYMM8_SIGNED)
+ {
+ execute_window_loop(win, [&](const Coordinates &)
+ {
+ const auto box_ptr = reinterpret_cast<const qasymm8_signed_t *>(box_it.ptr());
+ const auto anchor_ptr = reinterpret_cast<const qasymm8_signed_t *>(anchor_it.ptr());
+ box_centersize = BBox({ dequantize_qasymm8_signed(*box_ptr, qi_box), dequantize_qasymm8_signed(*(box_ptr + 1), qi_box),
+ dequantize_qasymm8_signed(*(2 + box_ptr), qi_box), dequantize_qasymm8_signed(*(3 + box_ptr), qi_box)
+ });
+ anchor = BBox({ dequantize_qasymm8_signed(*anchor_ptr, qi_anchors), dequantize_qasymm8_signed(*(anchor_ptr + 1), qi_anchors),
+ dequantize_qasymm8_signed(*(2 + anchor_ptr), qi_anchors), dequantize_qasymm8_signed(*(3 + anchor_ptr), qi_anchors)
+ });
+ DecodeBoxCorner(box_centersize, anchor, decoded_it, info);
+ },
+ box_it, anchor_it, decoded_it);
+ }
+ else
+ {
+ execute_window_loop(win, [&](const Coordinates &)
{
const auto box_ptr = reinterpret_cast<const float *>(box_it.ptr());
const auto anchor_ptr = reinterpret_cast<const float *>(anchor_it.ptr());
box_centersize = BBox({ *box_ptr, *(box_ptr + 1), *(2 + box_ptr), *(3 + box_ptr) });
anchor = BBox({ *anchor_ptr, *(anchor_ptr + 1), *(2 + anchor_ptr), *(3 + anchor_ptr) });
- }
-
- // BBox is equavalent to CenterSizeEncoding [y,x,h,w]
- const float y_center = box_centersize[0] / info.scale_value_y() * anchor[2] + anchor[0];
- const float x_center = box_centersize[1] / info.scale_value_x() * anchor[3] + anchor[1];
- const float half_h = half_factor * static_cast<float>(std::exp(box_centersize[2] / info.scale_value_h())) * anchor[2];
- const float half_w = half_factor * static_cast<float>(std::exp(box_centersize[3] / info.scale_value_w())) * anchor[3];
-
- // Box Corner encoding boxes are saved as [xmin, ymin, xmax, ymax]
- auto decoded_ptr = reinterpret_cast<float *>(decoded_it.ptr());
- *(decoded_ptr) = x_center - half_w; // xmin
- *(1 + decoded_ptr) = y_center - half_h; // ymin
- *(2 + decoded_ptr) = x_center + half_w; // xmax
- *(3 + decoded_ptr) = y_center + half_h; // ymax
- },
- box_it, anchor_it, decoded_it);
+ DecodeBoxCorner(box_centersize, anchor, decoded_it, info);
+ },
+ box_it, anchor_it, decoded_it);
+ }
}
void SaveOutputs(const Tensor *decoded_boxes, const std::vector<int> &result_idx_boxes_after_nms, const std::vector<float> &result_scores_after_nms, const std::vector<int> &result_classes_after_nms,
@@ -263,12 +288,26 @@ void CPPDetectionPostProcessLayer::run()
// Decode scores if necessary
if(_dequantize_scores)
{
- for(unsigned int idx_c = 0; idx_c < _num_classes_with_background; ++idx_c)
+ if(_input_box_encoding->info()->data_type() == DataType::QASYMM8)
{
- for(unsigned int idx_b = 0; idx_b < _num_boxes; ++idx_b)
+ for(unsigned int idx_c = 0; idx_c < _num_classes_with_background; ++idx_c)
{
- *(reinterpret_cast<float *>(_decoded_scores.ptr_to_element(Coordinates(idx_c, idx_b)))) =
- dequantize_qasymm8(*(reinterpret_cast<qasymm8_t *>(_input_scores->ptr_to_element(Coordinates(idx_c, idx_b)))), _input_scores->info()->quantization_info());
+ for(unsigned int idx_b = 0; idx_b < _num_boxes; ++idx_b)
+ {
+ *(reinterpret_cast<float *>(_decoded_scores.ptr_to_element(Coordinates(idx_c, idx_b)))) =
+ dequantize_qasymm8(*(reinterpret_cast<qasymm8_t *>(_input_scores->ptr_to_element(Coordinates(idx_c, idx_b)))), _input_scores->info()->quantization_info());
+ }
+ }
+ }
+ else if(_input_box_encoding->info()->data_type() == DataType::QASYMM8_SIGNED)
+ {
+ for(unsigned int idx_c = 0; idx_c < _num_classes_with_background; ++idx_c)
+ {
+ for(unsigned int idx_b = 0; idx_b < _num_boxes; ++idx_b)
+ {
+ *(reinterpret_cast<float *>(_decoded_scores.ptr_to_element(Coordinates(idx_c, idx_b)))) =
+ dequantize_qasymm8_signed(*(reinterpret_cast<qasymm8_signed_t *>(_input_scores->ptr_to_element(Coordinates(idx_c, idx_b)))), _input_scores->info()->quantization_info());
+ }
}
}
}
diff --git a/tests/validation/CPP/DetectionPostProcessLayer.cpp b/tests/validation/CPP/DetectionPostProcessLayer.cpp
index f4528fb593..934ffea545 100644
--- a/tests/validation/CPP/DetectionPostProcessLayer.cpp
+++ b/tests/validation/CPP/DetectionPostProcessLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 ARM Limited.
+ * Copyright (c) 2019-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,52 +46,64 @@ inline void fill_tensor(U &&tensor, const std::vector<T> &v)
{
std::memcpy(tensor.data(), v.data(), sizeof(T) * v.size());
}
-template <typename U, typename T>
+template <typename D, typename U, typename T>
inline void quantize_and_fill_tensor(U &&tensor, const std::vector<T> &v)
{
- QuantizationInfo qi = tensor.quantization_info();
- std::vector<uint8_t> quantized;
+ QuantizationInfo qi = tensor.quantization_info();
+ std::vector<D> quantized;
quantized.reserve(v.size());
for(auto elem : v)
{
- quantized.emplace_back(quantize_qasymm8(elem, qi));
+ quantized.emplace_back(Qasymm8QuantizationHelper<D>::quantize(elem, qi));
+ }
+ std::memcpy(tensor.data(), quantized.data(), sizeof(D) * quantized.size());
+}
+template <typename T>
+inline int calc_qinfo(const float min, const float max, float &scale)
+{
+ const auto qmin = std::numeric_limits<T>::min();
+ const auto qmax = std::numeric_limits<T>::max();
+ const float f_qmin = qmin;
+ const float f_qmax = qmax;
+
+ scale = (max - min) / (f_qmax - f_qmin);
+ const float offset_from_min = f_qmin - min / scale;
+ const float offset_from_max = f_qmax - max / scale;
+
+ const float offset_from_min_error = std::abs(f_qmin) + std::abs(min / scale);
+ const float offset_from_max_error = std::abs(f_qmax) + std::abs(max / scale);
+ const float f_offset = offset_from_min_error < offset_from_max_error ? offset_from_min : offset_from_max;
+ T tmp_offset;
+ if(f_offset < f_qmin)
+ {
+ tmp_offset = qmin;
+ }
+ else if(f_offset > f_qmax)
+ {
+ tmp_offset = qmax;
+ }
+ else
+ {
+ tmp_offset = static_cast<T>(arm_compute::support::cpp11::round(f_offset));
}
- std::memcpy(tensor.data(), quantized.data(), sizeof(uint8_t) * quantized.size());
+ return static_cast<int>(tmp_offset);
}
-inline QuantizationInfo qinfo_scaleoffset_from_minmax(const float min, const float max)
+inline QuantizationInfo qinfo_scaleoffset_from_minmax(DataType data_type, const float min, const float max)
{
- int offset = 0;
- float scale = 0;
- const uint8_t qmin = std::numeric_limits<uint8_t>::min();
- const uint8_t qmax = std::numeric_limits<uint8_t>::max();
- const float f_qmin = qmin;
- const float f_qmax = qmax;
+ int offset = 0;
+ float scale = 0;
// Continue only if [min,max] is a valid range and not a point
if(min != max)
{
- scale = (max - min) / (f_qmax - f_qmin);
- const float offset_from_min = f_qmin - min / scale;
- const float offset_from_max = f_qmax - max / scale;
-
- const float offset_from_min_error = std::abs(f_qmin) + std::abs(min / scale);
- const float offset_from_max_error = std::abs(f_qmax) + std::abs(max / scale);
- const float f_offset = offset_from_min_error < offset_from_max_error ? offset_from_min : offset_from_max;
-
- uint8_t uint8_offset = 0;
- if(f_offset < f_qmin)
- {
- uint8_offset = qmin;
- }
- else if(f_offset > f_qmax)
+ if(data_type == DataType::QASYMM8_SIGNED)
{
- uint8_offset = qmax;
+ offset = calc_qinfo<int8_t>(min, max, scale);
}
else
{
- uint8_offset = static_cast<uint8_t>(arm_compute::support::cpp11::round(f_offset));
+ offset = calc_qinfo<uint8_t>(min, max, scale);
}
- offset = uint8_offset;
}
return QuantizationInfo(scale, offset);
}
@@ -100,9 +112,9 @@ inline void base_test_case(DetectionPostProcessLayerInfo info, DataType data_typ
const SimpleTensor<float> &expected_output_classes, const SimpleTensor<float> &expected_output_scores, const SimpleTensor<float> &expected_num_detection,
AbsoluteTolerance<float> tolerance_boxes = AbsoluteTolerance<float>(0.1f), AbsoluteTolerance<float> tolerance_others = AbsoluteTolerance<float>(0.1f))
{
- Tensor box_encoding = create_tensor<Tensor>(TensorShape(4U, 6U, 1U), data_type, 1, qinfo_scaleoffset_from_minmax(-1.0f, 1.0f));
- Tensor class_prediction = create_tensor<Tensor>(TensorShape(3U, 6U, 1U), data_type, 1, qinfo_scaleoffset_from_minmax(0.0f, 1.0f));
- Tensor anchors = create_tensor<Tensor>(TensorShape(4U, 6U), data_type, 1, qinfo_scaleoffset_from_minmax(0.0f, 100.5f));
+ Tensor box_encoding = create_tensor<Tensor>(TensorShape(4U, 6U, 1U), data_type, 1, qinfo_scaleoffset_from_minmax(data_type, -1.0f, 1.0f));
+ Tensor class_prediction = create_tensor<Tensor>(TensorShape(3U, 6U, 1U), data_type, 1, qinfo_scaleoffset_from_minmax(data_type, 0.0f, 1.0f));
+ Tensor anchors = create_tensor<Tensor>(TensorShape(4U, 6U), data_type, 1, qinfo_scaleoffset_from_minmax(data_type, 0.0f, 100.5f));
box_encoding.allocator()->allocate();
class_prediction.allocator()->allocate();
@@ -137,17 +149,31 @@ inline void base_test_case(DetectionPostProcessLayerInfo info, DataType data_typ
};
// Fill the tensors with random pre-generated values
- if(data_type == DataType::F32)
+ switch(data_type)
{
- fill_tensor(Accessor(box_encoding), box_encoding_vector);
- fill_tensor(Accessor(class_prediction), class_prediction_vector);
- fill_tensor(Accessor(anchors), anchors_vector);
- }
- else
- {
- quantize_and_fill_tensor(Accessor(box_encoding), box_encoding_vector);
- quantize_and_fill_tensor(Accessor(class_prediction), class_prediction_vector);
- quantize_and_fill_tensor(Accessor(anchors), anchors_vector);
+ case DataType::F32:
+ {
+ fill_tensor(Accessor(box_encoding), box_encoding_vector);
+ fill_tensor(Accessor(class_prediction), class_prediction_vector);
+ fill_tensor(Accessor(anchors), anchors_vector);
+ }
+ break;
+ case DataType::QASYMM8:
+ {
+ quantize_and_fill_tensor<uint8_t>(Accessor(box_encoding), box_encoding_vector);
+ quantize_and_fill_tensor<uint8_t>(Accessor(class_prediction), class_prediction_vector);
+ quantize_and_fill_tensor<uint8_t>(Accessor(anchors), anchors_vector);
+ }
+ break;
+ case DataType::QASYMM8_SIGNED:
+ {
+ quantize_and_fill_tensor<int8_t>(Accessor(box_encoding), box_encoding_vector);
+ quantize_and_fill_tensor<int8_t>(Accessor(class_prediction), class_prediction_vector);
+ quantize_and_fill_tensor<int8_t>(Accessor(anchors), anchors_vector);
+ }
+ break;
+ default:
+ return;
}
// Determine the output through the CPP kernel
@@ -189,19 +215,22 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zi
TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::S8), // Unsupported data type
TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32), // Wrong Detection Info
TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32), // Wrong boxes dimensions
- TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8)}), // Wrong score dimension
+ TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8), // Wrong score dimension
+ TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8_SIGNED)}), // Wrong score dimension
framework::dataset::make("ClassPredsInfo",{ TensorInfo(TensorShape(3U ,10U), 1, DataType::F32),
TensorInfo(TensorShape(3U ,10U), 1, DataType::F32),
TensorInfo(TensorShape(3U ,10U), 1, DataType::F32),
TensorInfo(TensorShape(3U ,10U), 1, DataType::F32),
TensorInfo(TensorShape(3U ,10U), 1, DataType::F32),
- TensorInfo(TensorShape(3U ,10U), 1, DataType::QASYMM8)})),
+ TensorInfo(TensorShape(3U ,10U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(3U ,10U), 1, DataType::QASYMM8_SIGNED)})),
framework::dataset::make("AnchorsInfo",{ TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32),
TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32),
TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32),
TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32),
TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32),
- TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8)})),
+ TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8_SIGNED)})),
framework::dataset::make("OutputBoxInfo", { TensorInfo(TensorShape(4U, 3U, 1U), 1, DataType::F32),
TensorInfo(TensorShape(4U, 3U, 1U), 1, DataType::F32),
TensorInfo(TensorShape(4U, 3U, 1U), 1, DataType::S8),
@@ -383,6 +412,75 @@ TEST_CASE(Quantized_regular, framework::DatasetMode::ALL)
TEST_SUITE_END() // QASYMM8
+TEST_SUITE(QASYMM8_SIGNED)
+TEST_CASE(Quantized_general, framework::DatasetMode::ALL)
+{
+ DetectionPostProcessLayerInfo info = DetectionPostProcessLayerInfo(3 /*max_detections*/, 1 /*max_classes_per_detection*/, 0.0 /*nms_score_threshold*/,
+ 0.5 /*nms_iou_threshold*/, 2 /*num_classes*/, { 11.0, 11.0, 6.0, 6.0 } /*scale*/);
+
+ // Fill expected detection boxes
+ SimpleTensor<float> expected_output_boxes(TensorShape(4U, 3U), DataType::F32);
+ fill_tensor(expected_output_boxes, std::vector<float> { -0.15, 9.85, 0.95, 10.95, -0.15, -0.15, 0.95, 0.95, -0.15, 99.85, 0.95, 100.95 });
+ // Fill expected detection classes
+ SimpleTensor<float> expected_output_classes(TensorShape(3U), DataType::F32);
+ fill_tensor(expected_output_classes, std::vector<float> { 1.0f, 0.0f, 0.0f });
+ // Fill expected detection scores
+ SimpleTensor<float> expected_output_scores(TensorShape(3U), DataType::F32);
+ fill_tensor(expected_output_scores, std::vector<float> { 0.97f, 0.95f, 0.31f });
+ // Fill expected num detections
+ SimpleTensor<float> expected_num_detection(TensorShape(1U), DataType::F32);
+ fill_tensor(expected_num_detection, std::vector<float> { 3.f });
+ // Run test
+ base_test_case(info, DataType::QASYMM8_SIGNED, expected_output_boxes, expected_output_classes, expected_output_scores, expected_num_detection, AbsoluteTolerance<float>(0.3f));
+}
+
+TEST_CASE(Quantized_fast, framework::DatasetMode::ALL)
+{
+ DetectionPostProcessLayerInfo info = DetectionPostProcessLayerInfo(3 /*max_detections*/, 1 /*max_classes_per_detection*/, 0.0 /*nms_score_threshold*/,
+ 0.5 /*nms_iou_threshold*/, 2 /*num_classes*/, { 11.0, 11.0, 6.0, 6.0 } /*scale*/,
+ false /*use_regular_nms*/, 1 /*detections_per_class*/);
+
+ // Fill expected detection boxes
+ SimpleTensor<float> expected_output_boxes(TensorShape(4U, 3U), DataType::F32);
+ fill_tensor(expected_output_boxes, std::vector<float> { -0.15, 9.85, 0.95, 10.95, -0.15, -0.15, 0.95, 0.95, -0.15, 99.85, 0.95, 100.95 });
+ // Fill expected detection classes
+ SimpleTensor<float> expected_output_classes(TensorShape(3U), DataType::F32);
+ fill_tensor(expected_output_classes, std::vector<float> { 1.0f, 0.0f, 0.0f });
+ // Fill expected detection scores
+ SimpleTensor<float> expected_output_scores(TensorShape(3U), DataType::F32);
+ fill_tensor(expected_output_scores, std::vector<float> { 0.97f, 0.95f, 0.31f });
+ // Fill expected num detections
+ SimpleTensor<float> expected_num_detection(TensorShape(1U), DataType::F32);
+ fill_tensor(expected_num_detection, std::vector<float> { 3.f });
+
+ // Run base test
+ base_test_case(info, DataType::QASYMM8_SIGNED, expected_output_boxes, expected_output_classes, expected_output_scores, expected_num_detection, AbsoluteTolerance<float>(0.3f));
+}
+
+TEST_CASE(Quantized_regular, framework::DatasetMode::ALL)
+{
+ DetectionPostProcessLayerInfo info = DetectionPostProcessLayerInfo(3 /*max_detections*/, 1 /*max_classes_per_detection*/, 0.0 /*nms_score_threshold*/,
+ 0.5 /*nms_iou_threshold*/, 2 /*num_classes*/, { 11.0, 11.0, 6.0, 6.0 } /*scale*/,
+ true /*use_regular_nms*/, 1 /*detections_per_class*/);
+ // Fill expected detection boxes
+ SimpleTensor<float> expected_output_boxes(TensorShape(4U, 3U), DataType::F32);
+ fill_tensor(expected_output_boxes, std::vector<float> { -0.15, 9.85, 0.95, 10.95, -0.15, 9.85, 0.95, 10.95, 0.0f, 0.0f, 0.0f, 0.0f });
+ // Fill expected detection classes
+ SimpleTensor<float> expected_output_classes(TensorShape(3U), DataType::F32);
+ fill_tensor(expected_output_classes, std::vector<float> { 1.0f, 0.0f, 0.0f });
+ // Fill expected detection scores
+ SimpleTensor<float> expected_output_scores(TensorShape(3U), DataType::F32);
+ fill_tensor(expected_output_scores, std::vector<float> { 0.95f, 0.91f, 0.0f });
+ // Fill expected num detections
+ SimpleTensor<float> expected_num_detection(TensorShape(1U), DataType::F32);
+ fill_tensor(expected_num_detection, std::vector<float> { 2.f });
+
+ // Run test
+ base_test_case(info, DataType::QASYMM8_SIGNED, expected_output_boxes, expected_output_classes, expected_output_scores, expected_num_detection, AbsoluteTolerance<float>(0.3f));
+}
+
+TEST_SUITE_END() // QASYMM8_SIGNED
+
TEST_SUITE_END() // DetectionPostProcessLayer
TEST_SUITE_END() // CPP
} // namespace validation