diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-04-25 16:18:57 +0100 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2022-05-04 15:45:44 +0000 |
commit | 0ae102a0fc9d66b7067cf8d7a0ed1af5ed65ae50 (patch) | |
tree | 106e684da730f0eb9b3f29922af2244c26f2e528 /ConversionUtils_1_2.hpp | |
parent | be9d99e354b54ecf50578478203a86efa5442789 (diff) | |
download | android-nn-driver-0ae102a0fc9d66b7067cf8d7a0ed1af5ed65ae50.tar.gz |
IVGCVSW-6806 Fixed issue with missing TensorInfos in UnidirectionalSequenceLSTM
* Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported
* outputStateOut TensorInfo is not optional.
* cellStateOut TensorInfo is not optional.
* TensorInfo Order matches other QLSTM/LSTM layers.
!armnn:7455
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I5b0e0fa4b6e1c3da6689d9aefc9b959172c2e7d4
Diffstat (limited to 'ConversionUtils_1_2.hpp')
-rw-r--r-- | ConversionUtils_1_2.hpp | 25 |
1 files changed, 17 insertions, 8 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp index 1bcd9f2e..0ff50cff 100644 --- a/ConversionUtils_1_2.hpp +++ b/ConversionUtils_1_2.hpp @@ -3455,6 +3455,18 @@ bool ConvertUnidirectionalSequenceLstm(const HalOperation& operation, // Outputs const TensorInfo& outputInfo = GetTensorInfoForOperand(*output); + unsigned int batchSize = inputInfo.GetShape()[0]; + unsigned int outputSize = outputInfo.GetShape()[2]; + unsigned int numUnits = cellStateInInfo.GetShape()[1]; + + armnn::DataType dataType = inputInfo.GetDataType(); + float qScale = inputInfo.GetQuantizationScale(); + int qOffset = inputInfo.GetQuantizationOffset(); + + armnn::TensorInfo cellStateOutInfo({batchSize, numUnits}, cellStateInInfo.GetDataType(), + cellStateInInfo.GetQuantizationScale(), cellStateInInfo.GetQuantizationOffset()); + armnn::TensorInfo outputStateOutInfo({batchSize, outputSize}, dataType, qScale, qOffset); + // Basic parameters LstmInputParamsInfo paramsInfo; paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo()); @@ -3505,9 +3517,6 @@ bool ConvertUnidirectionalSequenceLstm(const HalOperation& operation, paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo()); } - auto hiddenStateOutInfo = EmptyOptional(); - auto cellStateOutInfo = EmptyOptional(); - bool isSupported = false; auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported) { @@ -3518,9 +3527,9 @@ bool ConvertUnidirectionalSequenceLstm(const HalOperation& operation, inputInfo, outputStateInInfo, cellStateInInfo, - outputInfo, - hiddenStateOutInfo, + outputStateOutInfo, cellStateOutInfo, + outputInfo, desc, paramsInfo); }; @@ -3552,12 +3561,12 @@ bool ConvertUnidirectionalSequenceLstm(const HalOperation& operation, if (!isDynamic) { - return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data)); + return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 2, model, data)); } else { - return (SetupAndTrackLayerOutputSlot<HalPolicy>( - operation, 0, *layer, 0, model, data, nullptr, validateFunc, ActivationFn::kActivationNone, true)); + return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 2, model, data, nullptr, + validateFunc, ActivationFn::kActivationNone, true)); } } |