From b082ed076b489f17bad3663005801b251d642108 Mon Sep 17 00:00:00 2001 From: Jan Eilers Date: Fri, 14 May 2021 11:10:39 +0100 Subject: IVGCVSW-6027 Add IsConstant flag to TensorInfo Signed-off-by: Jan Eilers Change-Id: I7cb0a6a8856d8cd9949bec83c1ddce0a454fdf63 --- include/armnn/Tensor.hpp | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) (limited to 'include/armnn/Tensor.hpp') diff --git a/include/armnn/Tensor.hpp b/include/armnn/Tensor.hpp index 95898743bc..6f6abe187b 100644 --- a/include/armnn/Tensor.hpp +++ b/include/armnn/Tensor.hpp @@ -158,24 +158,28 @@ public: TensorInfo(const TensorShape& shape, DataType dataType, float quantizationScale = 0.0f, - int32_t quantizationOffset = 0); + int32_t quantizationOffset = 0, + bool isConstant = false); TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, float quantizationScale = 0.0f, - int32_t quantizationOffset = 0); + int32_t quantizationOffset = 0, + bool isConstant = false); TensorInfo(const TensorShape& shape, DataType dataType, const std::vector& quantizationScales, - unsigned int quantizationDim); + unsigned int quantizationDim, + bool isConstant = false); TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, const std::vector& quantizationScales, - unsigned int quantizationDim); + unsigned int quantizationDim, + bool isConstant = false); TensorInfo(const TensorInfo& other); @@ -212,6 +216,14 @@ public: bool IsQuantized() const; + bool IsConstant() const; + + /// Marks the data corresponding to this tensor info as constant. + /// + /// @details: This can allow further optimization on execution + /// @Note: The user has to ensure that the underlying data actually is constant. + void SetConstant(const bool IsConstant=true); + /// Check that the types are the same and, if quantize, that the quantization parameters are the same. bool IsTypeSpaceMatch(const TensorInfo& other) const; @@ -220,6 +232,7 @@ public: private: TensorShape m_Shape; DataType m_DataType; + bool m_IsConstant; /// Vectors of scale and offset are used for per-axis quantization. struct Quantization @@ -316,10 +329,16 @@ class ConstTensor : public BaseTensor public: /// Brings in the constructors and assignment operator. using BaseTensor::BaseTensor; - ConstTensor() : BaseTensor() {} // This needs to be redefined explicitly?? + ConstTensor() : BaseTensor() + { + this->GetInfo().SetConstant(); + } /// Can be implicitly constructed from non-const Tensor. - ConstTensor(const Tensor& other) : BaseTensor(other.GetInfo(), other.GetMemoryArea()) {} + ConstTensor(const Tensor& other) : BaseTensor(other.GetInfo(), other.GetMemoryArea()) + { + this->GetInfo().SetConstant(); + } /// Constructor from a backing container. /// @param container - An stl-like container type which implements data() and size() methods. @@ -330,6 +349,7 @@ public: ConstTensor(const TensorInfo& info, const ContainerType& container) : BaseTensor(info, container.data()) { + this->GetInfo().SetConstant(); if (container.size() * sizeof(T) != info.GetNumBytes()) { throw InvalidArgumentException("Container size is not correct"); -- cgit v1.2.1