diff options
Diffstat (limited to 'include/armnn/Tensor.hpp')
-rw-r--r-- | include/armnn/Tensor.hpp | 32 |
1 files changed, 26 insertions, 6 deletions
diff --git a/include/armnn/Tensor.hpp b/include/armnn/Tensor.hpp index 95898743bc..6f6abe187b 100644 --- a/include/armnn/Tensor.hpp +++ b/include/armnn/Tensor.hpp @@ -158,24 +158,28 @@ public: TensorInfo(const TensorShape& shape, DataType dataType, float quantizationScale = 0.0f, - int32_t quantizationOffset = 0); + int32_t quantizationOffset = 0, + bool isConstant = false); TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, float quantizationScale = 0.0f, - int32_t quantizationOffset = 0); + int32_t quantizationOffset = 0, + bool isConstant = false); TensorInfo(const TensorShape& shape, DataType dataType, const std::vector<float>& quantizationScales, - unsigned int quantizationDim); + unsigned int quantizationDim, + bool isConstant = false); TensorInfo(unsigned int numDimensions, const unsigned int* dimensionSizes, DataType dataType, const std::vector<float>& quantizationScales, - unsigned int quantizationDim); + unsigned int quantizationDim, + bool isConstant = false); TensorInfo(const TensorInfo& other); @@ -212,6 +216,14 @@ public: bool IsQuantized() const; + bool IsConstant() const; + + /// Marks the data corresponding to this tensor info as constant. + /// + /// @details: This can allow further optimization on execution + /// @Note: The user has to ensure that the underlying data actually is constant. + void SetConstant(const bool IsConstant=true); + /// Check that the types are the same and, if quantize, that the quantization parameters are the same. bool IsTypeSpaceMatch(const TensorInfo& other) const; @@ -220,6 +232,7 @@ public: private: TensorShape m_Shape; DataType m_DataType; + bool m_IsConstant; /// Vectors of scale and offset are used for per-axis quantization. struct Quantization @@ -316,10 +329,16 @@ class ConstTensor : public BaseTensor<const void*> public: /// Brings in the constructors and assignment operator. using BaseTensor<const void*>::BaseTensor; - ConstTensor() : BaseTensor<const void*>() {} // This needs to be redefined explicitly?? + ConstTensor() : BaseTensor<const void*>() + { + this->GetInfo().SetConstant(); + } /// Can be implicitly constructed from non-const Tensor. - ConstTensor(const Tensor& other) : BaseTensor<const void*>(other.GetInfo(), other.GetMemoryArea()) {} + ConstTensor(const Tensor& other) : BaseTensor<const void*>(other.GetInfo(), other.GetMemoryArea()) + { + this->GetInfo().SetConstant(); + } /// Constructor from a backing container. /// @param container - An stl-like container type which implements data() and size() methods. @@ -330,6 +349,7 @@ public: ConstTensor(const TensorInfo& info, const ContainerType<T, ContainerArgs...>& container) : BaseTensor<const void*>(info, container.data()) { + this->GetInfo().SetConstant(); if (container.size() * sizeof(T) != info.GetNumBytes()) { throw InvalidArgumentException("Container size is not correct"); |