diff options
author | Pablo Tello <pablo.tello@arm.com> | 2019-11-28 15:21:41 +0000 |
---|---|---|
committer | Pablo Tello <pablo.tello@arm.com> | 2019-12-13 17:34:31 +0000 |
commit | 972603fb38e2f5a75b2389bd8a265b6e3173a02f (patch) | |
tree | c006c89dd47e54ed8f41aad597e6a00fb4e178de | |
parent | 624fe9f02d4cf229cc7f33d9c970b39fcf560d1f (diff) | |
download | android-nn-driver-972603fb38e2f5a75b2389bd8a265b6e3173a02f.tar.gz |
MLCE-133: Driver infers scratchBuffer shape.
Change-Id: I7977d697772349b8ea7eb300937409ce0a3a4dee
Signed-off-by: Pablo Tello <pablo.tello@arm.com>
-rw-r--r-- | 1.2/HalPolicy.cpp | 64 |
1 files changed, 60 insertions, 4 deletions
diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index 60bbf1d5..ac78e96b 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -2101,6 +2101,36 @@ bool HalPolicy::ConvertTanH(const Operation& operation, const Model& model, Conv return ::ConvertTanH<hal_1_2::HalPolicy>(operation, model, data); } +template<typename HalPolicy, + typename HalOperation = typename HalPolicy::Operation, + typename HalModel = typename HalPolicy::Model> +bool SetupAndTrackLayerOutputSlotAndOverrideTensorInfo(const HalOperation& operation, + uint32_t operationOutputIndex, + armnn::IConnectableLayer& layer, + uint32_t layerOutputIndex, + const HalModel& model, + ConversionData& data, + const armnn::TensorInfo tensor_info) +{ + using HalOperand = typename HalPolicy::Operand; + + const HalOperand* outputOperand = GetOutputOperand<HalPolicy>(operation, operationOutputIndex, model); + if ((outputOperand == nullptr) || (operationOutputIndex >= layer.GetNumOutputSlots())) + { + return false; + } + + armnn::IOutputSlot& outputSlot = layer.GetOutputSlot(layerOutputIndex); + + const uint32_t operandIndex = operation.outputs[operationOutputIndex]; + data.m_OutputSlotForOperand[operandIndex] = &outputSlot; + + outputSlot.SetTensorInfo(tensor_info); + + return true; +} + + bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, ConversionData& data) { ALOGV("hal_1_2::HalPolicy::ConvertLstm()"); @@ -2399,8 +2429,28 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv const TensorInfo& cellStateOutInfo = GetTensorInfoForOperand(*cellStateOut); const TensorInfo& outputInfo = GetTensorInfoForOperand(*output); - if (IsDynamicTensor(scratchBufferInfo) || - IsDynamicTensor(outputStateOutInfo) || + // Check if the scratch buffer shape was initialized, + // In some cases the shape could be (0,0) which requires the driver + // to infer the shape and set it up accordingly. + // The code below does that. + TensorInfo fixSbInfo = scratchBufferInfo; + if (IsDynamicTensor(scratchBufferInfo)) + { + auto & s = fixSbInfo.GetShape(); + s[0] = outputStateInInfo.GetShape()[0]; + if (desc.m_CifgEnabled) + { + // 2D tensor with dimensions [num_units * 3, batch_size] with CIFG + s[1] = cellStateOutInfo.GetShape()[1]*3; + } + else + { + // scratch_buffer [num_units * 4, batch_size] without CIFG + s[1] = cellStateOutInfo.GetShape()[1]*4; + } + } + + if (IsDynamicTensor(outputStateOutInfo) || IsDynamicTensor(cellStateOutInfo) || IsDynamicTensor(outputInfo)) { @@ -2467,7 +2517,7 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv inputInfo, outputStateInInfo, cellStateInInfo, - scratchBufferInfo, + fixSbInfo, outputStateOutInfo, cellStateOutInfo, outputInfo, @@ -2485,7 +2535,13 @@ bool HalPolicy::ConvertLstm(const Operation& operation, const Model& model, Conv outputStateIn.Connect(layer->GetInputSlot(1)); cellStateIn.Connect(layer->GetInputSlot(2)); - return (SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 0, *layer, 0, model, data) && + + return ( + (IsDynamicTensor(scratchBufferInfo)? + SetupAndTrackLayerOutputSlotAndOverrideTensorInfo<hal_1_2::HalPolicy>( + operation, 0, *layer, 0, model, data,fixSbInfo): + SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>( + operation, 0, *layer, 0, model, data)) && SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 1, *layer, 1, model, data) && SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 2, *layer, 2, model, data) && SetupAndTrackLayerOutputSlot<hal_1_2::HalPolicy>(operation, 3, *layer, 3, model, data)); |