diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2020-09-03 11:33:07 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2020-09-03 11:33:07 +0100 |
commit | dbda4b7eb52ba3c271b06cc5172a40bf84aafde8 (patch) | |
tree | e195783e782222d0a62ccc468cd1b4b6181ccb1e /ConversionUtils.hpp | |
parent | 346e8119e51cd9a3fbb56eb6fa85ad48a256c8eb (diff) | |
download | android-nn-driver-dbda4b7eb52ba3c271b06cc5172a40bf84aafde8.tar.gz |
IVGCVSW-5272 'Update ConvertLstm function to use ShapeInferenceMethod'
* Enabled Dynamic Tensors on LSTM operator
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Change-Id: I9cae539559570a44088a986870d3d3e41aee9468
Diffstat (limited to 'ConversionUtils.hpp')
-rw-r--r-- | ConversionUtils.hpp | 40 |
1 files changed, 19 insertions, 21 deletions
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index cdb57d1f..fe8e026e 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -1399,7 +1399,8 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation, const HalModel& model, ConversionData& data, const armnn::TensorInfo* overrideOutputInfo = nullptr, - const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc = nullptr) + const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc = nullptr, + bool inferOutputShapes = false) { using HalOperand = typename HalPolicy::Operand; @@ -1410,7 +1411,6 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation, } armnn::IOutputSlot& outputSlot = layer.GetOutputSlot(layerOutputIndex); - if (overrideOutputInfo == nullptr) { outputSlot.SetTensorInfo(GetTensorInfoForOperand(*outputOperand)); @@ -1420,32 +1420,30 @@ bool SetupAndTrackLayerOutputSlot(const HalOperation& operation, outputSlot.SetTensorInfo(*overrideOutputInfo); } - // Type one dynamic tensors require the previous layer's output shape for inference - if (!layer.GetInputSlot(0).GetConnection() && - IsDynamicTensor(outputSlot.GetTensorInfo())) - { - return false; - } - bool isSupported = false; - if (validateFunc && - layer.GetInputSlot(0).GetConnection() && - IsDynamicTensor(outputSlot.GetTensorInfo())) + if (validateFunc && (IsDynamicTensor(outputSlot.GetTensorInfo()) || inferOutputShapes)) { + // Type one dynamic tensors require the previous layer's output shape for inference + for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex) + { + if(!layer.GetInputSlot(inputSlotIndex).GetConnection()) + { + return false; + } + } // IsTensorInfoSet will infer the dynamic output shape outputSlot.IsTensorInfoSet(); // Once the shape is inferred we can validate it validateFunc(outputSlot.GetTensorInfo(), isSupported); - if(!isSupported) - { - for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex) - { - layer.GetInputSlot(inputSlotIndex).GetConnection()->Disconnect(layer.GetInputSlot(inputSlotIndex)); - } - - return false; - } + if(!isSupported) + { + for (unsigned int inputSlotIndex = 0; inputSlotIndex < layer.GetNumInputSlots(); ++inputSlotIndex) + { + layer.GetInputSlot(inputSlotIndex).GetConnection()->Disconnect(layer.GetInputSlot(inputSlotIndex)); + } + return false; + } } const uint32_t operandIndex = operation.outputs[operationOutputIndex]; |