diff options
Diffstat (limited to 'ConversionUtils_1_3.hpp')
-rw-r--r-- | ConversionUtils_1_3.hpp | 66 |
1 files changed, 42 insertions, 24 deletions
diff --git a/ConversionUtils_1_3.hpp b/ConversionUtils_1_3.hpp index e6961253..445b9ea7 100644 --- a/ConversionUtils_1_3.hpp +++ b/ConversionUtils_1_3.hpp @@ -600,29 +600,36 @@ bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model, } // Check if the layer is supported - - if (IsDynamicTensor(constOutputStateOutInfo) || - IsDynamicTensor(cellStateOutInfo) || - IsDynamicTensor(constOutputInfo)) + bool isSupported = false; + auto validateFunc = [&](const armnn::TensorInfo& cellStateOutInfo, bool& isSupported) + { + FORWARD_LAYER_SUPPORT_FUNC(__func__, + IsQLstmSupported, + data.m_Backends, + isSupported, + inputInfo, + outputStatePrevTimeStepInfo, + cellStatePrevTimeStepInfo, + constOutputStateOutInfo, + cellStateOutInfo, + constOutputInfo, + desc, + paramsInfo); + }; + + bool isDynamic = false; + if (!IsDynamicTensor(constOutputStateOutInfo) && + !IsDynamicTensor(cellStateOutInfo) && + !IsDynamicTensor(constOutputInfo)) + { + validateFunc(outputInfo, isSupported); + } + else { - return Fail("%s: Dynamic output tensors are not supported %d %d %d %d", __func__, - IsDynamicTensor(constOutputStateOutInfo), IsDynamicTensor(cellStateOutInfo), - IsDynamicTensor(constOutputInfo)); + isDynamic = true; + isSupported = AreDynamicTensorsSupported(); } - bool isSupported = false; - FORWARD_LAYER_SUPPORT_FUNC(__func__, - IsQLstmSupported, - data.m_Backends, - isSupported, - inputInfo, - outputStatePrevTimeStepInfo, - cellStatePrevTimeStepInfo, - constOutputStateOutInfo, - cellStateOutInfo, - constOutputInfo, - desc, - paramsInfo); if (!isSupported) { return false; @@ -635,10 +642,21 @@ bool ConvertQuantizedLstm(const HalOperation& operation, const HalModel& model, outputStatePrevTimeStep.Connect(layer->GetInputSlot(1)); cellStatePrevTimeStep.Connect(layer->GetInputSlot(2)); - return ( SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data, - &constOutputStateOutInfo) && - SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) && - SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo)); + if (!isDynamic) + { + return ( SetupAndTrackLayerOutputSlot<HalPolicy>( + operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) && + SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data) && + SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo)); + } + else + { + return ( SetupAndTrackLayerOutputSlot<HalPolicy>( + operation, 0, *layer, 0, model, data, &constOutputStateOutInfo) && + SetupAndTrackLayerOutputSlot<HalPolicy>( + operation, 1, *layer, 1, model, data, nullptr, validateFunc, true) && + SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 2, *layer, 2, model, data, &constOutputInfo)); + } } template<typename HalPolicy, |