aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/BoundingBoxTransform.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/BoundingBoxTransform.cpp')
-rw-r--r--tests/validation/reference/BoundingBoxTransform.cpp40
1 files changed, 25 insertions, 15 deletions
diff --git a/tests/validation/reference/BoundingBoxTransform.cpp b/tests/validation/reference/BoundingBoxTransform.cpp
index 55dd165b51..e09bcff1c6 100644
--- a/tests/validation/reference/BoundingBoxTransform.cpp
+++ b/tests/validation/reference/BoundingBoxTransform.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,16 +36,16 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> bounding_box_transform(const SimpleTensor<T> &boxes, const SimpleTensor<T> &deltas, const BoundingBoxTransformInfo &info)
+template <typename T, typename TDeltas>
+SimpleTensor<T> bounding_box_transform(const SimpleTensor<T> &boxes, const SimpleTensor<TDeltas> &deltas, const BoundingBoxTransformInfo &info)
{
- const DataType boxes_data_type = deltas.data_type();
+ const DataType boxes_data_type = boxes.data_type();
SimpleTensor<T> pred_boxes(deltas.shape(), boxes_data_type);
- const size_t num_classes = deltas.shape()[0] / 4;
- const size_t num_boxes = deltas.shape()[1];
- const T *deltas_ptr = deltas.data();
- T *pred_boxes_ptr = pred_boxes.data();
+ const size_t num_classes = deltas.shape()[0] / 4;
+ const size_t num_boxes = deltas.shape()[1];
+ const TDeltas *deltas_ptr = deltas.data();
+ T *pred_boxes_ptr = pred_boxes.data();
const int img_h = floor(info.img_height() / info.scale() + 0.5f);
const int img_w = floor(info.img_width() / info.scale() + 0.5f);
@@ -70,15 +70,15 @@ SimpleTensor<T> bounding_box_transform(const SimpleTensor<T> &boxes, const Simpl
for(size_t j = 0; j < num_classes; ++j)
{
// Extract deltas
- const size_t start_delta = i * num_classes * class_fields + class_fields * j;
- const T dx = deltas_ptr[start_delta] / T(info.weights()[0]);
- const T dy = deltas_ptr[start_delta + 1] / T(info.weights()[1]);
- T dw = deltas_ptr[start_delta + 2] / T(info.weights()[2]);
- T dh = deltas_ptr[start_delta + 3] / T(info.weights()[3]);
+ const size_t start_delta = i * num_classes * class_fields + class_fields * j;
+ const TDeltas dx = deltas_ptr[start_delta] / TDeltas(info.weights()[0]);
+ const TDeltas dy = deltas_ptr[start_delta + 1] / TDeltas(info.weights()[1]);
+ TDeltas dw = deltas_ptr[start_delta + 2] / TDeltas(info.weights()[2]);
+ TDeltas dh = deltas_ptr[start_delta + 3] / TDeltas(info.weights()[3]);
// Clip dw and dh
- dw = std::min(dw, T(info.bbox_xform_clip()));
- dh = std::min(dh, T(info.bbox_xform_clip()));
+ dw = std::min(dw, TDeltas(info.bbox_xform_clip()));
+ dh = std::min(dh, TDeltas(info.bbox_xform_clip()));
// Determine the predictions
const T pred_ctr_x = dx * width + ctr_x;
@@ -98,6 +98,16 @@ SimpleTensor<T> bounding_box_transform(const SimpleTensor<T> &boxes, const Simpl
template SimpleTensor<float> bounding_box_transform(const SimpleTensor<float> &boxes, const SimpleTensor<float> &deltas, const BoundingBoxTransformInfo &info);
template SimpleTensor<half> bounding_box_transform(const SimpleTensor<half> &boxes, const SimpleTensor<half> &deltas, const BoundingBoxTransformInfo &info);
+
+template <>
+SimpleTensor<uint16_t> bounding_box_transform(const SimpleTensor<uint16_t> &boxes, const SimpleTensor<uint8_t> &deltas, const BoundingBoxTransformInfo &info)
+{
+ SimpleTensor<float> boxes_tmp = convert_from_asymmetric(boxes);
+ SimpleTensor<float> deltas_tmp = convert_from_asymmetric(deltas);
+ SimpleTensor<float> pred_boxes_tmp = bounding_box_transform<float, float>(boxes_tmp, deltas_tmp, info);
+ SimpleTensor<uint16_t> pred_boxes = convert_to_asymmetric<uint16_t>(pred_boxes_tmp, boxes.quantization_info());
+ return pred_boxes;
+}
} // namespace reference
} // namespace validation
} // namespace test