diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-07-15 14:29:09 +0100 |
---|---|---|
committer | Áron Virginás-Tar <aron.virginas-tar@arm.com> | 2019-07-16 13:31:26 +0000 |
commit | 2b173126319343e49d1f081cfb58eacd96afc715 (patch) | |
tree | b51eaf9d648cb93753c6adc4a075dcb6aea3a68e /1.2/HalPolicy.cpp | |
parent | d759323d159a50298af937dfb2c519025efe3900 (diff) | |
download | android-nn-driver-2b173126319343e49d1f081cfb58eacd96afc715.tar.gz |
IVGCVSW-3452 Support dynamic output shape in hal_1_2::HalPolicy::ConvertConv2d
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I8694e1f1c62da6f74eb356558b17a63758ccfdad
Diffstat (limited to '1.2/HalPolicy.cpp')
-rw-r--r-- | 1.2/HalPolicy.cpp | 32 |
1 files changed, 24 insertions, 8 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index a82db80b..69cc4713 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -172,13 +172,8 @@ bool HalPolicy::ConvertConv2d(const Operation& operation, const Model& model, Co return Fail("%s: Could not read output 0", __func__); } - const armnn::TensorInfo& inputInfo = input.GetTensorInfo(); - const armnn::TensorInfo& outputInfo = GetTensorInfoForOperand(*output); - - if (IsDynamicOutput(outputInfo)) - { - return Fail("%s: Dynamic output not supported", __func__); - } + const armnn::TensorInfo& inputInfo = input.GetTensorInfo(); + armnn::TensorInfo outputInfo = GetTensorInfoForOperand(*output); armnn::Convolution2dDescriptor desc; desc.m_DataLayout = armnn::DataLayout::NHWC; @@ -272,6 +267,21 @@ bool HalPolicy::ConvertConv2d(const Operation& operation, const Model& model, Co desc.m_BiasEnabled = true; armnn::Optional<armnn::TensorInfo> biases(bias.GetInfo()); + if (IsDynamicOutput(outputInfo)) + { + try + { + ALOGD("Output shape not set, will infer from inputs"); + outputInfo.SetShape(InferConvolution2dOutputShape(inputInfo.GetShape(), + weights.GetInfo().GetShape(), + desc)); + } + catch (armnn::Exception& e) + { + return Fail("%s: Could not infer dynamic output shape: %s", __func__, e.what()); + } + } + bool isSupported = false; FORWARD_LAYER_SUPPORT_FUNC(__func__, IsConvolution2dSupported, @@ -282,6 +292,7 @@ bool HalPolicy::ConvertConv2d(const Operation& operation, const Model& model, Co desc, weights.GetInfo(), biases); + if (!isSupported) { return false; @@ -304,7 +315,12 @@ bool HalPolicy::ConvertConv2d(const Operation& operation, const Model& model, Co input.Connect(startLayer->GetInputSlot(0)); - return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *endLayer, model, data); + return SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, + 0, + *endLayer, + model, + data, + armnn::Optional<armnn::TensorInfo>(outputInfo)); } bool HalPolicy::ConvertDepthwiseConv2d(const Operation& operation, const Model& model, ConversionData& data) |