aboutsummaryrefslogtreecommitdiff
path: root/tests/SimpleTensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/SimpleTensor.h')
-rw-r--r--tests/SimpleTensor.h22
1 files changed, 18 insertions, 4 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h
index f3155ffeab..f9e49bc85b 100644
--- a/tests/SimpleTensor.h
+++ b/tests/SimpleTensor.h
@@ -77,10 +77,13 @@ public:
* @param[in] num_channels (Optional) Number of channels (default = 1).
* @param[in] fixed_point_position (Optional) Number of bits for the fractional part of the fixed point numbers (default = 0).
* @param[in] quantization_info (Optional) Quantization info for asymmetric quantization (default = empty).
+ * @param[in] data_layout (Optional) Data layout of the tensor (default = NCHW).
*/
SimpleTensor(TensorShape shape, DataType data_type,
- int num_channels = 1,
- int fixed_point_position = 0, QuantizationInfo quantization_info = QuantizationInfo());
+ int num_channels = 1,
+ int fixed_point_position = 0,
+ QuantizationInfo quantization_info = QuantizationInfo(),
+ DataLayout data_layout = DataLayout::NCHW);
/** Create a deep copy of the given @p tensor.
*
@@ -122,6 +125,9 @@ public:
/** Total size of the tensor in bytes. */
size_t size() const override;
+ /** Data layout of the tensor. */
+ DataLayout data_layout() const override;
+
/** Image format of the tensor. */
Format format() const override;
@@ -181,6 +187,7 @@ protected:
int _num_channels{ 0 };
int _fixed_point_position{ 0 };
QuantizationInfo _quantization_info{};
+ DataLayout _data_layout{ DataLayout::UNKNOWN };
};
template <typename T>
@@ -196,13 +203,14 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, Format format, int fixed_point_
}
template <typename T>
-SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position, QuantizationInfo quantization_info)
+SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position, QuantizationInfo quantization_info, DataLayout data_layout)
: _buffer(nullptr),
_shape(shape),
_data_type(data_type),
_num_channels(num_channels),
_fixed_point_position(fixed_point_position),
- _quantization_info(quantization_info)
+ _quantization_info(quantization_info),
+ _data_layout(data_layout)
{
_buffer = support::cpp14::make_unique<T[]>(num_elements() * this->num_channels());
}
@@ -279,6 +287,12 @@ Format SimpleTensor<T>::format() const
}
template <typename T>
+DataLayout SimpleTensor<T>::data_layout() const
+{
+ return _data_layout;
+}
+
+template <typename T>
DataType SimpleTensor<T>::data_type() const
{
if(_format != Format::UNKNOWN)