diff options
Diffstat (limited to 'src/armnn/layers/LstmLayer.cpp')
-rw-r--r-- | src/armnn/layers/LstmLayer.cpp | 9 |
1 files changed, 1 insertions, 8 deletions
diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index 866c837357..bd104d49fe 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -123,14 +123,7 @@ std::vector<TensorShape> LstmLayer::InferOutputShapes(const std::vector<TensorSh unsigned int numUnits = inputShapes[2][1]; std::vector<TensorShape> outShapes; - if (!m_Param.m_CifgEnabled) - { - outShapes.push_back(TensorShape({batchSize, numUnits*3})); - } - else - { - outShapes.push_back(TensorShape({batchSize, numUnits*4})); - } + outShapes.push_back(TensorShape({batchSize, numUnits * (m_Param.m_CifgEnabled ? 3 : 4)})); outShapes.push_back(TensorShape({batchSize, outputSize})); outShapes.push_back(TensorShape({batchSize, numUnits})); outShapes.push_back(TensorShape({batchSize, outputSize})); |