aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2020-09-03 10:57:43 +0100
committerSadik Armagan <sadik.armagan@arm.com>2020-09-03 11:23:20 +0000
commitbaa1f9f24eb8e5d5bc2011e2eee6278140846a9a (patch)
tree8b965d1ec2d96c6abe80a5a51312b7f79f82a407
parentdbda4b7eb52ba3c271b06cc5172a40bf84aafde8 (diff)
downloadandroid-nn-driver-baa1f9f24eb8e5d5bc2011e2eee6278140846a9a.tar.gz
IVGCVSW-5271 'Update ConvertQuantized16BitLstm function to use ShapeInferenceMethod'
* Enabled DynamicTensors on Quantized16BitLSTM operator. !android-nn-driver:3897 Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: Ic86c5af5a4d1b1d12fc6879dfb94fddd889b85de
-rw-r--r--ConversionUtils_1_2.hpp49
1 files changed, 37 insertions, 12 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp
index 4e0ccb6e..760312e7 100644
--- a/ConversionUtils_1_2.hpp
+++ b/ConversionUtils_1_2.hpp
@@ -1879,16 +1879,31 @@ bool ConvertQuantized16BitLstm(const HalOperation& operation, const HalModel& mo
paramsInfo.m_OutputGateBias = &(params.m_OutputGateBias->GetInfo());
bool isSupported = false;
- FORWARD_LAYER_SUPPORT_FUNC(__func__,
- IsQuantizedLstmSupported,
- data.m_Backends,
- isSupported,
- inputInfo,
- previousCellStateInInfo,
- previousOutputInInfo,
- cellStateOutInfo,
- outputInfo,
- paramsInfo);
+ auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
+ {
+ FORWARD_LAYER_SUPPORT_FUNC(__func__,
+ IsQuantizedLstmSupported,
+ data.m_Backends,
+ isSupported,
+ inputInfo,
+ previousCellStateInInfo,
+ previousOutputInInfo,
+ cellStateOutInfo,
+ outputInfo,
+ paramsInfo);
+ };
+
+ bool isDynamic = false;
+ if (!IsDynamicTensor(cellStateOutInfo) &&
+ !IsDynamicTensor(outputInfo))
+ {
+ validateFunc(outputInfo, isSupported);
+ }
+ else
+ {
+ isDynamic = true;
+ isSupported = AreDynamicTensorsSupported();
+ }
if (!isSupported)
{
@@ -1900,8 +1915,18 @@ bool ConvertQuantized16BitLstm(const HalOperation& operation, const HalModel& mo
previousCellStateIn.Connect(layer->GetInputSlot(1));
previousOutputIn.Connect(layer->GetInputSlot(2));
- return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) &&
- SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data));
+ if (!isDynamic)
+ {
+ return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) &&
+ SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 1, *layer, 1, model, data));
+ }
+ else
+ {
+ return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data) &&
+ SetupAndTrackLayerOutputSlot<HalPolicy>(
+ operation, 1, *layer, 1, model, data, nullptr, validateFunc, true));
+ }
+
}
template<typename HalPolicy,