diff options
-rw-r--r-- | ConversionUtils.hpp | 62 |
1 files changed, 56 insertions, 6 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 755e3bef..5ebec6b3 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -5,6 +5,7 @@ #pragma once +#include "OutputShapeUtils.hpp" #include "Utils.hpp" #include <armnn/ArmNN.hpp> @@ -1157,7 +1158,12 @@ bool ConvertToActivation(const HalOperation& operation, { return false; } - const armnn::TensorInfo outInfo = GetTensorInfoForOperand(*outputOperand); + armnn::TensorInfo outInfo = GetTensorInfoForOperand(*outputOperand); + if (IsDynamicTensor(outInfo)) + { + ALOGD("Output shape not set, will infer from input"); + outInfo.SetShape(input.GetTensorInfo().GetShape()); + } bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, @@ -1176,7 +1182,11 @@ bool ConvertToActivation(const HalOperation& operation, BOOST_ASSERT(layer != nullptr); input.Connect(layer->GetInputSlot(0)); - return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, model, data); + return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, + 0, + *layer, + model, + data,armnn::Optional<armnn::TensorInfo>(outInfo)); } template<typename HalPolicy, @@ -1344,7 +1354,7 @@ bool ConvertConv2d(const HalOperation& operation, const HalModel& model, Convers } const armnn::TensorInfo& inputInfo = input.GetTensorInfo(); - const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + armnn::TensorInfo outputInfo = GetTensorInfoForOperand(*output); // ArmNN does not currently support non-fixed weights or bias const ConstTensorPin weightsPin = ConvertOperationInputToConstTensorPin<HalPolicy>(operation, 1, model, data); @@ -1408,6 +1418,21 @@ bool ConvertConv2d(const HalOperation& operation, const HalModel& model, Convers desc.m_BiasEnabled = true; armnn::Optional<armnn::TensorInfo> biases(bias.GetInfo()); + if (IsDynamicTensor(outputInfo)) + { + try + { + ALOGD("Output shape not set, will infer from inputs"); + outputInfo.SetShape(InferConvolution2dOutputShape(inputInfo.GetShape(), + weights.GetInfo().GetShape(), + desc)); + } + catch (armnn::Exception& e) + { + return Fail("%s: Could not infer dynamic output shape: %s", __func__, e.what()); + } + } + bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, IsConvolution2dSupported, @@ -1440,7 +1465,12 @@ bool ConvertConv2d(const HalOperation& operation, const HalModel& model, Convers input.Connect(startLayer->GetInputSlot(0)); - return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *endLayer, model, data); + return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, + 0, + *endLayer, + model, + data, + armnn::Optional<armnn::TensorInfo>(outputInfo)); } template<typename HalPolicy, @@ -1466,7 +1496,7 @@ bool ConvertDepthwiseConv2d(const HalOperation& operation, const HalModel& model } const armnn::TensorInfo& inputInfo = input.GetTensorInfo(); - const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + armnn::TensorInfo outputInfo = GetTensorInfoForOperand(*output); // ArmNN does not currently support non-fixed weights or bias @@ -1568,6 +1598,21 @@ bool ConvertDepthwiseConv2d(const HalOperation& operation, const HalModel& model desc.m_BiasEnabled = true; armnn::Optional<armnn::TensorInfo> biases(bias.GetInfo()); + if (IsDynamicTensor(outputInfo)) + { + try + { + ALOGD("Output shape not set, will infer from inputs"); + outputInfo.SetShape(InferDepthwiseConvolution2dOutputShape(inputInfo.GetShape(), + weights.GetInfo().GetShape(), + desc)); + } + catch (armnn::Exception& e) + { + return Fail("%s: Could not infer dynamic output shape: %s", __func__, e.what()); + } + } + bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, IsDepthwiseConvolutionSupported, @@ -1598,7 +1643,12 @@ bool ConvertDepthwiseConv2d(const HalOperation& operation, const HalModel& model input.Connect(startLayer->GetInputSlot(0)); - return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *endLayer, model, data); + return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, + 0, + *endLayer, + model, + data, + armnn::Optional<armnn::TensorInfo>(outputInfo)); } } // namespace armnn_driver |