diff options
author | Jan Eilers <jan.eilers@arm.com> | 2021-05-14 11:10:39 +0100 |
---|---|---|
committer | Jan Eilers <jan.eilers@arm.com> | 2021-06-29 11:03:30 +0100 |
commit | b082ed076b489f17bad3663005801b251d642108 (patch) | |
tree | be4bc760626251a7411e0de7cd32e1a8a7631d5b | |
parent | 89fd793e179bf250c6c390c08dc42760343aa21b (diff) | |
download | armnn-b082ed076b489f17bad3663005801b251d642108.tar.gz |
IVGCVSW-6027 Add IsConstant flag to TensorInfo
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: I7cb0a6a8856d8cd9949bec83c1ddce0a454fdf63
-rw-r--r-- | include/armnn/Tensor.hpp | 32 | ||||
-rw-r--r-- | src/armnn/Tensor.cpp | 33 |
2 files changed, 53 insertions, 12 deletions
diff --git a/include/armnn/Tensor.hpp b/include/armnn/Tensor.hpp index 95898743bc..6f6abe187b 100644 --- a/include/armnn/Tensor.hpp +++ b/include/armnn/Tensor.hpp @@ -158,24 +158,28 @@ public: TensorInfo(const TensorShape& shape, DataType dataType, float quantizationScale = 0.0f, - int32_t quantizationOffset = 0); + int32_t quantizationOffset = 0, + bool isConstant = false); TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, float quantizationScale = 0.0f, - int32_t quantizationOffset = 0); + int32_t quantizationOffset = 0, + bool isConstant = false); TensorInfo(const TensorShape& shape, DataType dataType, const std::vector<float>& quantizationScales, - unsigned int quantizationDim); + unsigned int quantizationDim, + bool isConstant = false); TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, const std::vector<float>& quantizationScales, - unsigned int quantizationDim); + unsigned int quantizationDim, + bool isConstant = false); TensorInfo(const TensorInfo& other); @@ -212,6 +216,14 @@ public: bool IsQuantized() const; + bool IsConstant() const; + + /// Marks the data corresponding to this tensor info as constant. + /// + /// @details: This can allow further optimization on execution + /// @Note: The user has to ensure that the underlying data actually is constant. + void SetConstant(const bool IsConstant=true); + /// Check that the types are the same and, if quantize, that the quantization parameters are the same. bool IsTypeSpaceMatch(const TensorInfo& other) const; @@ -220,6 +232,7 @@ public: private: TensorShape m_Shape; DataType m_DataType; + bool m_IsConstant; /// Vectors of scale and offset are used for per-axis quantization. struct Quantization @@ -316,10 +329,16 @@ class ConstTensor : public BaseTensor<const void*> public: /// Brings in the constructors and assignment operator. using BaseTensor<const void*>::BaseTensor; - ConstTensor() : BaseTensor<const void*>() {} // This needs to be redefined explicitly?? + ConstTensor() : BaseTensor<const void*>() + { + this->GetInfo().SetConstant(); + } /// Can be implicitly constructed from non-const Tensor. - ConstTensor(const Tensor& other) : BaseTensor<const void*>(other.GetInfo(), other.GetMemoryArea()) {} + ConstTensor(const Tensor& other) : BaseTensor<const void*>(other.GetInfo(), other.GetMemoryArea()) + { + this->GetInfo().SetConstant(); + } /// Constructor from a backing container. /// @param container - An stl-like container type which implements data() and size() methods. @@ -330,6 +349,7 @@ public: ConstTensor(const TensorInfo& info, const ContainerType<T, ContainerArgs...>& container) : BaseTensor<const void*>(info, container.data()) { + this->GetInfo().SetConstant(); if (container.size() * sizeof(T) != info.GetNumBytes()) { throw InvalidArgumentException("Container size is not correct"); diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp index 449fdf1f04..6a4dbf8dae 100644 --- a/src/armnn/Tensor.cpp +++ b/src/armnn/Tensor.cpp @@ -339,16 +339,18 @@ void TensorShape::CheckSpecifiedNumDimensions() const // --- TensorInfo::TensorInfo() -: m_DataType(DataType::Float32) +: m_DataType(DataType::Float32), m_IsConstant(false) { } TensorInfo::TensorInfo(const TensorShape& shape, DataType dataType, float quantizationScale, - int32_t quantizationOffset) + int32_t quantizationOffset, + bool isConstant) : m_Shape(shape) , m_DataType(dataType) + , m_IsConstant(isConstant) { SetQuantizationScale(quantizationScale); SetQuantizationOffset(quantizationOffset); @@ -358,9 +360,11 @@ TensorInfo::TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, float quantizationScale, - int32_t quantizationOffset) + int32_t quantizationOffset, + bool isConstant) : m_Shape(numDimensions, dimensionSizes) , m_DataType(dataType) + , m_IsConstant(isConstant) { SetQuantizationScale(quantizationScale); SetQuantizationOffset(quantizationOffset); @@ -369,9 +373,11 @@ TensorInfo::TensorInfo(unsigned int numDimensions, TensorInfo::TensorInfo(const TensorShape& shape, DataType dataType, const std::vector<float>& quantizationScales, - unsigned int quantizationDim) + unsigned int quantizationDim, + bool isConstant) : m_Shape(shape) , m_DataType(dataType) + , m_IsConstant(isConstant) { SetQuantizationScales(quantizationScales); SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim)); @@ -381,9 +387,11 @@ TensorInfo::TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, const std::vector<float>& quantizationScales, - unsigned int quantizationDim) + unsigned int quantizationDim, + bool isConstant) : m_Shape(numDimensions, dimensionSizes) , m_DataType(dataType) + , m_IsConstant(isConstant) { SetQuantizationScales(quantizationScales); SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim)); @@ -392,6 +400,7 @@ TensorInfo::TensorInfo(unsigned int numDimensions, TensorInfo::TensorInfo(const TensorInfo& other) : m_Shape(other.m_Shape) , m_DataType(other.m_DataType) +, m_IsConstant(other.m_IsConstant) , m_Quantization(other.m_Quantization) {} @@ -400,6 +409,7 @@ TensorInfo& TensorInfo::operator=(const TensorInfo& other) m_Shape = other.m_Shape; m_DataType = other.m_DataType; m_Quantization = other.m_Quantization; + m_IsConstant = other.m_IsConstant; return *this; } @@ -407,7 +417,8 @@ bool TensorInfo::operator==(const TensorInfo& other) const { return ((m_Shape == other.m_Shape) && (m_DataType == other.m_DataType) && - (m_Quantization == other.m_Quantization)); + (m_Quantization == other.m_Quantization) && + (m_IsConstant == other.m_IsConstant)); } bool TensorInfo::operator!=(const TensorInfo& other) const @@ -497,6 +508,16 @@ bool TensorInfo::IsQuantized() const return IsQuantizedType(m_DataType); } +bool TensorInfo::IsConstant() const +{ + return m_IsConstant; +} + +void TensorInfo::SetConstant(const bool IsConstant) +{ + m_IsConstant = IsConstant; +} + // --- // --- BaseTensor // --- |