diff options
author | Chunosov <N.Chunosov@yandex.ru> | 2017-11-03 17:33:15 +0700 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | d621bca4e963555a99be4328c8d49d1813789649 (patch) | |
tree | 59503f9d4cdbaafefdba5a2569bf3d88082ad09d /tests/SimpleTensor.h | |
parent | 5a99ddf2dcf3a5eb49ea85cb8bcc6a43f1496e5e (diff) | |
download | ComputeLibrary-d621bca4e963555a99be4328c8d49d1813789649.tar.gz |
COMPMID-661: directconv-uint8 (#20)
Change-Id: I84f7a1ce3658be0d3c91e65096467258af48f0b6
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/94341
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/SimpleTensor.h')
-rw-r--r-- | tests/SimpleTensor.h | 38 |
1 files changed, 27 insertions, 11 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h index 0f79a3899a..6091991e66 100644 --- a/tests/SimpleTensor.h +++ b/tests/SimpleTensor.h @@ -76,8 +76,11 @@ public: * @param[in] data_type Data type of the new raw tensor. * @param[in] num_channels (Optional) Number of channels (default = 1). * @param[in] fixed_point_position (Optional) Number of bits for the fractional part of the fixed point numbers (default = 0). + * @param[in] quantization_info (Optional) Quantization info for asymmetric quantization (default = empty). */ - SimpleTensor(TensorShape shape, DataType data_type, int num_channels = 1, int fixed_point_position = 0); + SimpleTensor(TensorShape shape, DataType data_type, + int num_channels = 1, + int fixed_point_position = 0, QuantizationInfo quantization_info = QuantizationInfo()); /** Create a deep copy of the given @p tensor. * @@ -137,6 +140,9 @@ public: /** The number of bits for the fractional part of the fixed point numbers. */ int fixed_point_position() const override; + /** Quantization info in case of asymmetric quantized type */ + QuantizationInfo quantization_info() const override; + /** Constant pointer to the underlying buffer. */ const T *data() const; @@ -168,12 +174,13 @@ public: friend void swap(SimpleTensor<U> &tensor1, SimpleTensor<U> &tensor2); protected: - Buffer _buffer{ nullptr }; - TensorShape _shape{}; - Format _format{ Format::UNKNOWN }; - DataType _data_type{ DataType::UNKNOWN }; - int _num_channels{ 0 }; - int _fixed_point_position{ 0 }; + Buffer _buffer{ nullptr }; + TensorShape _shape{}; + Format _format{ Format::UNKNOWN }; + DataType _data_type{ DataType::UNKNOWN }; + int _num_channels{ 0 }; + int _fixed_point_position{ 0 }; + QuantizationInfo _quantization_info{}; }; template <typename T> @@ -181,18 +188,20 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, Format format, int fixed_point_ : _buffer(nullptr), _shape(shape), _format(format), - _fixed_point_position(fixed_point_position) + _fixed_point_position(fixed_point_position), + _quantization_info() { _buffer = support::cpp14::make_unique<T[]>(num_elements() * num_channels()); } template <typename T> -SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position) +SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position, QuantizationInfo quantization_info) : _buffer(nullptr), _shape(shape), _data_type(data_type), _num_channels(num_channels), - _fixed_point_position(fixed_point_position) + _fixed_point_position(fixed_point_position), + _quantization_info(quantization_info) { _buffer = support::cpp14::make_unique<T[]>(num_elements() * this->num_channels()); } @@ -204,7 +213,8 @@ SimpleTensor<T>::SimpleTensor(const SimpleTensor &tensor) _format(tensor.format()), _data_type(tensor.data_type()), _num_channels(tensor.num_channels()), - _fixed_point_position(tensor.fixed_point_position()) + _fixed_point_position(tensor.fixed_point_position()), + _quantization_info(tensor.quantization_info()) { _buffer = support::cpp14::make_unique<T[]>(tensor.num_elements() * num_channels()); std::copy_n(tensor.data(), num_elements() * num_channels(), _buffer.get()); @@ -249,6 +259,12 @@ int SimpleTensor<T>::fixed_point_position() const } template <typename T> +QuantizationInfo SimpleTensor<T>::quantization_info() const +{ + return _quantization_info; +} + +template <typename T> size_t SimpleTensor<T>::size() const { const size_t size = std::accumulate(_shape.cbegin(), _shape.cend(), 1, std::multiplies<size_t>()); |