aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/test/LstmSerializationTests.cpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2022-04-21 11:57:09 +0100
committermike.kelly <mike.kelly@arm.com>2022-05-05 08:29:20 +0000
commit1299496996bc332f02218f926640a9255ed60310 (patch)
tree2d242e142bd8fe7387140bcf8cdf39cd13ffc9eb /src/armnnSerializer/test/LstmSerializationTests.cpp
parent8272a7bda2974c39b6c06e3eb3a000f2bdb749f7 (diff)
downloadarmnn-1299496996bc332f02218f926640a9255ed60310.tar.gz
IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon
* Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported * outputStateOut TensorInfo is not optional. * cellStateOut TensorInfo is not optional. * TensorInfo Order matches other QLSTM/LSTM layers. * Added missing parameters to UnidirectionalSequenceLstmOperator for delegate. * Added quantized UnidirectionalSequenceLstm support to Neon !android-nn-driver:7457 Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4
Diffstat (limited to 'src/armnnSerializer/test/LstmSerializationTests.cpp')
-rw-r--r--src/armnnSerializer/test/LstmSerializationTests.cpp40
1 files changed, 28 insertions, 12 deletions
diff --git a/src/armnnSerializer/test/LstmSerializationTests.cpp b/src/armnnSerializer/test/LstmSerializationTests.cpp
index d8f8967bcd..ae2d813fc0 100644
--- a/src/armnnSerializer/test/LstmSerializationTests.cpp
+++ b/src/armnnSerializer/test/LstmSerializationTests.cpp
@@ -2299,6 +2299,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32);
armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateOutTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateOutTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
@@ -2310,8 +2312,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
CHECK(deserializedNetwork);
@@ -2319,7 +2323,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
layerName,
{inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
- {outputTensorInfo},
+ {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo},
descriptor,
params);
deserializedNetwork->ExecuteStrategy(checker);
@@ -2436,6 +2440,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndPr
armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32);
armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateOutTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateOutTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
@@ -2447,8 +2453,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndPr
cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
CHECK(deserializedNetwork);
@@ -2456,7 +2464,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndPr
VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
layerName,
{inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
- {outputTensorInfo},
+ {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo},
descriptor,
params);
deserializedNetwork->ExecuteStrategy(checker);
@@ -2592,6 +2600,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithP
armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32);
armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
@@ -2603,8 +2613,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithP
cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
CHECK(deserializedNetwork);
@@ -2612,7 +2624,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithP
VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
layerName,
{inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
- {outputTensorInfo},
+ {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo},
descriptor,
params);
deserializedNetwork->ExecuteStrategy(checker);
@@ -2697,6 +2709,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
armnn::TensorInfo inputTensorInfo({ timeSize, batchSize, inputSize }, armnn::DataType::Float32);
armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo({ timeSize, batchSize, outputSize }, armnn::DataType::Float32);
inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
@@ -2708,8 +2722,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
CHECK(deserializedNetwork);
@@ -2717,7 +2733,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
layerName,
{inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
- {outputTensorInfo},
+ {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo},
descriptor,
params);
deserializedNetwork->ExecuteStrategy(checker);