aboutsummaryrefslogtreecommitdiff
path: root/ConversionUtils_1_2.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'ConversionUtils_1_2.hpp')
-rw-r--r--ConversionUtils_1_2.hpp25
1 files changed, 17 insertions, 8 deletions
diff --git a/ConversionUtils_1_2.hpp b/ConversionUtils_1_2.hpp
index 1bcd9f2e..0ff50cff 100644
--- a/ConversionUtils_1_2.hpp
+++ b/ConversionUtils_1_2.hpp
@@ -3455,6 +3455,18 @@ bool ConvertUnidirectionalSequenceLstm(const HalOperation& operation,
// Outputs
const TensorInfo& outputInfo = GetTensorInfoForOperand(*output);
+ unsigned int batchSize = inputInfo.GetShape()[0];
+ unsigned int outputSize = outputInfo.GetShape()[2];
+ unsigned int numUnits = cellStateInInfo.GetShape()[1];
+
+ armnn::DataType dataType = inputInfo.GetDataType();
+ float qScale = inputInfo.GetQuantizationScale();
+ int qOffset = inputInfo.GetQuantizationOffset();
+
+ armnn::TensorInfo cellStateOutInfo({batchSize, numUnits}, cellStateInInfo.GetDataType(),
+ cellStateInInfo.GetQuantizationScale(), cellStateInInfo.GetQuantizationOffset());
+ armnn::TensorInfo outputStateOutInfo({batchSize, outputSize}, dataType, qScale, qOffset);
+
// Basic parameters
LstmInputParamsInfo paramsInfo;
paramsInfo.m_InputToForgetWeights = &(params.m_InputToForgetWeights->GetInfo());
@@ -3505,9 +3517,6 @@ bool ConvertUnidirectionalSequenceLstm(const HalOperation& operation,
paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
}
- auto hiddenStateOutInfo = EmptyOptional();
- auto cellStateOutInfo = EmptyOptional();
-
bool isSupported = false;
auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
{
@@ -3518,9 +3527,9 @@ bool ConvertUnidirectionalSequenceLstm(const HalOperation& operation,
inputInfo,
outputStateInInfo,
cellStateInInfo,
- outputInfo,
- hiddenStateOutInfo,
+ outputStateOutInfo,
cellStateOutInfo,
+ outputInfo,
desc,
paramsInfo);
};
@@ -3552,12 +3561,12 @@ bool ConvertUnidirectionalSequenceLstm(const HalOperation& operation,
if (!isDynamic)
{
- return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 0, model, data));
+ return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 2, model, data));
}
else
{
- return (SetupAndTrackLayerOutputSlot<HalPolicy>(
- operation, 0, *layer, 0, model, data, nullptr, validateFunc, ActivationFn::kActivationNone, true));
+ return (SetupAndTrackLayerOutputSlot<HalPolicy>(operation, 0, *layer, 2, model, data, nullptr,
+ validateFunc, ActivationFn::kActivationNone, true));
}
}