aboutsummaryrefslogtreecommitdiff
path: root/Utils.cpp
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-07-23 12:55:12 +0100
committerTeresa Charlin <teresa.charlinreyes@arm.com>2020-08-12 17:06:37 +0100
commita4983cec09a3e24bf4e99abd31aa11842e8b365f (patch)
tree0a09be7cdad7ad5c15bb24c67abf0f4cd4d6e47e /Utils.cpp
parentbf866e2dc7bd5936788fe213b5c0f74483ec1532 (diff)
downloadandroid-nn-driver-a4983cec09a3e24bf4e99abd31aa11842e8b365f.tar.gz
IVGCVSW-4931 Update NN Driver to support dynamic tensors
* Change NN Driver m_Network to now have ShapeInferenceMethod::InferAndValidate * Implement dynamic tensor support for: - ArgMinMax layer - Pooling2d layer - Activation layer * Skip dynamic tensor tests for any HAL other than 1.3 Change-Id: Icf66c968e49cdd4822b8c79c5f18b3f9e97dc53f Signed-off-by: Finn Williams <Finn.Williams@Arm.com> Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Diffstat (limited to 'Utils.cpp')
-rw-r--r--Utils.cpp99
1 files changed, 89 insertions, 10 deletions
diff --git a/Utils.cpp b/Utils.cpp
index 8a2812ad..db1b6e68 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -80,7 +80,8 @@ void* GetMemoryFromPool(DataLocation location, const std::vector<android::nn::Ru
armnn::TensorInfo GetTensorInfoForOperand(const V1_0::Operand& operand)
{
- armnn::DataType type;
+ using namespace armnn;
+ DataType type;
switch (operand.type)
{
@@ -97,7 +98,30 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_0::Operand& operand)
throw UnsupportedOperand<V1_0::OperandType>(operand.type);
}
- armnn::TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type);
+ TensorInfo ret;
+ if (operand.dimensions.size() == 0)
+ {
+ TensorShape tensorShape(Dimensionality::NotSpecified);
+ ret = TensorInfo(tensorShape, type);
+ }
+ else
+ {
+ bool dimensionsSpecificity[5] = { true, true, true, true, true };
+ int count = 0;
+ std::for_each(operand.dimensions.data(),
+ operand.dimensions.data() + operand.dimensions.size(),
+ [&](const unsigned int val)
+ {
+ if (val == 0)
+ {
+ dimensionsSpecificity[count] = false;
+ }
+ count++;
+ });
+
+ TensorShape tensorShape(operand.dimensions.size(), operand.dimensions.data(), dimensionsSpecificity);
+ ret = TensorInfo(tensorShape, type);
+ }
ret.SetQuantizationScale(operand.scale);
ret.SetQuantizationOffset(operand.zeroPoint);
@@ -143,7 +167,31 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand)
throw UnsupportedOperand<V1_2::OperandType>(operand.type);
}
- TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type);
+ TensorInfo ret;
+ if (operand.dimensions.size() == 0)
+ {
+ TensorShape tensorShape(Dimensionality::NotSpecified);
+ ret = TensorInfo(tensorShape, type);
+ }
+ else
+ {
+ bool dimensionsSpecificity[5] = { true, true, true, true, true };
+ int count = 0;
+ std::for_each(operand.dimensions.data(),
+ operand.dimensions.data() + operand.dimensions.size(),
+ [&](const unsigned int val)
+ {
+ if (val == 0)
+ {
+ dimensionsSpecificity[count] = false;
+ }
+ count++;
+ });
+
+ TensorShape tensorShape(operand.dimensions.size(), operand.dimensions.data(), dimensionsSpecificity);
+ ret = TensorInfo(tensorShape, type);
+ }
+
if (perChannel)
{
// ExtraParams is expected to be of type channelQuant
@@ -219,7 +267,29 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand)
}
else
{
- ret = TensorInfo(operand.dimensions.size(), operand.dimensions.data(), type);
+ if (operand.dimensions.size() == 0)
+ {
+ TensorShape tensorShape(Dimensionality::NotSpecified);
+ ret = TensorInfo(tensorShape, type);
+ }
+ else
+ {
+ bool dimensionsSpecificity[5] = { true, true, true, true, true };
+ int count = 0;
+ std::for_each(operand.dimensions.data(),
+ operand.dimensions.data() + operand.dimensions.size(),
+ [&](const unsigned int val)
+ {
+ if (val == 0)
+ {
+ dimensionsSpecificity[count] = false;
+ }
+ count++;
+ });
+
+ TensorShape tensorShape(operand.dimensions.size(), operand.dimensions.data(), dimensionsSpecificity);
+ ret = TensorInfo(tensorShape, type);
+ }
}
if (perChannel)
@@ -501,10 +571,22 @@ std::string ExportNetworkGraphToDotFile(const armnn::IOptimizedNetwork& optimize
return fileName;
}
-bool IsDynamicTensor(const armnn::TensorInfo& outputInfo)
+bool IsDynamicTensor(const armnn::TensorInfo& tensorInfo)
+{
+ if (tensorInfo.GetShape().GetDimensionality() == armnn::Dimensionality::NotSpecified)
+ {
+ return true;
+ }
+ return !tensorInfo.GetShape().AreAllDimensionsSpecified();
+}
+
+bool AreDynamicTensorsSupported()
{
- // Dynamic tensors have at least one 0-sized dimension
- return outputInfo.GetNumElements() == 0u;
+#if defined(ARMNN_ANDROID_NN_V1_3)
+ return true;
+#else
+ return false;
+#endif
}
std::string GetFileTimestamp()
@@ -568,7 +650,4 @@ void CommitPools(std::vector<::android::nn::RunTimePoolInfo>& memPools)
#endif
}
}
-
-
-
} // namespace armnn_driver