diff options
Diffstat (limited to 'src/armnn/Tensor.cpp')
-rw-r--r-- | src/armnn/Tensor.cpp | 116 |
1 files changed, 101 insertions, 15 deletions
diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp index 614abc77f5..f4b8b509b6 100644 --- a/src/armnn/Tensor.cpp +++ b/src/armnn/Tensor.cpp @@ -2,6 +2,7 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // + #include "armnn/Tensor.hpp" #include "armnn/Utils.hpp" #include "armnn/Exceptions.hpp" @@ -138,30 +139,57 @@ TensorInfo::TensorInfo() { } -TensorInfo::TensorInfo(const TensorShape& shape, DataType dataType, - float quantizationScale, int32_t quantizationOffset) - : m_Shape(shape) - , m_DataType(dataType) +TensorInfo::TensorInfo(const TensorShape& shape, + DataType dataType, + float quantizationScale, + int32_t quantizationOffset) + : m_Shape(shape) + , m_DataType(dataType) +{ + SetQuantizationScale(quantizationScale); + SetQuantizationOffset(quantizationOffset); +} + +TensorInfo::TensorInfo(unsigned int numDimensions, + const unsigned int* dimensionSizes, + DataType dataType, + float quantizationScale, + int32_t quantizationOffset) + : m_Shape(numDimensions, dimensionSizes) + , m_DataType(dataType) { - m_Quantization.m_Scale = quantizationScale; - m_Quantization.m_Offset = quantizationOffset; + SetQuantizationScale(quantizationScale); + SetQuantizationOffset(quantizationOffset); } -TensorInfo::TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, - float quantizationScale, int32_t quantizationOffset) - : m_Shape(numDimensions, dimensionSizes) - , m_DataType(dataType) +TensorInfo::TensorInfo(const TensorShape& shape, + DataType dataType, + const std::vector<float>& quantizationScales, + unsigned int quantizationDim) + : m_Shape(shape) + , m_DataType(dataType) { - m_Quantization.m_Scale = quantizationScale; - m_Quantization.m_Offset = quantizationOffset; + SetQuantizationScales(quantizationScales); + SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim)); +} + +TensorInfo::TensorInfo(unsigned int numDimensions, + const unsigned int* dimensionSizes, + DataType dataType, + const std::vector<float>& quantizationScales, + unsigned int quantizationDim) + : m_Shape(numDimensions, dimensionSizes) + , m_DataType(dataType) +{ + SetQuantizationScales(quantizationScales); + SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim)); } TensorInfo::TensorInfo(const TensorInfo& other) : m_Shape(other.m_Shape) , m_DataType(other.m_DataType) , m_Quantization(other.m_Quantization) -{ -} +{} TensorInfo& TensorInfo::operator=(const TensorInfo& other) { @@ -194,7 +222,7 @@ bool TensorInfo::IsTypeSpaceMatch(const TensorInfo& other) const match &= m_DataType == other.m_DataType; - if (IsQuantized()) + if (IsQuantized() && !HasMultipleQuantizationScales()) { match &= GetQuantizationScale() == other.GetQuantizationScale() && GetQuantizationOffset() == other.GetQuantizationOffset(); @@ -202,6 +230,64 @@ bool TensorInfo::IsTypeSpaceMatch(const TensorInfo& other) const return match; } +std::vector<float> TensorInfo::GetQuantizationScales() const +{ + return m_Quantization.m_Scales; +} + +void TensorInfo::SetQuantizationScales(const std::vector<float>& scales) +{ + m_Quantization.m_Scales = scales; +} + +float TensorInfo::GetQuantizationScale() const +{ + if (m_Quantization.m_Scales.empty()) + { + // NOTE: old default for backward compatibility + return 1.0f; + } + + BOOST_ASSERT(!HasMultipleQuantizationScales()); + return m_Quantization.m_Scales[0]; +} + +void TensorInfo::SetQuantizationScale(float scale) +{ + m_Quantization.m_Scales = { scale }; +} + +int32_t TensorInfo::GetQuantizationOffset() const +{ + if (!m_Quantization.m_Offset.has_value()) + { + // NOTE: old default for backward compatibility + return 0; + } + + return m_Quantization.m_Offset.value(); +} + +void TensorInfo::SetQuantizationOffset(int32_t offset) +{ + m_Quantization.m_Offset = MakeOptional<int32_t>(offset); +} + +Optional<unsigned int> TensorInfo::GetQuantizationDim() const +{ + return m_Quantization.m_QuantizationDim; +} + +void TensorInfo::SetQuantizationDim(const Optional<unsigned int>& quantizationDim) +{ + m_Quantization.m_QuantizationDim = quantizationDim; +} + +bool TensorInfo::IsQuantized() const +{ + return m_DataType == DataType::QuantisedAsymm8 || m_DataType == DataType::QuantisedSymm16; +} + // --- // --- BaseTensor // --- |