aboutsummaryrefslogtreecommitdiff
path: root/tests/SimpleTensor.h
diff options
context:
space:
mode:
authorChunosov <N.Chunosov@yandex.ru>2017-11-03 17:33:15 +0700
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commitd621bca4e963555a99be4328c8d49d1813789649 (patch)
tree59503f9d4cdbaafefdba5a2569bf3d88082ad09d /tests/SimpleTensor.h
parent5a99ddf2dcf3a5eb49ea85cb8bcc6a43f1496e5e (diff)
downloadComputeLibrary-d621bca4e963555a99be4328c8d49d1813789649.tar.gz
COMPMID-661: directconv-uint8 (#20)
Change-Id: I84f7a1ce3658be0d3c91e65096467258af48f0b6 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/94341 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/SimpleTensor.h')
-rw-r--r--tests/SimpleTensor.h38
1 files changed, 27 insertions, 11 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h
index 0f79a3899a..6091991e66 100644
--- a/tests/SimpleTensor.h
+++ b/tests/SimpleTensor.h
@@ -76,8 +76,11 @@ public:
* @param[in] data_type Data type of the new raw tensor.
* @param[in] num_channels (Optional) Number of channels (default = 1).
* @param[in] fixed_point_position (Optional) Number of bits for the fractional part of the fixed point numbers (default = 0).
+ * @param[in] quantization_info (Optional) Quantization info for asymmetric quantization (default = empty).
*/
- SimpleTensor(TensorShape shape, DataType data_type, int num_channels = 1, int fixed_point_position = 0);
+ SimpleTensor(TensorShape shape, DataType data_type,
+ int num_channels = 1,
+ int fixed_point_position = 0, QuantizationInfo quantization_info = QuantizationInfo());
/** Create a deep copy of the given @p tensor.
*
@@ -137,6 +140,9 @@ public:
/** The number of bits for the fractional part of the fixed point numbers. */
int fixed_point_position() const override;
+ /** Quantization info in case of asymmetric quantized type */
+ QuantizationInfo quantization_info() const override;
+
/** Constant pointer to the underlying buffer. */
const T *data() const;
@@ -168,12 +174,13 @@ public:
friend void swap(SimpleTensor<U> &tensor1, SimpleTensor<U> &tensor2);
protected:
- Buffer _buffer{ nullptr };
- TensorShape _shape{};
- Format _format{ Format::UNKNOWN };
- DataType _data_type{ DataType::UNKNOWN };
- int _num_channels{ 0 };
- int _fixed_point_position{ 0 };
+ Buffer _buffer{ nullptr };
+ TensorShape _shape{};
+ Format _format{ Format::UNKNOWN };
+ DataType _data_type{ DataType::UNKNOWN };
+ int _num_channels{ 0 };
+ int _fixed_point_position{ 0 };
+ QuantizationInfo _quantization_info{};
};
template <typename T>
@@ -181,18 +188,20 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, Format format, int fixed_point_
: _buffer(nullptr),
_shape(shape),
_format(format),
- _fixed_point_position(fixed_point_position)
+ _fixed_point_position(fixed_point_position),
+ _quantization_info()
{
_buffer = support::cpp14::make_unique<T[]>(num_elements() * num_channels());
}
template <typename T>
-SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position)
+SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_channels, int fixed_point_position, QuantizationInfo quantization_info)
: _buffer(nullptr),
_shape(shape),
_data_type(data_type),
_num_channels(num_channels),
- _fixed_point_position(fixed_point_position)
+ _fixed_point_position(fixed_point_position),
+ _quantization_info(quantization_info)
{
_buffer = support::cpp14::make_unique<T[]>(num_elements() * this->num_channels());
}
@@ -204,7 +213,8 @@ SimpleTensor<T>::SimpleTensor(const SimpleTensor &tensor)
_format(tensor.format()),
_data_type(tensor.data_type()),
_num_channels(tensor.num_channels()),
- _fixed_point_position(tensor.fixed_point_position())
+ _fixed_point_position(tensor.fixed_point_position()),
+ _quantization_info(tensor.quantization_info())
{
_buffer = support::cpp14::make_unique<T[]>(tensor.num_elements() * num_channels());
std::copy_n(tensor.data(), num_elements() * num_channels(), _buffer.get());
@@ -249,6 +259,12 @@ int SimpleTensor<T>::fixed_point_position() const
}
template <typename T>
+QuantizationInfo SimpleTensor<T>::quantization_info() const
+{
+ return _quantization_info;
+}
+
+template <typename T>
size_t SimpleTensor<T>::size() const
{
const size_t size = std::accumulate(_shape.cbegin(), _shape.cend(), 1, std::multiplies<size_t>());