From b082ed076b489f17bad3663005801b251d642108 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Fri, 14 May 2021 11:10:39 +0100 Subject: IVGCVSW-6027 Add IsConstant flag to TensorInfo Signed-off-by: Jan Eilers Change-Id: I7cb0a6a8856d8cd9949bec83c1ddce0a454fdf63 --- include/armnn/Tensor.hpp | 32 ++++++++++++++++++++++++++------ src/armnn/Tensor.cpp | 33 +++++++++++++++++++++++++++------ 2 files changed, 53 insertions(+), 12 deletions(-) diff --git a/include/armnn/Tensor.hpp b/include/armnn/Tensor.hpp index 95898743bc..6f6abe187b 100644 --- a/include/armnn/Tensor.hpp +++ b/include/armnn/Tensor.hpp @@ -158,24 +158,28 @@ public: TensorInfo(const TensorShape& shape, DataType dataType, float quantizationScale = 0.0f, - int32_t quantizationOffset = 0); + int32_t quantizationOffset = 0, + bool isConstant = false); TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, float quantizationScale = 0.0f, - int32_t quantizationOffset = 0); + int32_t quantizationOffset = 0, + bool isConstant = false); TensorInfo(const TensorShape& shape, DataType dataType, const std::vector& quantizationScales, - unsigned int quantizationDim); + unsigned int quantizationDim, + bool isConstant = false); TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, const std::vector& quantizationScales, - unsigned int quantizationDim); + unsigned int quantizationDim, + bool isConstant = false); TensorInfo(const TensorInfo& other); @@ -212,6 +216,14 @@ public: bool IsQuantized() const; + bool IsConstant() const; + + /// Marks the data corresponding to this tensor info as constant. + /// + /// @details: This can allow further optimization on execution + /// @Note: The user has to ensure that the underlying data actually is constant. + void SetConstant(const bool IsConstant=true); + /// Check that the types are the same and, if quantize, that the quantization parameters are the same. bool IsTypeSpaceMatch(const TensorInfo& other) const; @@ -220,6 +232,7 @@ public: private: TensorShape m_Shape; DataType m_DataType; + bool m_IsConstant; /// Vectors of scale and offset are used for per-axis quantization. struct Quantization @@ -316,10 +329,16 @@ class ConstTensor : public BaseTensor public: /// Brings in the constructors and assignment operator. using BaseTensor::BaseTensor; - ConstTensor() : BaseTensor() {} // This needs to be redefined explicitly?? + ConstTensor() : BaseTensor() + { + this->GetInfo().SetConstant(); + } /// Can be implicitly constructed from non-const Tensor. - ConstTensor(const Tensor& other) : BaseTensor(other.GetInfo(), other.GetMemoryArea()) {} + ConstTensor(const Tensor& other) : BaseTensor(other.GetInfo(), other.GetMemoryArea()) + { + this->GetInfo().SetConstant(); + } /// Constructor from a backing container. /// @param container - An stl-like container type which implements data() and size() methods. @@ -330,6 +349,7 @@ public: ConstTensor(const TensorInfo& info, const ContainerType& container) : BaseTensor(info, container.data()) { + this->GetInfo().SetConstant(); if (container.size() * sizeof(T) != info.GetNumBytes()) { throw InvalidArgumentException("Container size is not correct"); diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp index 449fdf1f04..6a4dbf8dae 100644 --- a/src/armnn/Tensor.cpp +++ b/src/armnn/Tensor.cpp @@ -339,16 +339,18 @@ void TensorShape::CheckSpecifiedNumDimensions() const // --- TensorInfo::TensorInfo() -: m_DataType(DataType::Float32) +: m_DataType(DataType::Float32), m_IsConstant(false) { } TensorInfo::TensorInfo(const TensorShape& shape, DataType dataType, float quantizationScale, - int32_t quantizationOffset) + int32_t quantizationOffset, + bool isConstant) : m_Shape(shape) , m_DataType(dataType) + , m_IsConstant(isConstant) { SetQuantizationScale(quantizationScale); SetQuantizationOffset(quantizationOffset); @@ -358,9 +360,11 @@ TensorInfo::TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, float quantizationScale, - int32_t quantizationOffset) + int32_t quantizationOffset, + bool isConstant) : m_Shape(numDimensions, dimensionSizes) , m_DataType(dataType) + , m_IsConstant(isConstant) { SetQuantizationScale(quantizationScale); SetQuantizationOffset(quantizationOffset); @@ -369,9 +373,11 @@ TensorInfo::TensorInfo(unsigned int numDimensions, TensorInfo::TensorInfo(const TensorShape& shape, DataType dataType, const std::vector& quantizationScales, - unsigned int quantizationDim) + unsigned int quantizationDim, + bool isConstant) : m_Shape(shape) , m_DataType(dataType) + , m_IsConstant(isConstant) { SetQuantizationScales(quantizationScales); SetQuantizationDim(MakeOptional(quantizationDim)); @@ -381,9 +387,11 @@ TensorInfo::TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, const std::vector& quantizationScales, - unsigned int quantizationDim) + unsigned int quantizationDim, + bool isConstant) : m_Shape(numDimensions, dimensionSizes) , m_DataType(dataType) + , m_IsConstant(isConstant) { SetQuantizationScales(quantizationScales); SetQuantizationDim(MakeOptional(quantizationDim)); @@ -392,6 +400,7 @@ TensorInfo::TensorInfo(unsigned int numDimensions, TensorInfo::TensorInfo(const TensorInfo& other) : m_Shape(other.m_Shape) , m_DataType(other.m_DataType) +, m_IsConstant(other.m_IsConstant) , m_Quantization(other.m_Quantization) {} @@ -400,6 +409,7 @@ TensorInfo& TensorInfo::operator=(const TensorInfo& other) m_Shape = other.m_Shape; m_DataType = other.m_DataType; m_Quantization = other.m_Quantization; + m_IsConstant = other.m_IsConstant; return *this; } @@ -407,7 +417,8 @@ bool TensorInfo::operator==(const TensorInfo& other) const { return ((m_Shape == other.m_Shape) && (m_DataType == other.m_DataType) && - (m_Quantization == other.m_Quantization)); + (m_Quantization == other.m_Quantization) && + (m_IsConstant == other.m_IsConstant)); } bool TensorInfo::operator!=(const TensorInfo& other) const @@ -497,6 +508,16 @@ bool TensorInfo::IsQuantized() const return IsQuantizedType(m_DataType); } +bool TensorInfo::IsConstant() const +{ + return m_IsConstant; +} + +void TensorInfo::SetConstant(const bool IsConstant) +{ + m_IsConstant = IsConstant; +} + // --- // --- BaseTensor // --- -- cgit v1.2.1