From a4983cec09a3e24bf4e99abd31aa11842e8b365f Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Thu, 23 Jul 2020 12:55:12 +0100 Subject: IVGCVSW-4931 Update NN Driver to support dynamic tensors * Change NN Driver m_Network to now have ShapeInferenceMethod::InferAndValidate * Implement dynamic tensor support for: - ArgMinMax layer - Pooling2d layer - Activation layer * Skip dynamic tensor tests for any HAL other than 1.3 Change-Id: Icf66c968e49cdd4822b8c79c5f18b3f9e97dc53f Signed-off-by: Finn Williams Signed-off-by: Teresa Charlin --- Utils.cpp | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 89 insertions(+), 10 deletions(-) (limited to 'Utils.cpp') diff --git a/Utils.cpp b/Utils.cpp index 8a2812ad..db1b6e68 100644 --- a/Utils.cpp +++ b/Utils.cpp @@ -80,7 +80,8 @@ void* GetMemoryFromPool(DataLocation location, const std::vector(operand.type); } - armnn::TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type); + TensorInfo ret; + if (operand.dimensions.size() == 0) + { + TensorShape tensorShape(Dimensionality::NotSpecified); + ret = TensorInfo(tensorShape, type); + } + else + { + bool dimensionsSpecificity[5] = { true, true, true, true, true }; + int count = 0; + std::for_each(operand.dimensions.data(), + operand.dimensions.data() + operand.dimensions.size(), + [&](const unsigned int val) + { + if (val == 0) + { + dimensionsSpecificity[count] = false; + } + count++; + }); + + TensorShape tensorShape(operand.dimensions.size(), operand.dimensions.data(), dimensionsSpecificity); + ret = TensorInfo(tensorShape, type); + } ret.SetQuantizationScale(operand.scale); ret.SetQuantizationOffset(operand.zeroPoint); @@ -143,7 +167,31 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_2::Operand& operand) throw UnsupportedOperand(operand.type); } - TensorInfo ret(operand.dimensions.size(), operand.dimensions.data(), type); + TensorInfo ret; + if (operand.dimensions.size() == 0) + { + TensorShape tensorShape(Dimensionality::NotSpecified); + ret = TensorInfo(tensorShape, type); + } + else + { + bool dimensionsSpecificity[5] = { true, true, true, true, true }; + int count = 0; + std::for_each(operand.dimensions.data(), + operand.dimensions.data() + operand.dimensions.size(), + [&](const unsigned int val) + { + if (val == 0) + { + dimensionsSpecificity[count] = false; + } + count++; + }); + + TensorShape tensorShape(operand.dimensions.size(), operand.dimensions.data(), dimensionsSpecificity); + ret = TensorInfo(tensorShape, type); + } + if (perChannel) { // ExtraParams is expected to be of type channelQuant @@ -219,7 +267,29 @@ armnn::TensorInfo GetTensorInfoForOperand(const V1_3::Operand& operand) } else { - ret = TensorInfo(operand.dimensions.size(), operand.dimensions.data(), type); + if (operand.dimensions.size() == 0) + { + TensorShape tensorShape(Dimensionality::NotSpecified); + ret = TensorInfo(tensorShape, type); + } + else + { + bool dimensionsSpecificity[5] = { true, true, true, true, true }; + int count = 0; + std::for_each(operand.dimensions.data(), + operand.dimensions.data() + operand.dimensions.size(), + [&](const unsigned int val) + { + if (val == 0) + { + dimensionsSpecificity[count] = false; + } + count++; + }); + + TensorShape tensorShape(operand.dimensions.size(), operand.dimensions.data(), dimensionsSpecificity); + ret = TensorInfo(tensorShape, type); + } } if (perChannel) @@ -501,10 +571,22 @@ std::string ExportNetworkGraphToDotFile(const armnn::IOptimizedNetwork& optimize return fileName; } -bool IsDynamicTensor(const armnn::TensorInfo& outputInfo) +bool IsDynamicTensor(const armnn::TensorInfo& tensorInfo) +{ + if (tensorInfo.GetShape().GetDimensionality() == armnn::Dimensionality::NotSpecified) + { + return true; + } + return !tensorInfo.GetShape().AreAllDimensionsSpecified(); +} + +bool AreDynamicTensorsSupported() { - // Dynamic tensors have at least one 0-sized dimension - return outputInfo.GetNumElements() == 0u; +#if defined(ARMNN_ANDROID_NN_V1_3) + return true; +#else + return false; +#endif } std::string GetFileTimestamp() @@ -568,7 +650,4 @@ void CommitPools(std::vector<::android::nn::RunTimePoolInfo>& memPools) #endif } } - - - } // namespace armnn_driver -- cgit v1.2.1