From 95f8089c63fbeab67bfe57b2232adbdccc7932c3 Mon Sep 17 00:00:00 2001 From: Sheri Zhang Date: Mon, 16 Mar 2020 14:31:53 +0000 Subject: COMPMID-3273: Add support for QASYMM8_SIGNED in CPPDetectionPostProcessLayer Signed-off-by: Sheri Zhang Change-Id: I8dad529892caf7389efb311e810c8a80ca3d03c2 Signed-off-by: Sheri Zhang Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2888 Comments-Addressed: Arm Jenkins Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- .../CPP/functions/CPPDetectionPostProcessLayer.h | 8 +- .../CPP/functions/CPPDetectionPostProcessLayer.cpp | 97 +++++++---- tests/validation/CPP/DetectionPostProcessLayer.cpp | 190 ++++++++++++++++----- 3 files changed, 216 insertions(+), 79 deletions(-) diff --git a/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h b/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h index 44ebf9d7f2..cb74ca9077 100644 --- a/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h +++ b/arm_compute/runtime/CPP/functions/CPPDetectionPostProcessLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -54,7 +54,7 @@ public: CPPDetectionPostProcessLayer &operator=(const CPPDetectionPostProcessLayer &) = delete; /** Configure the detection output layer CPP function * - * @param[in] input_box_encoding The bounding box input tensor. Data types supported: F32, QASYMM8. + * @param[in] input_box_encoding The bounding box input tensor. Data types supported: F32/QASYMM8/QASYMM8_SIGNED. * @param[in] input_score The class prediction input tensor. Data types supported: Same as @p input_box_encoding. * @param[in] input_anchors The anchors input tensor. Data types supported: Same as @p input_box_encoding. * @param[out] output_boxes The boxes output tensor. Data types supported: F32. @@ -69,8 +69,8 @@ public: ITensor *output_boxes, ITensor *output_classes, ITensor *output_scores, ITensor *num_detection, DetectionPostProcessLayerInfo info = DetectionPostProcessLayerInfo()); /** Static function to check if given info will lead to a valid configuration of @ref CPPDetectionPostProcessLayer * - * @param[in] input_box_encoding The bounding box input tensor info. Data types supported: F32, QASYMM8. - * @param[in] input_class_score The class prediction input tensor info. Data types supported: F32, QASYMM8. + * @param[in] input_box_encoding The bounding box input tensor info. Data types supported: F32/QASYMM8/QASYMM8_SIGNED. + * @param[in] input_class_score The class prediction input tensor info. Data types supported: Same as @p input_box_encoding. * @param[in] input_anchors The anchors input tensor. Data types supported: F32, QASYMM8. * @param[out] output_boxes The output tensor. Data types supported: F32. * @param[out] output_classes The output tensor. Data types supported: Same as @p output_boxes. diff --git a/src/runtime/CPP/functions/CPPDetectionPostProcessLayer.cpp b/src/runtime/CPP/functions/CPPDetectionPostProcessLayer.cpp index 9856f527ee..b3fc9c776d 100644 --- a/src/runtime/CPP/functions/CPPDetectionPostProcessLayer.cpp +++ b/src/runtime/CPP/functions/CPPDetectionPostProcessLayer.cpp @@ -40,7 +40,7 @@ Status validate_arguments(const ITensorInfo *input_box_encoding, const ITensorIn DetectionPostProcessLayerInfo info, const unsigned int kBatchSize, const unsigned int kNumCoordBox) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input_box_encoding, input_class_score, input_anchors); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_box_encoding, 1, DataType::F32, DataType::QASYMM8); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input_box_encoding, 1, DataType::F32, DataType::QASYMM8, DataType::QASYMM8_SIGNED); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input_box_encoding, input_anchors); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input_box_encoding->num_dimensions() > 3, "The location input tensor shape should be [4, N, kBatchSize]."); if(input_box_encoding->num_dimensions() > 2) @@ -90,6 +90,24 @@ Status validate_arguments(const ITensorInfo *input_box_encoding, const ITensorIn return Status{}; } +inline void DecodeBoxCorner(BBox &box_centersize, BBox &anchor, Iterator &decoded_it, DetectionPostProcessLayerInfo info) +{ + const float half_factor = 0.5f; + + // BBox is equavalent to CenterSizeEncoding [y,x,h,w] + const float y_center = box_centersize[0] / info.scale_value_y() * anchor[2] + anchor[0]; + const float x_center = box_centersize[1] / info.scale_value_x() * anchor[3] + anchor[1]; + const float half_h = half_factor * static_cast(std::exp(box_centersize[2] / info.scale_value_h())) * anchor[2]; + const float half_w = half_factor * static_cast(std::exp(box_centersize[3] / info.scale_value_w())) * anchor[3]; + + // Box Corner encoding boxes are saved as [xmin, ymin, xmax, ymax] + auto decoded_ptr = reinterpret_cast(decoded_it.ptr()); + *(decoded_ptr) = x_center - half_w; // xmin + *(1 + decoded_ptr) = y_center - half_h; // ymin + *(2 + decoded_ptr) = x_center + half_w; // xmax + *(3 + decoded_ptr) = y_center + half_h; // ymax +} + /** Decode a bbox according to a anchors and scale info. * * @param[in] input_box_encoding The input prior bounding boxes. @@ -101,8 +119,8 @@ void DecodeCenterSizeBoxes(const ITensor *input_box_encoding, const ITensor *inp { const QuantizationInfo &qi_box = input_box_encoding->info()->quantization_info(); const QuantizationInfo &qi_anchors = input_anchors->info()->quantization_info(); - BBox box_centersize; - BBox anchor; + BBox box_centersize{ {} }; + BBox anchor{ {} }; Window win; win.use_tensor_dimensions(input_box_encoding->info()->tensor_shape()); @@ -112,11 +130,9 @@ void DecodeCenterSizeBoxes(const ITensor *input_box_encoding, const ITensor *inp Iterator anchor_it(input_anchors, win); Iterator decoded_it(decoded_boxes, win); - const float half_factor = 0.5f; - - execute_window_loop(win, [&](const Coordinates &) + if(input_box_encoding->info()->data_type() == DataType::QASYMM8) { - if(is_data_type_quantized(input_box_encoding->info()->data_type())) + execute_window_loop(win, [&](const Coordinates &) { const auto box_ptr = reinterpret_cast(box_it.ptr()); const auto anchor_ptr = reinterpret_cast(anchor_it.ptr()); @@ -126,29 +142,38 @@ void DecodeCenterSizeBoxes(const ITensor *input_box_encoding, const ITensor *inp anchor = BBox({ dequantize_qasymm8(*anchor_ptr, qi_anchors), dequantize_qasymm8(*(anchor_ptr + 1), qi_anchors), dequantize_qasymm8(*(2 + anchor_ptr), qi_anchors), dequantize_qasymm8(*(3 + anchor_ptr), qi_anchors) }); - } - else + DecodeBoxCorner(box_centersize, anchor, decoded_it, info); + }, + box_it, anchor_it, decoded_it); + } + else if(input_box_encoding->info()->data_type() == DataType::QASYMM8_SIGNED) + { + execute_window_loop(win, [&](const Coordinates &) + { + const auto box_ptr = reinterpret_cast(box_it.ptr()); + const auto anchor_ptr = reinterpret_cast(anchor_it.ptr()); + box_centersize = BBox({ dequantize_qasymm8_signed(*box_ptr, qi_box), dequantize_qasymm8_signed(*(box_ptr + 1), qi_box), + dequantize_qasymm8_signed(*(2 + box_ptr), qi_box), dequantize_qasymm8_signed(*(3 + box_ptr), qi_box) + }); + anchor = BBox({ dequantize_qasymm8_signed(*anchor_ptr, qi_anchors), dequantize_qasymm8_signed(*(anchor_ptr + 1), qi_anchors), + dequantize_qasymm8_signed(*(2 + anchor_ptr), qi_anchors), dequantize_qasymm8_signed(*(3 + anchor_ptr), qi_anchors) + }); + DecodeBoxCorner(box_centersize, anchor, decoded_it, info); + }, + box_it, anchor_it, decoded_it); + } + else + { + execute_window_loop(win, [&](const Coordinates &) { const auto box_ptr = reinterpret_cast(box_it.ptr()); const auto anchor_ptr = reinterpret_cast(anchor_it.ptr()); box_centersize = BBox({ *box_ptr, *(box_ptr + 1), *(2 + box_ptr), *(3 + box_ptr) }); anchor = BBox({ *anchor_ptr, *(anchor_ptr + 1), *(2 + anchor_ptr), *(3 + anchor_ptr) }); - } - - // BBox is equavalent to CenterSizeEncoding [y,x,h,w] - const float y_center = box_centersize[0] / info.scale_value_y() * anchor[2] + anchor[0]; - const float x_center = box_centersize[1] / info.scale_value_x() * anchor[3] + anchor[1]; - const float half_h = half_factor * static_cast(std::exp(box_centersize[2] / info.scale_value_h())) * anchor[2]; - const float half_w = half_factor * static_cast(std::exp(box_centersize[3] / info.scale_value_w())) * anchor[3]; - - // Box Corner encoding boxes are saved as [xmin, ymin, xmax, ymax] - auto decoded_ptr = reinterpret_cast(decoded_it.ptr()); - *(decoded_ptr) = x_center - half_w; // xmin - *(1 + decoded_ptr) = y_center - half_h; // ymin - *(2 + decoded_ptr) = x_center + half_w; // xmax - *(3 + decoded_ptr) = y_center + half_h; // ymax - }, - box_it, anchor_it, decoded_it); + DecodeBoxCorner(box_centersize, anchor, decoded_it, info); + }, + box_it, anchor_it, decoded_it); + } } void SaveOutputs(const Tensor *decoded_boxes, const std::vector &result_idx_boxes_after_nms, const std::vector &result_scores_after_nms, const std::vector &result_classes_after_nms, @@ -263,12 +288,26 @@ void CPPDetectionPostProcessLayer::run() // Decode scores if necessary if(_dequantize_scores) { - for(unsigned int idx_c = 0; idx_c < _num_classes_with_background; ++idx_c) + if(_input_box_encoding->info()->data_type() == DataType::QASYMM8) { - for(unsigned int idx_b = 0; idx_b < _num_boxes; ++idx_b) + for(unsigned int idx_c = 0; idx_c < _num_classes_with_background; ++idx_c) { - *(reinterpret_cast(_decoded_scores.ptr_to_element(Coordinates(idx_c, idx_b)))) = - dequantize_qasymm8(*(reinterpret_cast(_input_scores->ptr_to_element(Coordinates(idx_c, idx_b)))), _input_scores->info()->quantization_info()); + for(unsigned int idx_b = 0; idx_b < _num_boxes; ++idx_b) + { + *(reinterpret_cast(_decoded_scores.ptr_to_element(Coordinates(idx_c, idx_b)))) = + dequantize_qasymm8(*(reinterpret_cast(_input_scores->ptr_to_element(Coordinates(idx_c, idx_b)))), _input_scores->info()->quantization_info()); + } + } + } + else if(_input_box_encoding->info()->data_type() == DataType::QASYMM8_SIGNED) + { + for(unsigned int idx_c = 0; idx_c < _num_classes_with_background; ++idx_c) + { + for(unsigned int idx_b = 0; idx_b < _num_boxes; ++idx_b) + { + *(reinterpret_cast(_decoded_scores.ptr_to_element(Coordinates(idx_c, idx_b)))) = + dequantize_qasymm8_signed(*(reinterpret_cast(_input_scores->ptr_to_element(Coordinates(idx_c, idx_b)))), _input_scores->info()->quantization_info()); + } } } } diff --git a/tests/validation/CPP/DetectionPostProcessLayer.cpp b/tests/validation/CPP/DetectionPostProcessLayer.cpp index f4528fb593..934ffea545 100644 --- a/tests/validation/CPP/DetectionPostProcessLayer.cpp +++ b/tests/validation/CPP/DetectionPostProcessLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019 ARM Limited. + * Copyright (c) 2019-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -46,52 +46,64 @@ inline void fill_tensor(U &&tensor, const std::vector &v) { std::memcpy(tensor.data(), v.data(), sizeof(T) * v.size()); } -template +template inline void quantize_and_fill_tensor(U &&tensor, const std::vector &v) { - QuantizationInfo qi = tensor.quantization_info(); - std::vector quantized; + QuantizationInfo qi = tensor.quantization_info(); + std::vector quantized; quantized.reserve(v.size()); for(auto elem : v) { - quantized.emplace_back(quantize_qasymm8(elem, qi)); + quantized.emplace_back(Qasymm8QuantizationHelper::quantize(elem, qi)); + } + std::memcpy(tensor.data(), quantized.data(), sizeof(D) * quantized.size()); +} +template +inline int calc_qinfo(const float min, const float max, float &scale) +{ + const auto qmin = std::numeric_limits::min(); + const auto qmax = std::numeric_limits::max(); + const float f_qmin = qmin; + const float f_qmax = qmax; + + scale = (max - min) / (f_qmax - f_qmin); + const float offset_from_min = f_qmin - min / scale; + const float offset_from_max = f_qmax - max / scale; + + const float offset_from_min_error = std::abs(f_qmin) + std::abs(min / scale); + const float offset_from_max_error = std::abs(f_qmax) + std::abs(max / scale); + const float f_offset = offset_from_min_error < offset_from_max_error ? offset_from_min : offset_from_max; + T tmp_offset; + if(f_offset < f_qmin) + { + tmp_offset = qmin; + } + else if(f_offset > f_qmax) + { + tmp_offset = qmax; + } + else + { + tmp_offset = static_cast(arm_compute::support::cpp11::round(f_offset)); } - std::memcpy(tensor.data(), quantized.data(), sizeof(uint8_t) * quantized.size()); + return static_cast(tmp_offset); } -inline QuantizationInfo qinfo_scaleoffset_from_minmax(const float min, const float max) +inline QuantizationInfo qinfo_scaleoffset_from_minmax(DataType data_type, const float min, const float max) { - int offset = 0; - float scale = 0; - const uint8_t qmin = std::numeric_limits::min(); - const uint8_t qmax = std::numeric_limits::max(); - const float f_qmin = qmin; - const float f_qmax = qmax; + int offset = 0; + float scale = 0; // Continue only if [min,max] is a valid range and not a point if(min != max) { - scale = (max - min) / (f_qmax - f_qmin); - const float offset_from_min = f_qmin - min / scale; - const float offset_from_max = f_qmax - max / scale; - - const float offset_from_min_error = std::abs(f_qmin) + std::abs(min / scale); - const float offset_from_max_error = std::abs(f_qmax) + std::abs(max / scale); - const float f_offset = offset_from_min_error < offset_from_max_error ? offset_from_min : offset_from_max; - - uint8_t uint8_offset = 0; - if(f_offset < f_qmin) - { - uint8_offset = qmin; - } - else if(f_offset > f_qmax) + if(data_type == DataType::QASYMM8_SIGNED) { - uint8_offset = qmax; + offset = calc_qinfo(min, max, scale); } else { - uint8_offset = static_cast(arm_compute::support::cpp11::round(f_offset)); + offset = calc_qinfo(min, max, scale); } - offset = uint8_offset; } return QuantizationInfo(scale, offset); } @@ -100,9 +112,9 @@ inline void base_test_case(DetectionPostProcessLayerInfo info, DataType data_typ const SimpleTensor &expected_output_classes, const SimpleTensor &expected_output_scores, const SimpleTensor &expected_num_detection, AbsoluteTolerance tolerance_boxes = AbsoluteTolerance(0.1f), AbsoluteTolerance tolerance_others = AbsoluteTolerance(0.1f)) { - Tensor box_encoding = create_tensor(TensorShape(4U, 6U, 1U), data_type, 1, qinfo_scaleoffset_from_minmax(-1.0f, 1.0f)); - Tensor class_prediction = create_tensor(TensorShape(3U, 6U, 1U), data_type, 1, qinfo_scaleoffset_from_minmax(0.0f, 1.0f)); - Tensor anchors = create_tensor(TensorShape(4U, 6U), data_type, 1, qinfo_scaleoffset_from_minmax(0.0f, 100.5f)); + Tensor box_encoding = create_tensor(TensorShape(4U, 6U, 1U), data_type, 1, qinfo_scaleoffset_from_minmax(data_type, -1.0f, 1.0f)); + Tensor class_prediction = create_tensor(TensorShape(3U, 6U, 1U), data_type, 1, qinfo_scaleoffset_from_minmax(data_type, 0.0f, 1.0f)); + Tensor anchors = create_tensor(TensorShape(4U, 6U), data_type, 1, qinfo_scaleoffset_from_minmax(data_type, 0.0f, 100.5f)); box_encoding.allocator()->allocate(); class_prediction.allocator()->allocate(); @@ -137,17 +149,31 @@ inline void base_test_case(DetectionPostProcessLayerInfo info, DataType data_typ }; // Fill the tensors with random pre-generated values - if(data_type == DataType::F32) + switch(data_type) { - fill_tensor(Accessor(box_encoding), box_encoding_vector); - fill_tensor(Accessor(class_prediction), class_prediction_vector); - fill_tensor(Accessor(anchors), anchors_vector); - } - else - { - quantize_and_fill_tensor(Accessor(box_encoding), box_encoding_vector); - quantize_and_fill_tensor(Accessor(class_prediction), class_prediction_vector); - quantize_and_fill_tensor(Accessor(anchors), anchors_vector); + case DataType::F32: + { + fill_tensor(Accessor(box_encoding), box_encoding_vector); + fill_tensor(Accessor(class_prediction), class_prediction_vector); + fill_tensor(Accessor(anchors), anchors_vector); + } + break; + case DataType::QASYMM8: + { + quantize_and_fill_tensor(Accessor(box_encoding), box_encoding_vector); + quantize_and_fill_tensor(Accessor(class_prediction), class_prediction_vector); + quantize_and_fill_tensor(Accessor(anchors), anchors_vector); + } + break; + case DataType::QASYMM8_SIGNED: + { + quantize_and_fill_tensor(Accessor(box_encoding), box_encoding_vector); + quantize_and_fill_tensor(Accessor(class_prediction), class_prediction_vector); + quantize_and_fill_tensor(Accessor(anchors), anchors_vector); + } + break; + default: + return; } // Determine the output through the CPP kernel @@ -189,19 +215,22 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zi TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::S8), // Unsupported data type TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32), // Wrong Detection Info TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32), // Wrong boxes dimensions - TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8)}), // Wrong score dimension + TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8), // Wrong score dimension + TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8_SIGNED)}), // Wrong score dimension framework::dataset::make("ClassPredsInfo",{ TensorInfo(TensorShape(3U ,10U), 1, DataType::F32), TensorInfo(TensorShape(3U ,10U), 1, DataType::F32), TensorInfo(TensorShape(3U ,10U), 1, DataType::F32), TensorInfo(TensorShape(3U ,10U), 1, DataType::F32), TensorInfo(TensorShape(3U ,10U), 1, DataType::F32), - TensorInfo(TensorShape(3U ,10U), 1, DataType::QASYMM8)})), + TensorInfo(TensorShape(3U ,10U), 1, DataType::QASYMM8), + TensorInfo(TensorShape(3U ,10U), 1, DataType::QASYMM8_SIGNED)})), framework::dataset::make("AnchorsInfo",{ TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32), TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32), TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32), TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32), TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::F32), - TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8)})), + TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8), + TensorInfo(TensorShape(4U, 10U, 1U), 1, DataType::QASYMM8_SIGNED)})), framework::dataset::make("OutputBoxInfo", { TensorInfo(TensorShape(4U, 3U, 1U), 1, DataType::F32), TensorInfo(TensorShape(4U, 3U, 1U), 1, DataType::F32), TensorInfo(TensorShape(4U, 3U, 1U), 1, DataType::S8), @@ -383,6 +412,75 @@ TEST_CASE(Quantized_regular, framework::DatasetMode::ALL) TEST_SUITE_END() // QASYMM8 +TEST_SUITE(QASYMM8_SIGNED) +TEST_CASE(Quantized_general, framework::DatasetMode::ALL) +{ + DetectionPostProcessLayerInfo info = DetectionPostProcessLayerInfo(3 /*max_detections*/, 1 /*max_classes_per_detection*/, 0.0 /*nms_score_threshold*/, + 0.5 /*nms_iou_threshold*/, 2 /*num_classes*/, { 11.0, 11.0, 6.0, 6.0 } /*scale*/); + + // Fill expected detection boxes + SimpleTensor expected_output_boxes(TensorShape(4U, 3U), DataType::F32); + fill_tensor(expected_output_boxes, std::vector { -0.15, 9.85, 0.95, 10.95, -0.15, -0.15, 0.95, 0.95, -0.15, 99.85, 0.95, 100.95 }); + // Fill expected detection classes + SimpleTensor expected_output_classes(TensorShape(3U), DataType::F32); + fill_tensor(expected_output_classes, std::vector { 1.0f, 0.0f, 0.0f }); + // Fill expected detection scores + SimpleTensor expected_output_scores(TensorShape(3U), DataType::F32); + fill_tensor(expected_output_scores, std::vector { 0.97f, 0.95f, 0.31f }); + // Fill expected num detections + SimpleTensor expected_num_detection(TensorShape(1U), DataType::F32); + fill_tensor(expected_num_detection, std::vector { 3.f }); + // Run test + base_test_case(info, DataType::QASYMM8_SIGNED, expected_output_boxes, expected_output_classes, expected_output_scores, expected_num_detection, AbsoluteTolerance(0.3f)); +} + +TEST_CASE(Quantized_fast, framework::DatasetMode::ALL) +{ + DetectionPostProcessLayerInfo info = DetectionPostProcessLayerInfo(3 /*max_detections*/, 1 /*max_classes_per_detection*/, 0.0 /*nms_score_threshold*/, + 0.5 /*nms_iou_threshold*/, 2 /*num_classes*/, { 11.0, 11.0, 6.0, 6.0 } /*scale*/, + false /*use_regular_nms*/, 1 /*detections_per_class*/); + + // Fill expected detection boxes + SimpleTensor expected_output_boxes(TensorShape(4U, 3U), DataType::F32); + fill_tensor(expected_output_boxes, std::vector { -0.15, 9.85, 0.95, 10.95, -0.15, -0.15, 0.95, 0.95, -0.15, 99.85, 0.95, 100.95 }); + // Fill expected detection classes + SimpleTensor expected_output_classes(TensorShape(3U), DataType::F32); + fill_tensor(expected_output_classes, std::vector { 1.0f, 0.0f, 0.0f }); + // Fill expected detection scores + SimpleTensor expected_output_scores(TensorShape(3U), DataType::F32); + fill_tensor(expected_output_scores, std::vector { 0.97f, 0.95f, 0.31f }); + // Fill expected num detections + SimpleTensor expected_num_detection(TensorShape(1U), DataType::F32); + fill_tensor(expected_num_detection, std::vector { 3.f }); + + // Run base test + base_test_case(info, DataType::QASYMM8_SIGNED, expected_output_boxes, expected_output_classes, expected_output_scores, expected_num_detection, AbsoluteTolerance(0.3f)); +} + +TEST_CASE(Quantized_regular, framework::DatasetMode::ALL) +{ + DetectionPostProcessLayerInfo info = DetectionPostProcessLayerInfo(3 /*max_detections*/, 1 /*max_classes_per_detection*/, 0.0 /*nms_score_threshold*/, + 0.5 /*nms_iou_threshold*/, 2 /*num_classes*/, { 11.0, 11.0, 6.0, 6.0 } /*scale*/, + true /*use_regular_nms*/, 1 /*detections_per_class*/); + // Fill expected detection boxes + SimpleTensor expected_output_boxes(TensorShape(4U, 3U), DataType::F32); + fill_tensor(expected_output_boxes, std::vector { -0.15, 9.85, 0.95, 10.95, -0.15, 9.85, 0.95, 10.95, 0.0f, 0.0f, 0.0f, 0.0f }); + // Fill expected detection classes + SimpleTensor expected_output_classes(TensorShape(3U), DataType::F32); + fill_tensor(expected_output_classes, std::vector { 1.0f, 0.0f, 0.0f }); + // Fill expected detection scores + SimpleTensor expected_output_scores(TensorShape(3U), DataType::F32); + fill_tensor(expected_output_scores, std::vector { 0.95f, 0.91f, 0.0f }); + // Fill expected num detections + SimpleTensor expected_num_detection(TensorShape(1U), DataType::F32); + fill_tensor(expected_num_detection, std::vector { 2.f }); + + // Run test + base_test_case(info, DataType::QASYMM8_SIGNED, expected_output_boxes, expected_output_classes, expected_output_scores, expected_num_detection, AbsoluteTolerance(0.3f)); +} + +TEST_SUITE_END() // QASYMM8_SIGNED + TEST_SUITE_END() // DetectionPostProcessLayer TEST_SUITE_END() // CPP } // namespace validation -- cgit v1.2.1