diff options
-rw-r--r-- | tests/CL/CLAccessor.h | 8 | ||||
-rw-r--r-- | tests/GLES_COMPUTE/GCAccessor.h | 8 | ||||
-rw-r--r-- | tests/IAccessor.h | 5 | ||||
-rw-r--r-- | tests/NEON/Accessor.h | 8 | ||||
-rw-r--r-- | tests/RawTensor.h | 4 | ||||
-rw-r--r-- | tests/SimpleTensor.h | 22 | ||||
-rw-r--r-- | utils/TypePrinter.h | 28 |
7 files changed, 74 insertions, 9 deletions
diff --git a/tests/CL/CLAccessor.h b/tests/CL/CLAccessor.h index 9e7b73f34f..f2e13f1232 100644 --- a/tests/CL/CLAccessor.h +++ b/tests/CL/CLAccessor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -56,6 +56,7 @@ public: size_t element_size() const override; size_t size() const override; Format format() const override; + DataLayout data_layout() const override; DataType data_type() const override; int num_channels() const override; int num_elements() const override; @@ -102,6 +103,11 @@ inline Format CLAccessor::format() const return _tensor.info()->format(); } +inline DataLayout CLAccessor::data_layout() const +{ + return _tensor.info()->data_layout(); +} + inline DataType CLAccessor::data_type() const { return _tensor.info()->data_type(); diff --git a/tests/GLES_COMPUTE/GCAccessor.h b/tests/GLES_COMPUTE/GCAccessor.h index 0f7c491c3c..ccf4caabaf 100644 --- a/tests/GLES_COMPUTE/GCAccessor.h +++ b/tests/GLES_COMPUTE/GCAccessor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -56,6 +56,7 @@ public: size_t element_size() const override; size_t size() const override; Format format() const override; + DataLayout data_layout() const override; DataType data_type() const override; int num_channels() const override; int num_elements() const override; @@ -100,6 +101,11 @@ inline Format GCAccessor::format() const return _tensor.info()->format(); } +inline DataLayout GCAccessor::data_layout() const +{ + return _tensor.info()->data_layout(); +} + inline DataType GCAccessor::data_type() const { return _tensor.info()->data_type(); diff --git a/tests/IAccessor.h b/tests/IAccessor.h index 3744fc8c02..6170bc0ba1 100644 --- a/tests/IAccessor.h +++ b/tests/IAccessor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -52,6 +52,9 @@ public: /** Image format of the tensor. */ virtual Format format() const = 0; + /** Data layout of the tensor. */ + virtual DataLayout data_layout() const = 0; + /** Data type of the tensor. */ virtual DataType data_type() const = 0; diff --git a/tests/NEON/Accessor.h b/tests/NEON/Accessor.h index 2bad53b3fe..e285f227de 100644 --- a/tests/NEON/Accessor.h +++ b/tests/NEON/Accessor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -50,6 +50,7 @@ public: size_t element_size() const override; size_t size() const override; Format format() const override; + DataLayout data_layout() const override; DataType data_type() const override; int num_channels() const override; int num_elements() const override; @@ -90,6 +91,11 @@ inline Format Accessor::format() const return _tensor.info()->format(); } +inline DataLayout Accessor::data_layout() const +{ + return _tensor.info()->data_layout(); +} + inline DataType Accessor::data_type() const { return _tensor.info()->data_type(); diff --git a/tests/RawTensor.h b/tests/RawTensor.h index 116275d617..6b1b904e13 100644 --- a/tests/RawTensor.h +++ b/tests/RawTensor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -71,6 +71,7 @@ public: _data_type = tensor._data_type; _num_channels = tensor._num_channels; _fixed_point_position = tensor._fixed_point_position; + _data_layout = tensor._data_layout; } /** Conversion operator to SimpleTensor. @@ -89,6 +90,7 @@ public: cast._data_type = _data_type; cast._num_channels = _num_channels; cast._fixed_point_position = _fixed_point_position; + cast._data_layout = _data_layout; return cast; } diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h index f3155ffeab..f9e49bc85b 100644 --- a/tests/SimpleTensor.h +++ b/tests/SimpleTensor.h @@ -77,10 +77,13 @@ public: * @param[in] num_channels (Optional) Number of channels (default = 1). * @param[in] fixed_point_position (Optional) Number of bits for the fractional part of the fixed point numbers (default = 0). * @param[in] quantization_info (Optional) Quantization info for asymmetric quantization (default = empty). + * @param[in] data_layout (Optional) Data layout of the tensor (default = NCHW). */ SimpleTensor(TensorShape shape, DataType data_type, - int num_channels = 1, - int fixed_point_position = 0, QuantizationInfo quantization_info = QuantizationInfo()); + int num_channels = 1, + int fixed_point_position = 0, + QuantizationInfo quantization_info = QuantizationInfo(), + DataLayout data_layout = DataLayout::NCHW); /** Create a deep copy of the given @p tensor. * @@ -122,6 +125,9 @@ public: /** Total size of the tensor in bytes. */ size_t size() const override; + /** Data layout of the tensor. */ + DataLayout data_layout() const override; + /** Image format of the tensor. */ Format format() const override; @@ -181,6 +187,7 @@ protected: int _num_channels{ 0 }; int _fixed_point_position{ 0 }; QuantizationInfo _quantization_info{}; + DataLayout _data_layout{ DataLayout::UNKNOWN }; }; template <typename T> @@ -196,13 +203,14 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, Format format, int fixed_point_ } template <typename T> -SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position, QuantizationInfo quantization_info) +SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position, QuantizationInfo quantization_info, DataLayout data_layout) : _buffer(nullptr), _shape(shape), _data_type(data_type), _num_channels(num_channels), _fixed_point_position(fixed_point_position), - _quantization_info(quantization_info) + _quantization_info(quantization_info), + _data_layout(data_layout) { _buffer = support::cpp14::make_unique<T[]>(num_elements() * this->num_channels()); } @@ -279,6 +287,12 @@ Format SimpleTensor<T>::format() const } template <typename T> +DataLayout SimpleTensor<T>::data_layout() const +{ + return _data_layout; +} + +template <typename T> DataType SimpleTensor<T>::data_type() const { if(_format != Format::UNKNOWN) diff --git a/utils/TypePrinter.h b/utils/TypePrinter.h index 63fba35052..e5f860812d 100644 --- a/utils/TypePrinter.h +++ b/utils/TypePrinter.h @@ -338,6 +338,34 @@ inline std::string to_string(const RoundingPolicy &rounding_policy) return str.str(); } +/** Formatted output of the DataLayout type. */ +inline ::std::ostream &operator<<(::std::ostream &os, const DataLayout &data_layout) +{ + switch(data_layout) + { + case DataLayout::UNKNOWN: + os << "UNKNOWN"; + break; + case DataLayout::NHWC: + os << "NHWC"; + break; + case DataLayout::NCHW: + os << "NCHW"; + break; + default: + ARM_COMPUTE_ERROR("NOT_SUPPORTED!"); + } + + return os; +} + +inline std::string to_string(const arm_compute::DataLayout &data_layout) +{ + std::stringstream str; + str << data_layout; + return str.str(); +} + /** Formatted output of the DataType type. */ inline ::std::ostream &operator<<(::std::ostream &os, const DataType &data_type) { |