diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2020-05-19 14:10:30 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2020-05-19 14:10:30 +0100 |
commit | 813f23049d73177edfc1f1cff71147c39f4b695e (patch) | |
tree | cef15224e1e7a4bd8117bdc9f57a71be84e4325d /ConversionUtils.hpp | |
parent | abc95d04dfb2462ffb42bc1facde4f45ecc65319 (diff) | |
download | android-nn-driver-813f23049d73177edfc1f1cff71147c39f4b695e.tar.gz |
IVGCVSW-4453 Add Support for ANEURALNETWORKS_QLSTM to HAL 1.3 Driver
* Add QLSTM support for Android NN Driver
* Add overrideOutputInfo parameter to SetupAndTrackLayerOutputSlot
* Add optional condition to GetInputScalar
* Refactor Quantized 16 Bit LSTM impl
Change-Id: Ie8fa98ad5ee4a62174ef91ca80f1df62b7fde937
Signed-off-by: Keith Davis <keith.davis@arm.com>
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r-- | ConversionUtils.hpp | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 8313d045..5a111317 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -877,35 +877,40 @@ bool GetInputScalar(const HalOperation& operation, HalOperandType type, OutputType& outValue, const HalModel& model, - const ConversionData& data) + const ConversionData& data, + bool optional = false) { using HalOperand = typename HalPolicy::Operand; const HalOperand* operand = GetInputOperand<HalPolicy>(operation, inputIndex, model); - if (!operand) + if (!optional && !operand) { return Fail("%s: invalid input operand at index %i", __func__, inputIndex); } - if (operand->type != type) + if (!optional && operand->type != type) { return Fail("%s: unexpected operand type: %s (should be %s)", __func__, toString(operand->type).c_str(), toString(type).c_str()); } - if (operand->location.length != sizeof(OutputType)) + if (!optional && operand->location.length != sizeof(OutputType)) { return Fail("%s: incorrect operand location length: %i (should be %i)", __func__, operand->location.length, sizeof(OutputType)); } const void* valueAddress = GetOperandValueReadOnlyAddress<HalPolicy>(*operand, model, data); - if (!valueAddress) + if (!optional && !valueAddress) { return Fail("%s: failed to get address for operand", __func__); } - outValue = *(static_cast<const OutputType*>(valueAddress)); + if(!optional) + { + outValue = *(static_cast<const OutputType*>(valueAddress)); + } + return true; } @@ -1374,7 +1379,8 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation, armnn::IConnectableLayer& layer, uint32_t layerOutputIndex, const HalModel& model, - ConversionData& data) + ConversionData& data, + const armnn::TensorInfo* overrideOutputInfo = nullptr) { using HalOperand = typename HalPolicy::Operand; @@ -1389,7 +1395,14 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation, const uint32_t operandIndex = operation.outputs[operationOutputIndex]; data.m_OutputSlotForOperand[operandIndex] = &outputSlot; - outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand)); + if (overrideOutputInfo == nullptr) + { + outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand)); + } + else + { + outputSlot.SetTensorInfo(*overrideOutputInfo); + } return true; } |