aboutsummaryrefslogtreecommitdiff
path: root/tests/RawTensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/RawTensor.h')
-rw-r--r--tests/RawTensor.h38
1 files changed, 38 insertions, 0 deletions
diff --git a/tests/RawTensor.h b/tests/RawTensor.h
index fd0ab2b9fd..116275d617 100644
--- a/tests/RawTensor.h
+++ b/tests/RawTensor.h
@@ -55,6 +55,44 @@ public:
*/
RawTensor(TensorShape shape, DataType data_type, int num_channels = 1, int fixed_point_position = 0);
+ /** Conversion constructor from SimpleTensor.
+ *
+ * The passed SimpleTensor will be destroyed after it has been converted to
+ * a RawTensor.
+ *
+ * @param[in,out] tensor SimpleTensor to be converted to a RawTensor.
+ */
+ template <typename T>
+ RawTensor(SimpleTensor<T> &&tensor)
+ {
+ _buffer = std::unique_ptr<uint8_t[]>(reinterpret_cast<uint8_t *>(tensor._buffer.release()));
+ _shape = std::move(tensor._shape);
+ _format = tensor._format;
+ _data_type = tensor._data_type;
+ _num_channels = tensor._num_channels;
+ _fixed_point_position = tensor._fixed_point_position;
+ }
+
+ /** Conversion operator to SimpleTensor.
+ *
+ * The current RawTensor must not be used after the conversion.
+ *
+ * @return SimpleTensor of the given type.
+ */
+ template <typename T>
+ operator SimpleTensor<T>()
+ {
+ SimpleTensor<T> cast;
+ cast._buffer = std::unique_ptr<T[]>(reinterpret_cast<T *>(_buffer.release()));
+ cast._shape = std::move(_shape);
+ cast._format = _format;
+ cast._data_type = _data_type;
+ cast._num_channels = _num_channels;
+ cast._fixed_point_position = _fixed_point_position;
+
+ return cast;
+ }
+
/** Create a deep copy of the given @p tensor.
*
* @param[in] tensor To be copied tensor.