diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2020-09-03 10:57:43 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2020-09-03 11:23:20 +0000 |
commit | baa1f9f24eb8e5d5bc2011e2eee6278140846a9a (patch) | |
tree | 8b965d1ec2d96c6abe80a5a51312b7f79f82a407 | |
parent | dbda4b7eb52ba3c271b06cc5172a40bf84aafde8 (diff) | |
download | android-nn-driver-baa1f9f24eb8e5d5bc2011e2eee6278140846a9a.tar.gz |
IVGCVSW-5271 'Update ConvertQuantized16BitLstm function to use ShapeInferenceMethod'
* Enabled DynamicTensors on Quantized16BitLSTM operator.
!android-nn-driver:3897
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: Ic86c5af5a4d1b1d12fc6879dfb94fddd889b85de
-rw-r--r-- | ConversionUtils_1_2.hpp | 49 |
1 files changed, 37 insertions, 12 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp index 4e0ccb6e..760312e7 100644 --- a/ConversionUtils_1_2.hpp +++ b/ConversionUtils_1_2.hpp @@ -1879,16 +1879,31 @@ bool ConvertQuantized16BitLstm(const HalOperation& operation, const HalModel& mo paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo()); bool isSupported = false; - FORWARD_LAYER_SUPPORT_FUNC(__func__, - IsQuantizedLstmSupported, - data.m_Backends, - isSupported, - inputInfo, - previousCellStateInInfo, - previousOutputInInfo, - cellStateOutInfo, - outputInfo, - paramsInfo); + auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsQuantizedLstmSupported, + data.m_Backends, + isSupported, + inputInfo, + previousCellStateInInfo, + previousOutputInInfo, + cellStateOutInfo, + outputInfo, + paramsInfo); + }; + + bool isDynamic = false; + if (!IsDynamicTensor(cellStateOutInfo) && + !IsDynamicTensor(outputInfo)) + { + validateFunc(outputInfo, isSupported); + } + else + { + isDynamic = true; + isSupported = AreDynamicTensorsSupported(); + } if (!isSupported) { @@ -1900,8 +1915,18 @@ bool ConvertQuantized16BitLstm(const HalOperation& operation, const HalModel& mo previousCellStateIn.Connect(layer->GetInputSlot(1)); previousOutputIn.Connect(layer->GetInputSlot(2)); - return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) && - SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data)); + if (!isDynamic) + { + return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) && + SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data)); + } + else + { + return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) && + SetupAndTrackLayerOutputSlot<HalPolicy>( + operation, 1, *layer, 1, model, data, nullptr, validateFunc, true)); + } + } template<typename HalPolicy, |