From 4aff98fcfd3c736115f3983dc448c3280e570841 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Wed, 28 Aug 2019 16:27:26 +0100 Subject: COMPMID-2247: Extend support of CLBoundingBoxTransform for QUANT16_ASYMM Change-Id: I8af7a382c0bccf55cf7f4a64f46ce9e6cd965afe Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/1833 Comments-Addressed: Arm Jenkins Reviewed-by: Pablo Marquez Tested-by: Arm Jenkins --- .../validation/reference/BoundingBoxTransform.cpp | 40 ++++++++++++++-------- 1 file changed, 25 insertions(+), 15 deletions(-) (limited to 'tests/validation/reference/BoundingBoxTransform.cpp') diff --git a/tests/validation/reference/BoundingBoxTransform.cpp b/tests/validation/reference/BoundingBoxTransform.cpp index 55dd165b51..e09bcff1c6 100644 --- a/tests/validation/reference/BoundingBoxTransform.cpp +++ b/tests/validation/reference/BoundingBoxTransform.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -36,16 +36,16 @@ namespace validation { namespace reference { -template -SimpleTensor bounding_box_transform(const SimpleTensor &boxes, const SimpleTensor &deltas, const BoundingBoxTransformInfo &info) +template +SimpleTensor bounding_box_transform(const SimpleTensor &boxes, const SimpleTensor &deltas, const BoundingBoxTransformInfo &info) { - const DataType boxes_data_type = deltas.data_type(); + const DataType boxes_data_type = boxes.data_type(); SimpleTensor pred_boxes(deltas.shape(), boxes_data_type); - const size_t num_classes = deltas.shape()[0] / 4; - const size_t num_boxes = deltas.shape()[1]; - const T *deltas_ptr = deltas.data(); - T *pred_boxes_ptr = pred_boxes.data(); + const size_t num_classes = deltas.shape()[0] / 4; + const size_t num_boxes = deltas.shape()[1]; + const TDeltas *deltas_ptr = deltas.data(); + T *pred_boxes_ptr = pred_boxes.data(); const int img_h = floor(info.img_height() / info.scale() + 0.5f); const int img_w = floor(info.img_width() / info.scale() + 0.5f); @@ -70,15 +70,15 @@ SimpleTensor bounding_box_transform(const SimpleTensor &boxes, const Simpl for(size_t j = 0; j < num_classes; ++j) { // Extract deltas - const size_t start_delta = i * num_classes * class_fields + class_fields * j; - const T dx = deltas_ptr[start_delta] / T(info.weights()[0]); - const T dy = deltas_ptr[start_delta + 1] / T(info.weights()[1]); - T dw = deltas_ptr[start_delta + 2] / T(info.weights()[2]); - T dh = deltas_ptr[start_delta + 3] / T(info.weights()[3]); + const size_t start_delta = i * num_classes * class_fields + class_fields * j; + const TDeltas dx = deltas_ptr[start_delta] / TDeltas(info.weights()[0]); + const TDeltas dy = deltas_ptr[start_delta + 1] / TDeltas(info.weights()[1]); + TDeltas dw = deltas_ptr[start_delta + 2] / TDeltas(info.weights()[2]); + TDeltas dh = deltas_ptr[start_delta + 3] / TDeltas(info.weights()[3]); // Clip dw and dh - dw = std::min(dw, T(info.bbox_xform_clip())); - dh = std::min(dh, T(info.bbox_xform_clip())); + dw = std::min(dw, TDeltas(info.bbox_xform_clip())); + dh = std::min(dh, TDeltas(info.bbox_xform_clip())); // Determine the predictions const T pred_ctr_x = dx * width + ctr_x; @@ -98,6 +98,16 @@ SimpleTensor bounding_box_transform(const SimpleTensor &boxes, const Simpl template SimpleTensor bounding_box_transform(const SimpleTensor &boxes, const SimpleTensor &deltas, const BoundingBoxTransformInfo &info); template SimpleTensor bounding_box_transform(const SimpleTensor &boxes, const SimpleTensor &deltas, const BoundingBoxTransformInfo &info); + +template <> +SimpleTensor bounding_box_transform(const SimpleTensor &boxes, const SimpleTensor &deltas, const BoundingBoxTransformInfo &info) +{ + SimpleTensor boxes_tmp = convert_from_asymmetric(boxes); + SimpleTensor deltas_tmp = convert_from_asymmetric(deltas); + SimpleTensor pred_boxes_tmp = bounding_box_transform(boxes_tmp, deltas_tmp, info); + SimpleTensor pred_boxes = convert_to_asymmetric(pred_boxes_tmp, boxes.quantization_info()); + return pred_boxes; +} } // namespace reference } // namespace validation } // namespace test -- cgit v1.2.1