diff options
author | Mike Kelly <mike.kelly@arm.com> | 2022-04-21 11:57:09 +0100 |
---|---|---|
committer | mike.kelly <mike.kelly@arm.com> | 2022-05-05 08:29:20 +0000 |
commit | 1299496996bc332f02218f926640a9255ed60310 (patch) | |
tree | 2d242e142bd8fe7387140bcf8cdf39cd13ffc9eb /src/armnnSerializer/test | |
parent | 8272a7bda2974c39b6c06e3eb3a000f2bdb749f7 (diff) | |
download | armnn-1299496996bc332f02218f926640a9255ed60310.tar.gz |
IVGCVSW-6806 Add Unidirectional Sequence Lstm support to Neon
* Corrected TensorInfo order for IsUnidirectionalSequenceLstmSupported
* outputStateOut TensorInfo is not optional.
* cellStateOut TensorInfo is not optional.
* TensorInfo Order matches other QLSTM/LSTM layers.
* Added missing parameters to UnidirectionalSequenceLstmOperator for
delegate.
* Added quantized UnidirectionalSequenceLstm support to Neon
!android-nn-driver:7457
Signed-off-by: Mike Kelly <mike.kelly@arm.com>
Change-Id: I26dde1bb96793dd25eb9081ca5ae5f63752288c4
Diffstat (limited to 'src/armnnSerializer/test')
-rw-r--r-- | src/armnnSerializer/test/LstmSerializationTests.cpp | 40 |
1 files changed, 28 insertions, 12 deletions
diff --git a/src/armnnSerializer/test/LstmSerializationTests.cpp b/src/armnnSerializer/test/LstmSerializationTests.cpp index d8f8967bcd..ae2d813fc0 100644 --- a/src/armnnSerializer/test/LstmSerializationTests.cpp +++ b/src/armnnSerializer/test/LstmSerializationTests.cpp @@ -2299,6 +2299,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32); armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputStateOutTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateOutTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); @@ -2310,8 +2312,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); - unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); - unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); CHECK(deserializedNetwork); @@ -2319,7 +2323,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker( layerName, {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, - {outputTensorInfo}, + {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo}, descriptor, params); deserializedNetwork->ExecuteStrategy(checker); @@ -2436,6 +2440,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndPr armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32); armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputStateOutTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo cellStateOutTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); @@ -2447,8 +2453,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndPr cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); - unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); - unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); CHECK(deserializedNetwork); @@ -2456,7 +2464,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeAndPr VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker( layerName, {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, - {outputTensorInfo}, + {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo}, descriptor, params); deserializedNetwork->ExecuteStrategy(checker); @@ -2592,6 +2600,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithP armnn::TensorInfo inputTensorInfo({ batchSize, timeSize, inputSize }, armnn::DataType::Float32); armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32); + armnn::TensorInfo cellStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32); armnn::TensorInfo outputTensorInfo({ batchSize, timeSize, outputSize }, armnn::DataType::Float32); inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); @@ -2603,8 +2613,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithP cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); - unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); - unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); CHECK(deserializedNetwork); @@ -2612,7 +2624,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmNoCifgWithPeepholeWithP VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker( layerName, {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, - {outputTensorInfo}, + {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo}, descriptor, params); deserializedNetwork->ExecuteStrategy(checker); @@ -2697,6 +2709,8 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio armnn::TensorInfo inputTensorInfo({ timeSize, batchSize, inputSize }, armnn::DataType::Float32); armnn::TensorInfo cellStateTensorInfo({ batchSize, numUnits}, armnn::DataType::Float32); armnn::TensorInfo outputStateTensorInfo({ batchSize, outputSize }, armnn::DataType::Float32); + armnn::TensorInfo outputStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32); + armnn::TensorInfo cellStateOutTensorInfo({batchSize, timeSize, outputSize}, armnn::DataType::Float32); armnn::TensorInfo outputTensorInfo({ timeSize, batchSize, outputSize }, armnn::DataType::Float32); inputLayer->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(0)); @@ -2708,8 +2722,10 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio cellStateIn->GetOutputSlot(0).Connect(unidirectionalSequenceLstmLayer->GetInputSlot(2)); cellStateIn->GetOutputSlot(0).SetTensorInfo(cellStateTensorInfo); - unidirectionalSequenceLstmLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); - unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo); + unidirectionalSequenceLstmLayer->GetOutputSlot(2).Connect(outputLayer->GetInputSlot(0)); + unidirectionalSequenceLstmLayer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo); armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); CHECK(deserializedNetwork); @@ -2717,7 +2733,7 @@ TEST_CASE("SerializeDeserializeUnidirectionalSequenceLstmCifgPeepholeNoProjectio VerifyLstmLayer<armnn::UnidirectionalSequenceLstmDescriptor> checker( layerName, {inputTensorInfo, outputStateTensorInfo, cellStateTensorInfo}, - {outputTensorInfo}, + {outputStateOutTensorInfo, cellStateOutTensorInfo, outputTensorInfo}, descriptor, params); deserializedNetwork->ExecuteStrategy(checker); |