aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2021-05-14 11:10:39 +0100
committerJan Eilers <jan.eilers@arm.com>2021-06-29 11:03:30 +0100
commitb082ed076b489f17bad3663005801b251d642108 (patch)
treebe4bc760626251a7411e0de7cd32e1a8a7631d5b
parent89fd793e179bf250c6c390c08dc42760343aa21b (diff)
downloadarmnn-b082ed076b489f17bad3663005801b251d642108.tar.gz
IVGCVSW-6027 Add IsConstant flag to TensorInfo
Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I7cb0a6a8856d8cd9949bec83c1ddce0a454fdf63
-rw-r--r--include/armnn/Tensor.hpp32
-rw-r--r--src/armnn/Tensor.cpp33
2 files changed, 53 insertions, 12 deletions
diff --git a/include/armnn/Tensor.hpp b/include/armnn/Tensor.hpp
index 95898743bc..6f6abe187b 100644
--- a/include/armnn/Tensor.hpp
+++ b/include/armnn/Tensor.hpp
@@ -158,24 +158,28 @@ public:
TensorInfo(const TensorShape& shape,
DataType dataType,
float quantizationScale = 0.0f,
- int32_t quantizationOffset = 0);
+ int32_t quantizationOffset = 0,
+ bool isConstant = false);
TensorInfo(unsigned int numDimensions,
const unsigned int* dimensionSizes,
DataType dataType,
float quantizationScale = 0.0f,
- int32_t quantizationOffset = 0);
+ int32_t quantizationOffset = 0,
+ bool isConstant = false);
TensorInfo(const TensorShape& shape,
DataType dataType,
const std::vector<float>& quantizationScales,
- unsigned int quantizationDim);
+ unsigned int quantizationDim,
+ bool isConstant = false);
TensorInfo(unsigned int numDimensions,
const unsigned int* dimensionSizes,
DataType dataType,
const std::vector<float>& quantizationScales,
- unsigned int quantizationDim);
+ unsigned int quantizationDim,
+ bool isConstant = false);
TensorInfo(const TensorInfo& other);
@@ -212,6 +216,14 @@ public:
bool IsQuantized() const;
+ bool IsConstant() const;
+
+ /// Marks the data corresponding to this tensor info as constant.
+ ///
+ /// @details: This can allow further optimization on execution
+ /// @Note: The user has to ensure that the underlying data actually is constant.
+ void SetConstant(const bool IsConstant=true);
+
/// Check that the types are the same and, if quantize, that the quantization parameters are the same.
bool IsTypeSpaceMatch(const TensorInfo& other) const;
@@ -220,6 +232,7 @@ public:
private:
TensorShape m_Shape;
DataType m_DataType;
+ bool m_IsConstant;
/// Vectors of scale and offset are used for per-axis quantization.
struct Quantization
@@ -316,10 +329,16 @@ class ConstTensor : public BaseTensor<const void*>
public:
/// Brings in the constructors and assignment operator.
using BaseTensor<const void*>::BaseTensor;
- ConstTensor() : BaseTensor<const void*>() {} // This needs to be redefined explicitly??
+ ConstTensor() : BaseTensor<const void*>()
+ {
+ this->GetInfo().SetConstant();
+ }
/// Can be implicitly constructed from non-const Tensor.
- ConstTensor(const Tensor& other) : BaseTensor<const void*>(other.GetInfo(), other.GetMemoryArea()) {}
+ ConstTensor(const Tensor& other) : BaseTensor<const void*>(other.GetInfo(), other.GetMemoryArea())
+ {
+ this->GetInfo().SetConstant();
+ }
/// Constructor from a backing container.
/// @param container - An stl-like container type which implements data() and size() methods.
@@ -330,6 +349,7 @@ public:
ConstTensor(const TensorInfo& info, const ContainerType<T, ContainerArgs...>& container)
: BaseTensor<const void*>(info, container.data())
{
+ this->GetInfo().SetConstant();
if (container.size() * sizeof(T) != info.GetNumBytes())
{
throw InvalidArgumentException("Container size is not correct");
diff --git a/src/armnn/Tensor.cpp b/src/armnn/Tensor.cpp
index 449fdf1f04..6a4dbf8dae 100644
--- a/src/armnn/Tensor.cpp
+++ b/src/armnn/Tensor.cpp
@@ -339,16 +339,18 @@ void TensorShape::CheckSpecifiedNumDimensions() const
// ---
TensorInfo::TensorInfo()
-: m_DataType(DataType::Float32)
+: m_DataType(DataType::Float32), m_IsConstant(false)
{
}
TensorInfo::TensorInfo(const TensorShape& shape,
DataType dataType,
float quantizationScale,
- int32_t quantizationOffset)
+ int32_t quantizationOffset,
+ bool isConstant)
: m_Shape(shape)
, m_DataType(dataType)
+ , m_IsConstant(isConstant)
{
SetQuantizationScale(quantizationScale);
SetQuantizationOffset(quantizationOffset);
@@ -358,9 +360,11 @@ TensorInfo::TensorInfo(unsigned int numDimensions,
const unsigned int* dimensionSizes,
DataType dataType,
float quantizationScale,
- int32_t quantizationOffset)
+ int32_t quantizationOffset,
+ bool isConstant)
: m_Shape(numDimensions, dimensionSizes)
, m_DataType(dataType)
+ , m_IsConstant(isConstant)
{
SetQuantizationScale(quantizationScale);
SetQuantizationOffset(quantizationOffset);
@@ -369,9 +373,11 @@ TensorInfo::TensorInfo(unsigned int numDimensions,
TensorInfo::TensorInfo(const TensorShape& shape,
DataType dataType,
const std::vector<float>& quantizationScales,
- unsigned int quantizationDim)
+ unsigned int quantizationDim,
+ bool isConstant)
: m_Shape(shape)
, m_DataType(dataType)
+ , m_IsConstant(isConstant)
{
SetQuantizationScales(quantizationScales);
SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim));
@@ -381,9 +387,11 @@ TensorInfo::TensorInfo(unsigned int numDimensions,
const unsigned int* dimensionSizes,
DataType dataType,
const std::vector<float>& quantizationScales,
- unsigned int quantizationDim)
+ unsigned int quantizationDim,
+ bool isConstant)
: m_Shape(numDimensions, dimensionSizes)
, m_DataType(dataType)
+ , m_IsConstant(isConstant)
{
SetQuantizationScales(quantizationScales);
SetQuantizationDim(MakeOptional<unsigned int>(quantizationDim));
@@ -392,6 +400,7 @@ TensorInfo::TensorInfo(unsigned int numDimensions,
TensorInfo::TensorInfo(const TensorInfo& other)
: m_Shape(other.m_Shape)
, m_DataType(other.m_DataType)
+, m_IsConstant(other.m_IsConstant)
, m_Quantization(other.m_Quantization)
{}
@@ -400,6 +409,7 @@ TensorInfo& TensorInfo::operator=(const TensorInfo& other)
m_Shape = other.m_Shape;
m_DataType = other.m_DataType;
m_Quantization = other.m_Quantization;
+ m_IsConstant = other.m_IsConstant;
return *this;
}
@@ -407,7 +417,8 @@ bool TensorInfo::operator==(const TensorInfo& other) const
{
return ((m_Shape == other.m_Shape) &&
(m_DataType == other.m_DataType) &&
- (m_Quantization == other.m_Quantization));
+ (m_Quantization == other.m_Quantization) &&
+ (m_IsConstant == other.m_IsConstant));
}
bool TensorInfo::operator!=(const TensorInfo& other) const
@@ -497,6 +508,16 @@ bool TensorInfo::IsQuantized() const
return IsQuantizedType(m_DataType);
}
+bool TensorInfo::IsConstant() const
+{
+ return m_IsConstant;
+}
+
+void TensorInfo::SetConstant(const bool IsConstant)
+{
+ m_IsConstant = IsConstant;
+}
+
// ---
// --- BaseTensor
// ---