diff options
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r-- | ConversionUtils.hpp | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 8313d045..5a111317 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -877,35 +877,40 @@ bool GetInputScalar(const HalOperation& operation, HalOperandType type, OutputType& outValue, const HalModel& model, - const ConversionData& data) + const ConversionData& data, + bool optional = false) { using HalOperand = typename HalPolicy::Operand; const HalOperand* operand = GetInputOperand<HalPolicy>(operation, inputIndex, model); - if (!operand) + if (!optional && !operand) { return Fail("%s: invalid input operand at index %i", __func__, inputIndex); } - if (operand->type != type) + if (!optional && operand->type != type) { return Fail("%s: unexpected operand type: %s (should be %s)", __func__, toString(operand->type).c_str(), toString(type).c_str()); } - if (operand->location.length != sizeof(OutputType)) + if (!optional && operand->location.length != sizeof(OutputType)) { return Fail("%s: incorrect operand location length: %i (should be %i)", __func__, operand->location.length, sizeof(OutputType)); } const void* valueAddress = GetOperandValueReadOnlyAddress<HalPolicy>(*operand, model, data); - if (!valueAddress) + if (!optional && !valueAddress) { return Fail("%s: failed to get address for operand", __func__); } - outValue = *(static_cast<const OutputType*>(valueAddress)); + if(!optional) + { + outValue = *(static_cast<const OutputType*>(valueAddress)); + } + return true; } @@ -1374,7 +1379,8 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation, armnn::IConnectableLayer& layer, uint32_t layerOutputIndex, const HalModel& model, - ConversionData& data) + ConversionData& data, + const armnn::TensorInfo* overrideOutputInfo = nullptr) { using HalOperand = typename HalPolicy::Operand; @@ -1389,7 +1395,14 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation, const uint32_t operandIndex = operation.outputs[operationOutputIndex]; data.m_OutputSlotForOperand[operandIndex] = &outputSlot; - outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand)); + if (overrideOutputInfo == nullptr) + { + outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand)); + } + else + { + outputSlot.SetTensorInfo(*overrideOutputInfo); + } return true; } |