diff options
Diffstat (limited to 'tests/SimpleTensor.h')
-rw-r--r-- | tests/SimpleTensor.h | 38 |
1 files changed, 27 insertions, 11 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h index 0f79a3899a..6091991e66 100644 --- a/tests/SimpleTensor.h +++ b/tests/SimpleTensor.h @@ -76,8 +76,11 @@ public: * @param[in] data_type Data type of the new raw tensor. * @param[in] num_channels (Optional) Number of channels (default = 1). * @param[in] fixed_point_position (Optional) Number of bits for the fractional part of the fixed point numbers (default = 0). + * @param[in] quantization_info (Optional) Quantization info for asymmetric quantization (default = empty). */ - SimpleTensor(TensorShape shape, DataType data_type, int num_channels = 1, int fixed_point_position = 0); + SimpleTensor(TensorShape shape, DataType data_type, + int num_channels = 1, + int fixed_point_position = 0, QuantizationInfo quantization_info = QuantizationInfo()); /** Create a deep copy of the given @p tensor. * @@ -137,6 +140,9 @@ public: /** The number of bits for the fractional part of the fixed point numbers. */ int fixed_point_position() const override; + /** Quantization info in case of asymmetric quantized type */ + QuantizationInfo quantization_info() const override; + /** Constant pointer to the underlying buffer. */ const T *data() const; @@ -168,12 +174,13 @@ public: friend void swap(SimpleTensor<U> &tensor1, SimpleTensor<U> &tensor2); protected: - Buffer _buffer{ nullptr }; - TensorShape _shape{}; - Format _format{ Format::UNKNOWN }; - DataType _data_type{ DataType::UNKNOWN }; - int _num_channels{ 0 }; - int _fixed_point_position{ 0 }; + Buffer _buffer{ nullptr }; + TensorShape _shape{}; + Format _format{ Format::UNKNOWN }; + DataType _data_type{ DataType::UNKNOWN }; + int _num_channels{ 0 }; + int _fixed_point_position{ 0 }; + QuantizationInfo _quantization_info{}; }; template <typename T> @@ -181,18 +188,20 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, Format format, int fixed_point_ : _buffer(nullptr), _shape(shape), _format(format), - _fixed_point_position(fixed_point_position) + _fixed_point_position(fixed_point_position), + _quantization_info() { _buffer = support::cpp14::make_unique<T[]>(num_elements() * num_channels()); } template <typename T> -SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position) +SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position, QuantizationInfo quantization_info) : _buffer(nullptr), _shape(shape), _data_type(data_type), _num_channels(num_channels), - _fixed_point_position(fixed_point_position) + _fixed_point_position(fixed_point_position), + _quantization_info(quantization_info) { _buffer = support::cpp14::make_unique<T[]>(num_elements() * this->num_channels()); } @@ -204,7 +213,8 @@ SimpleTensor<T>::SimpleTensor(const SimpleTensor &tensor) _format(tensor.format()), _data_type(tensor.data_type()), _num_channels(tensor.num_channels()), - _fixed_point_position(tensor.fixed_point_position()) + _fixed_point_position(tensor.fixed_point_position()), + _quantization_info(tensor.quantization_info()) { _buffer = support::cpp14::make_unique<T[]>(tensor.num_elements() * num_channels()); std::copy_n(tensor.data(), num_elements() * num_channels(), _buffer.get()); @@ -249,6 +259,12 @@ int SimpleTensor<T>::fixed_point_position() const } template <typename T> +QuantizationInfo SimpleTensor<T>::quantization_info() const +{ + return _quantization_info; +} + +template <typename T> size_t SimpleTensor<T>::size() const { const size_t size = std::accumulate(_shape.cbegin(), _shape.cend(), 1, std::multiplies<size_t>()); |