aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r--ConversionUtils.hpp24
1 files changed, 20 insertions, 4 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index d30b8a4e..c9be0003 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -1028,7 +1028,8 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
armnn::IConnectableLayer& layer,
uint32_t layerOutputIndex,
const HalModel& model,
- ConversionData& data)
+ ConversionData& data,
+ const armnn::Optional<armnn::TensorInfo>& outputInfo = armnn::EmptyOptional())
{
using HalOperand = typename HalPolicy::Operand;
@@ -1043,7 +1044,15 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
const uint32_t operandIndex = operation.outputs[operationOutputIndex];
data.m_OutputSlotForOperand[operandIndex] = &outputSlot;
- outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand));
+ if (outputInfo.has_value())
+ {
+ outputSlot.SetTensorInfo(outputInfo.value());
+ ALOGD("Output info overwritten");
+ }
+ else
+ {
+ outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand));
+ }
return true;
}
@@ -1092,9 +1101,16 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation,
uint32_t outputIndex,
armnn::IConnectableLayer& layer,
const HalModel& model,
- ConversionData& data)
+ ConversionData& data,
+ const armnn::Optional<armnn::TensorInfo>& outputInfo = armnn::EmptyOptional())
{
- return SetupAndTrackLayerOutputSlot<HalPolicy>(operation, outputIndex, layer, outputIndex, model, data);
+ return SetupAndTrackLayerOutputSlot<HalPolicy>(operation,
+ outputIndex,
+ layer,
+ outputIndex,
+ model,
+ data,
+ outputInfo);
}
template<typename HalPolicy,