aboutsummaryrefslogtreecommitdiff
path: root/tests/SimpleTensor.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/SimpleTensor.h')
-rw-r--r--tests/SimpleTensor.h20
1 files changed, 19 insertions, 1 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h
index c1bd7f87b5..419621e808 100644
--- a/tests/SimpleTensor.h
+++ b/tests/SimpleTensor.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -173,6 +173,15 @@ public:
*/
QuantizationInfo quantization_info() const override;
+ /** Set the quantization information of the tensor.
+ *
+ * This function does not have any effect on the raw quantized data of the tensor.
+ * It simply changes the quantization information, hence changes the dequantized values.
+ *
+ * @return A reference to the current object.
+ */
+ SimpleTensor<T> &quantization_info(const QuantizationInfo &qinfo);
+
/** Constant pointer to the underlying buffer.
*
* @return a constant pointer to the data.
@@ -335,6 +344,13 @@ QuantizationInfo SimpleTensor<T>::quantization_info() const
}
template <typename T>
+SimpleTensor<T> &SimpleTensor<T>::quantization_info(const QuantizationInfo &qinfo)
+{
+ _quantization_info = qinfo;
+ return *this;
+}
+
+template <typename T>
size_t SimpleTensor<T>::size() const
{
const size_t size = std::accumulate(_shape.cbegin(), _shape.cend(), 1, std::multiplies<size_t>());
@@ -376,6 +392,8 @@ int SimpleTensor<T>::num_channels() const
case Format::S16:
case Format::U32:
case Format::S32:
+ case Format::U64:
+ case Format::S64:
case Format::F16:
case Format::F32:
return 1;