aboutsummaryrefslogtreecommitdiff
path: root/tests/RawTensor.h
diff options
context:
space:
mode:
authorMoritz Pflanzer <moritz.pflanzer@arm.com>2017-09-08 09:53:14 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commitcde1e8adeacea5c33a1682ef7b05a0ef643463b8 (patch)
tree47e58abdf5bb6ef39db362a2ac777c93b3f76666 /tests/RawTensor.h
parent86b53339679e12c952a24a8845a5409ac3d52de6 (diff)
downloadComputeLibrary-cde1e8adeacea5c33a1682ef7b05a0ef643463b8.tar.gz
COMPMID-415: Add tests for ConvolutionLayer reshaped weights
Change-Id: I6c1209a2afafccba2cbdbcda16aceb3ae0cc7b4b Reviewed-on: http://mpd-gerrit.cambridge.arm.com/87000 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests/RawTensor.h')
-rw-r--r--tests/RawTensor.h38
1 files changed, 38 insertions, 0 deletions
diff --git a/tests/RawTensor.h b/tests/RawTensor.h
index fd0ab2b9fd..116275d617 100644
--- a/tests/RawTensor.h
+++ b/tests/RawTensor.h
@@ -55,6 +55,44 @@ public:
*/
RawTensor(TensorShape shape, DataType data_type, int num_channels = 1, int fixed_point_position = 0);
+ /** Conversion constructor from SimpleTensor.
+ *
+ * The passed SimpleTensor will be destroyed after it has been converted to
+ * a RawTensor.
+ *
+ * @param[in,out] tensor SimpleTensor to be converted to a RawTensor.
+ */
+ template <typename T>
+ RawTensor(SimpleTensor<T> &&tensor)
+ {
+ _buffer = std::unique_ptr<uint8_t[]>(reinterpret_cast<uint8_t *>(tensor._buffer.release()));
+ _shape = std::move(tensor._shape);
+ _format = tensor._format;
+ _data_type = tensor._data_type;
+ _num_channels = tensor._num_channels;
+ _fixed_point_position = tensor._fixed_point_position;
+ }
+
+ /** Conversion operator to SimpleTensor.
+ *
+ * The current RawTensor must not be used after the conversion.
+ *
+ * @return SimpleTensor of the given type.
+ */
+ template <typename T>
+ operator SimpleTensor<T>()
+ {
+ SimpleTensor<T> cast;
+ cast._buffer = std::unique_ptr<T[]>(reinterpret_cast<T *>(_buffer.release()));
+ cast._shape = std::move(_shape);
+ cast._format = _format;
+ cast._data_type = _data_type;
+ cast._num_channels = _num_channels;
+ cast._fixed_point_position = _fixed_point_position;
+
+ return cast;
+ }
+
/** Create a deep copy of the given @p tensor.
*
* @param[in] tensor To be copied tensor.