From 34db1872566a1737fd94305d0b3f3e7741d99b60 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Thu, 3 Sep 2020 15:22:29 +0100 Subject: IVGCVSW-5274 'Update ConvertQuantizedLstm function to use ShapeInferenceMethod' * Enabled Dynamic Tensors in QUANTIZED_LSTM operator. !android-nn-driver:3897 Signed-off-by: Sadik Armagan Change-Id: I415014d19729aac255479099e372e5ff1a6dd3e2 --- ConversionUtils_1_3.hpp | 66 +++++++++++++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 24 deletions(-) (limited to 'ConversionUtils_1_3.hpp') diff --git a/ConversionUtils_1_3.hpp b/ConversionUtils_1_3.hpp index e6961253..445b9ea7 100644 --- a/ConversionUtils_1_3.hpp +++ b/ConversionUtils_1_3.hpp @@ -600,29 +600,36 @@ bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model, } // Check if the layer is supported - - if (IsDynamicTensor(constOutputStateOutInfo) || - IsDynamicTensor(cellStateOutInfo) || - IsDynamicTensor(constOutputInfo)) + bool isSupported = false; + auto validateFunc = [&](const armnn::TensorInfo& cellStateOutInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsQLstmSupported, + data.m_Backends, + isSupported, + inputInfo, + outputStatePrevTimeStepInfo, + cellStatePrevTimeStepInfo, + constOutputStateOutInfo, + cellStateOutInfo, + constOutputInfo, + desc, + paramsInfo); + }; + + bool isDynamic = false; + if (!IsDynamicTensor(constOutputStateOutInfo) && + !IsDynamicTensor(cellStateOutInfo) && + !IsDynamicTensor(constOutputInfo)) + { + validateFunc(outputInfo, isSupported); + } + else { - return Fail("%s: Dynamic output tensors are not supported %d %d %d %d", __func__, - IsDynamicTensor(constOutputStateOutInfo), IsDynamicTensor(cellStateOutInfo), - IsDynamicTensor(constOutputInfo)); + isDynamic = true; + isSupported = AreDynamicTensorsSupported(); } - bool isSupported = false; - FORWARD_LAYER_SUPPORT_FUNC(__func__, - IsQLstmSupported, - data.m_Backends, - isSupported, - inputInfo, - outputStatePrevTimeStepInfo, - cellStatePrevTimeStepInfo, - constOutputStateOutInfo, - cellStateOutInfo, - constOutputInfo, - desc, - paramsInfo); if (!isSupported) { return false; @@ -635,10 +642,21 @@ bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model, outputStatePrevTimeStep.Connect(layer->GetInputSlot(1)); cellStatePrevTimeStep.Connect(layer->GetInputSlot(2)); - return ( SetupAndTrackLayerOutputSlot(operation, 0, *layer, 0, model, data, - &constOutputStateOutInfo) && - SetupAndTrackLayerOutputSlot(operation, 1, *layer, 1, model, data) && - SetupAndTrackLayerOutputSlot(operation, 2, *layer, 2, model, data, &constOutputInfo)); + if (!isDynamic) + { + return ( SetupAndTrackLayerOutputSlot( + operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) && + SetupAndTrackLayerOutputSlot(operation, 1, *layer, 1, model, data) && + SetupAndTrackLayerOutputSlot(operation, 2, *layer, 2, model, data, &constOutputInfo)); + } + else + { + return ( SetupAndTrackLayerOutputSlot( + operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) && + SetupAndTrackLayerOutputSlot( + operation, 1, *layer, 1, model, data, nullptr, validateFunc, true) && + SetupAndTrackLayerOutputSlot(operation, 2, *layer, 2, model, data, &constOutputInfo)); + } } template