aboutsummaryrefslogtreecommitdiff
path: root/tests/SimpleTensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/SimpleTensor.h')
-rw-r--r--tests/SimpleTensor.h27
1 files changed, 22 insertions, 5 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h
index 07474ff779..419621e808 100644
--- a/tests/SimpleTensor.h
+++ b/tests/SimpleTensor.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 ARM Limited.
+ * Copyright (c) 2017-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,7 +27,6 @@
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
-#include "support/MemorySupport.h"
#include "tests/IAccessor.h"
#include "tests/Utils.h"
@@ -174,6 +173,15 @@ public:
*/
QuantizationInfo quantization_info() const override;
+ /** Set the quantization information of the tensor.
+ *
+ * This function does not have any effect on the raw quantized data of the tensor.
+ * It simply changes the quantization information, hence changes the dequantized values.
+ *
+ * @return A reference to the current object.
+ */
+ SimpleTensor<T> &quantization_info(const QuantizationInfo &qinfo);
+
/** Constant pointer to the underlying buffer.
*
* @return a constant pointer to the data.
@@ -268,7 +276,7 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, Format format)
_data_layout(DataLayout::NCHW)
{
_num_channels = num_channels();
- _buffer = support::cpp14::make_unique<T[]>(num_elements() * _num_channels);
+ _buffer = std::make_unique<T[]>(num_elements() * _num_channels);
}
template <typename T>
@@ -280,7 +288,7 @@ SimpleTensor<T>::SimpleTensor(TensorShape shape, DataType data_type, int num_cha
_quantization_info(quantization_info),
_data_layout(data_layout)
{
- _buffer = support::cpp14::make_unique<T[]>(this->_shape.total_size() * _num_channels);
+ _buffer = std::make_unique<T[]>(this->_shape.total_size() * _num_channels);
}
template <typename T>
@@ -293,7 +301,7 @@ SimpleTensor<T>::SimpleTensor(const SimpleTensor &tensor)
_quantization_info(tensor.quantization_info()),
_data_layout(tensor.data_layout())
{
- _buffer = support::cpp14::make_unique<T[]>(tensor.num_elements() * _num_channels);
+ _buffer = std::make_unique<T[]>(tensor.num_elements() * _num_channels);
std::copy_n(tensor.data(), this->_shape.total_size() * _num_channels, _buffer.get());
}
@@ -336,6 +344,13 @@ QuantizationInfo SimpleTensor<T>::quantization_info() const
}
template <typename T>
+SimpleTensor<T> &SimpleTensor<T>::quantization_info(const QuantizationInfo &qinfo)
+{
+ _quantization_info = qinfo;
+ return *this;
+}
+
+template <typename T>
size_t SimpleTensor<T>::size() const
{
const size_t size = std::accumulate(_shape.cbegin(), _shape.cend(), 1, std::multiplies<size_t>());
@@ -377,6 +392,8 @@ int SimpleTensor<T>::num_channels() const
case Format::S16:
case Format::U32:
case Format::S32:
+ case Format::U64:
+ case Format::S64:
case Format::F16:
case Format::F32:
return 1;