aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/Helpers.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/Helpers.h')
-rw-r--r--tests/validation/Helpers.h28
1 files changed, 28 insertions, 0 deletions
diff --git a/tests/validation/Helpers.h b/tests/validation/Helpers.h
index a551da731e..cae1976bd6 100644
--- a/tests/validation/Helpers.h
+++ b/tests/validation/Helpers.h
@@ -172,6 +172,34 @@ inline void fill_mask_from_pattern(uint8_t *mask, int cols, int rows, MatrixPatt
}
}
+/** Calculate output tensor shape give a vector of input tensor to concatenate
+ *
+ * @param[in] input_shapes Shapes of the tensors to concatenate across depth.
+ *
+ * @return The shape of output concatenated tensor.
+ */
+inline TensorShape calculate_depth_concatenate_shape(std::vector<TensorShape> input_shapes)
+{
+ TensorShape out_shape = input_shapes.at(0);
+
+ unsigned int max_x = 0;
+ unsigned int max_y = 0;
+ unsigned int depth = 0;
+
+ for(auto const &shape : input_shapes)
+ {
+ max_x = std::max<unsigned int>(shape.x(), max_x);
+ max_y = std::max<unsigned int>(shape.y(), max_y);
+ depth += shape.z();
+ }
+
+ out_shape.set(0, max_x);
+ out_shape.set(1, max_y);
+ out_shape.set(2, depth);
+
+ return out_shape;
+}
+
/** Create a vector of random ROIs.
*
* @param[in] shape The shape of the input tensor.