aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/test/LstmSerializationTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/test/LstmSerializationTests.cpp')
-rw-r--r--src/armnnSerializer/test/LstmSerializationTests.cpp40
1 files changed, 28 insertions, 12 deletions
diff --git a/src/armnnSerializer/test/LstmSerializationTests.cpp b/src/armnnSerializer/test/LstmSerializationTests.cpp
index d8f8967bcd..ae2d813fc0 100644
--- a/src/armnnSerializer/test/LstmSerializationTests.cpp
+++ b/src/armnnSerializer/test/LstmSerializationTests.cpp
@@ -2299,6 +2299,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32);
armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateOutTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateOutTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
@@ -2310,8 +2312,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
CHECK(deserializedNetwork);
@@ -2319,7 +2323,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
layerName,
{inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
- {outputTensorInfo},
+ {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo},
descriptor,
params);
deserializedNetwork->ExecuteStrategy(checker);
@@ -2436,6 +2440,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndPr
armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32);
armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateOutTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateOutTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
@@ -2447,8 +2453,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndPr
cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
CHECK(deserializedNetwork);
@@ -2456,7 +2464,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndPr
VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
layerName,
{inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
- {outputTensorInfo},
+ {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo},
descriptor,
params);
deserializedNetwork->ExecuteStrategy(checker);
@@ -2592,6 +2600,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithP
armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32);
armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32);
inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
@@ -2603,8 +2613,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithP
cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
CHECK(deserializedNetwork);
@@ -2612,7 +2624,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithP
VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
layerName,
{inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
- {outputTensorInfo},
+ {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo},
descriptor,
params);
deserializedNetwork->ExecuteStrategy(checker);
@@ -2697,6 +2709,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
armnn::TensorInfo inputTensorInfo({ timeSize, batchSize, inputSize }, armnn::DataType::Float32);
armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32);
armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32);
+ armnn::TensorInfo outputStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32);
+ armnn::TensorInfo cellStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32);
armnn::TensorInfo outputTensorInfo({ timeSize, batchSize, outputSize }, armnn::DataType::Float32);
inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0));
@@ -2708,8 +2722,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2));
cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo);
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
- unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0));
+ unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
CHECK(deserializedNetwork);
@@ -2717,7 +2733,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio
VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker(
layerName,
{inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo},
- {outputTensorInfo},
+ {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo},
descriptor,
params);
deserializedNetwork->ExecuteStrategy(checker);