diff options
-rw-r--r-- | 1.2/HalPolicy.cpp | 64 |
1 files changed, 60 insertions, 4 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index 60bbf1d5..ac78e96b 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -2101,6 +2101,36 @@ bool HalPolicy::ConvertTanH(const Operation& operation, const Model& model, Conv return ::ConvertTanH<hal_1_2::HalPolicy>(operation, model, data); } +template<typename HalPolicy, + typename HalOperation = typename HalPolicy::Operation, + typename HalModel = typename HalPolicy::Model> +bool SetupAndTrackLayerOutputSlotAndOverrideTensorInfo(const HalOperation& operation, + uint32_t operationOutputIndex, + armnn::IConnectableLayer& layer, + uint32_t layerOutputIndex, + const HalModel& model, + ConversionData& data, + const armnn::TensorInfo tensor_info) +{ + using HalOperand = typename HalPolicy::Operand; + + const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, operationOutputIndex, model); + if ((outputOperand == nullptr) || (operationOutputIndex >= layer.GetNumOutputSlots())) + { + return false; + } + + armnn::IOutputSlot& outputSlot = layer.GetOutputSlot(layerOutputIndex); + + const uint32_t operandIndex = operation.outputs[operationOutputIndex]; + data.m_OutputSlotForOperand[operandIndex] = &outputSlot; + + outputSlot.SetTensorInfo(tensor_info); + + return true; +} + + bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_2::HalPolicy::ConvertLstm()"); @@ -2399,8 +2429,28 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv const TensorInfo& cellStateOutInfo = GetTensorInfoForOperand(*cellStateOut); const TensorInfo& outputInfo = GetTensorInfoForOperand(*output); - if (IsDynamicTensor(scratchBufferInfo) || - IsDynamicTensor(outputStateOutInfo) || + // Check if the scratch buffer shape was initialized, + // In some cases the shape could be (0,0) which requires the driver + // to infer the shape and set it up accordingly. + // The code below does that. + TensorInfo fixSbInfo = scratchBufferInfo; + if (IsDynamicTensor(scratchBufferInfo)) + { + auto & s = fixSbInfo.GetShape(); + s[0] = outputStateInInfo.GetShape()[0]; + if (desc.m_CifgEnabled) + { + // 2D tensor with dimensions [num_units * 3, batch_size] with CIFG + s[1] = cellStateOutInfo.GetShape()[1]*3; + } + else + { + // scratch_buffer [num_units * 4, batch_size] without CIFG + s[1] = cellStateOutInfo.GetShape()[1]*4; + } + } + + if (IsDynamicTensor(outputStateOutInfo) || IsDynamicTensor(cellStateOutInfo) || IsDynamicTensor(outputInfo)) { @@ -2467,7 +2517,7 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv inputInfo, outputStateInInfo, cellStateInInfo, - scratchBufferInfo, + fixSbInfo, outputStateOutInfo, cellStateOutInfo, outputInfo, @@ -2485,7 +2535,13 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv outputStateIn.Connect(layer->GetInputSlot(1)); cellStateIn.Connect(layer->GetInputSlot(2)); - return (SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *layer, 0, model, data) && + + return ( + (IsDynamicTensor(scratchBufferInfo)? + SetupAndTrackLayerOutputSlotAndOverrideTensorInfo<hal_1_2::HalPolicy>( + operation, 0, *layer, 0, model, data,fixSbInfo): + SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>( + operation, 0, *layer, 0, model, data)) && SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 1, *layer, 1, model, data) && SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 2, *layer, 2, model, data) && SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 3, *layer, 3, model, data)); |