diff options
Diffstat (limited to 'tests/RawTensor.h')
-rw-r--r-- | tests/RawTensor.h | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/tests/RawTensor.h b/tests/RawTensor.h index fd0ab2b9fd..116275d617 100644 --- a/tests/RawTensor.h +++ b/tests/RawTensor.h @@ -55,6 +55,44 @@ public: */ RawTensor(TensorShape shape, DataType data_type, int num_channels = 1, int fixed_point_position = 0); + /** Conversion constructor from SimpleTensor. + * + * The passed SimpleTensor will be destroyed after it has been converted to + * a RawTensor. + * + * @param[in,out] tensor SimpleTensor to be converted to a RawTensor. + */ + template <typename T> + RawTensor(SimpleTensor<T> &&tensor) + { + _buffer = std::unique_ptr<uint8_t[]>(reinterpret_cast<uint8_t *>(tensor._buffer.release())); + _shape = std::move(tensor._shape); + _format = tensor._format; + _data_type = tensor._data_type; + _num_channels = tensor._num_channels; + _fixed_point_position = tensor._fixed_point_position; + } + + /** Conversion operator to SimpleTensor. + * + * The current RawTensor must not be used after the conversion. + * + * @return SimpleTensor of the given type. + */ + template <typename T> + operator SimpleTensor<T>() + { + SimpleTensor<T> cast; + cast._buffer = std::unique_ptr<T[]>(reinterpret_cast<T *>(_buffer.release())); + cast._shape = std::move(_shape); + cast._format = _format; + cast._data_type = _data_type; + cast._num_channels = _num_channels; + cast._fixed_point_position = _fixed_point_position; + + return cast; + } + /** Create a deep copy of the given @p tensor. * * @param[in] tensor To be copied tensor. |