aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp40
1 files changed, 19 insertions, 21 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index cdb57d1f..fe8e026e 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -1399,7 +1399,8 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
const HalModel& model,
ConversionData& data,
const armnn::TensorInfo* overrideOutputInfo = nullptr,
- const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc = nullptr)
+ const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc = nullptr,
+ bool inferOutputShapes = false)
{
using HalOperand = typename HalPolicy::Operand;
@@ -1410,7 +1411,6 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
}
armnn::IOutputSlot& outputSlot = layer.GetOutputSlot(layerOutputIndex);
-
if (overrideOutputInfo == nullptr)
{
outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand));
@@ -1420,32 +1420,30 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
outputSlot.SetTensorInfo(*overrideOutputInfo);
}
- // Type one dynamic tensors require the previous layer's output shape for inference
- if (!layer.GetInputSlot(0).GetConnection() &&
- IsDynamicTensor(outputSlot.GetTensorInfo()))
- {
- return false;
- }
-
bool isSupported = false;
- if (validateFunc &&
- layer.GetInputSlot(0).GetConnection() &&
- IsDynamicTensor(outputSlot.GetTensorInfo()))
+ if (validateFunc && (IsDynamicTensor(outputSlot.GetTensorInfo()) || inferOutputShapes))
{
+ // Type one dynamic tensors require the previous layer's output shape for inference
+ for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex)
+ {
+ if(!layer.GetInputSlot(inputSlotIndex).GetConnection())
+ {
+ return false;
+ }
+ }
// IsTensorInfoSet will infer the dynamic output shape
outputSlot.IsTensorInfoSet();
// Once the shape is inferred we can validate it
validateFunc(outputSlot.GetTensorInfo(), isSupported);
- if(!isSupported)
- {
- for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex)
- {
- layer.GetInputSlot(inputSlotIndex).GetConnection()->Disconnect(layer.GetInputSlot(inputSlotIndex));
- }
-
- return false;
- }
+ if(!isSupported)
+ {
+ for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex)
+ {
+ layer.GetInputSlot(inputSlotIndex).GetConnection()->Disconnect(layer.GetInputSlot(inputSlotIndex));
+ }
+ return false;
+ }
}
const uint32_t operandIndex = operation.outputs[operationOutputIndex];