diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2018-03-02 11:21:38 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:48:33 +0000 |
commit | 4a65b9855f71fff11a4c18d2fa4bccc74303e5c6 (patch) | |
tree | 13b62f158e57913a55cd9bf409c77b7b4b3e71f2 /tests/SimpleTensor.h | |
parent | 898db6f31f015d077fb87ef82e22abea74a8f710 (diff) | |
download | ComputeLibrary-4a65b9855f71fff11a4c18d2fa4bccc74303e5c6.tar.gz |
COMPMID-991: Add data layout information to the test framework
Change-Id: Iccf66f7476e697e8fdee5a7441cc06b936bbba09
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/122986
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/SimpleTensor.h')
-rw-r--r-- | tests/SimpleTensor.h | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h index f3155ffeab..f9e49bc85b 100644 --- a/tests/SimpleTensor.h +++ b/tests/SimpleTensor.h @@ -77,10 +77,13 @@ public: * @param[in] num_channels (Optional) Number of channels (default = 1). * @param[in] fixed_point_position (Optional) Number of bits for the fractional part of the fixed point numbers (default = 0). * @param[in] quantization_info (Optional) Quantization info for asymmetric quantization (default = empty). + * @param[in] data_layout (Optional) Data layout of the tensor (default = NCHW). */ SimpleTensor(TensorShape shape, DataType data_type, - int num_channels = 1, - int fixed_point_position = 0, QuantizationInfo quantization_info = QuantizationInfo()); + int num_channels = 1, + int fixed_point_position = 0, + QuantizationInfo quantization_info = QuantizationInfo(), + DataLayout data_layout = DataLayout::NCHW); /** Create a deep copy of the given @p tensor. * @@ -122,6 +125,9 @@ public: /** Total size of the tensor in bytes. */ size_t size() const override; + /** Data layout of the tensor. */ + DataLayout data_layout() const override; + /** Image format of the tensor. */ Format format() const override; @@ -181,6 +187,7 @@ protected: int _num_channels{ 0 }; int _fixed_point_position{ 0 }; QuantizationInfo _quantization_info{}; + DataLayout _data_layout{ DataLayout::UNKNOWN }; }; template <typename T> @@ -196,13 +203,14 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, Format format, int fixed_point_ } template <typename T> -SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position, QuantizationInfo quantization_info) +SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position, QuantizationInfo quantization_info, DataLayout data_layout) : _buffer(nullptr), _shape(shape), _data_type(data_type), _num_channels(num_channels), _fixed_point_position(fixed_point_position), - _quantization_info(quantization_info) + _quantization_info(quantization_info), + _data_layout(data_layout) { _buffer = support::cpp14::make_unique<T[]>(num_elements() * this->num_channels()); } @@ -279,6 +287,12 @@ Format SimpleTensor<T>::format() const } template <typename T> +DataLayout SimpleTensor<T>::data_layout() const +{ + return _data_layout; +} + +template <typename T> DataType SimpleTensor<T>::data_type() const { if(_format != Format::UNKNOWN) |