aboutsummaryrefslogtreecommitdiff
path: root/src/armnnSerializer/test/LstmSerializationTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnSerializer/test/LstmSerializationTests.cpp')
-rw-r--r--src/armnnSerializer/test/LstmSerializationTests.cpp72
1 files changed, 48 insertions, 24 deletions
diff --git a/src/armnnSerializer/test/LstmSerializationTests.cpp b/src/armnnSerializer/test/LstmSerializationTests.cpp
index 3178bc990e..d8f8967bcd 100644
--- a/src/armnnSerializer/test/LstmSerializationTests.cpp
+++ b/src/armnnSerializer/test/LstmSerializationTests.cpp
@@ -1454,7 +1454,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo inputToInputWeightsInfo(inputToInputWeightsShape,
armnn::DataType::QAsymmU8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::ConstTensor inputToInputWeights(inputToInputWeightsInfo, inputToInputWeightsData);
armnn::TensorShape inputToForgetWeightsShape = {4, 2};
@@ -1462,7 +1463,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo inputToForgetWeightsInfo(inputToForgetWeightsShape,
armnn::DataType::QAsymmU8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::ConstTensor inputToForgetWeights(inputToForgetWeightsInfo, inputToForgetWeightsData);
armnn::TensorShape inputToCellWeightsShape = {4, 2};
@@ -1470,7 +1472,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo inputToCellWeightsInfo(inputToCellWeightsShape,
armnn::DataType::QAsymmU8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::ConstTensor inputToCellWeights(inputToCellWeightsInfo, inputToCellWeightsData);
armnn::TensorShape inputToOutputWeightsShape = {4, 2};
@@ -1478,7 +1481,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo inputToOutputWeightsInfo(inputToOutputWeightsShape,
armnn::DataType::QAsymmU8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::ConstTensor inputToOutputWeights(inputToOutputWeightsInfo, inputToOutputWeightsData);
// The shape of recurrent weight data is {outputSize, outputSize} = {4, 4}
@@ -1487,7 +1491,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo recurrentToInputWeightsInfo(recurrentToInputWeightsShape,
armnn::DataType::QAsymmU8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::ConstTensor recurrentToInputWeights(recurrentToInputWeightsInfo, recurrentToInputWeightsData);
armnn::TensorShape recurrentToForgetWeightsShape = {4, 4};
@@ -1495,7 +1500,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo recurrentToForgetWeightsInfo(recurrentToForgetWeightsShape,
armnn::DataType::QAsymmU8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::ConstTensor recurrentToForgetWeights(recurrentToForgetWeightsInfo, recurrentToForgetWeightsData);
armnn::TensorShape recurrentToCellWeightsShape = {4, 4};
@@ -1503,7 +1509,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo recurrentToCellWeightsInfo(recurrentToCellWeightsShape,
armnn::DataType::QAsymmU8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::ConstTensor recurrentToCellWeights(recurrentToCellWeightsInfo, recurrentToCellWeightsData);
armnn::TensorShape recurrentToOutputWeightsShape = {4, 4};
@@ -1511,7 +1518,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo recurrentToOutputWeightsInfo(recurrentToOutputWeightsShape,
armnn::DataType::QAsymmU8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::ConstTensor recurrentToOutputWeights(recurrentToOutputWeightsInfo, recurrentToOutputWeightsData);
// The shape of bias data is {outputSize} = {4}
@@ -1520,7 +1528,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo inputGateBiasInfo(inputGateBiasShape,
armnn::DataType::Signed32,
biasScale,
- biasOffset, true);
+ biasOffset,
+ true);
armnn::ConstTensor inputGateBias(inputGateBiasInfo, inputGateBiasData);
armnn::TensorShape forgetGateBiasShape = {4};
@@ -1528,7 +1537,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo forgetGateBiasInfo(forgetGateBiasShape,
armnn::DataType::Signed32,
biasScale,
- biasOffset, true);
+ biasOffset,
+ true);
armnn::ConstTensor forgetGateBias(forgetGateBiasInfo, forgetGateBiasData);
armnn::TensorShape cellBiasShape = {4};
@@ -1536,7 +1546,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo cellBiasInfo(cellBiasShape,
armnn::DataType::Signed32,
biasScale,
- biasOffset, true);
+ biasOffset,
+ true);
armnn::ConstTensor cellBias(cellBiasInfo, cellBiasData);
armnn::TensorShape outputGateBiasShape = {4};
@@ -1544,7 +1555,8 @@ TEST_CASE("SerializeDeserializeQuantizedLstm")
armnn::TensorInfo outputGateBiasInfo(outputGateBiasShape,
armnn::DataType::Signed32,
biasScale,
- biasOffset, true);
+ biasOffset,
+ true);
armnn::ConstTensor outputGateBias(outputGateBiasInfo, outputGateBiasData);
armnn::QuantizedLstmInputParams params;
@@ -1655,12 +1667,14 @@ TEST_CASE("SerializeDeserializeQLstmBasic")
armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
armnn::DataType::QSymmS8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
armnn::DataType::QSymmS8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::TensorInfo biasInfo({numUnits}, armnn::DataType::Signed32, biasScale, biasOffset, true);
@@ -1816,22 +1830,26 @@ TEST_CASE("SerializeDeserializeQLstmCifgLayerNorm")
armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
armnn::DataType::QSymmS8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
armnn::DataType::QSymmS8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::TensorInfo biasInfo({numUnits},
armnn::DataType::Signed32,
biasScale,
- biasOffset, true);
+ biasOffset,
+ true);
armnn::TensorInfo layerNormWeightsInfo({numUnits},
armnn::DataType::QSymmS16,
layerNormScale,
- layerNormOffset, true);
+ layerNormOffset,
+ true);
// Mandatory params
std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());
@@ -2003,32 +2021,38 @@ TEST_CASE("SerializeDeserializeQLstmAdvanced")
armnn::TensorInfo inputWeightsInfo({numUnits, inputSize},
armnn::DataType::QSymmS8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::TensorInfo recurrentWeightsInfo({numUnits, outputSize},
armnn::DataType::QSymmS8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::TensorInfo biasInfo({numUnits},
armnn::DataType::Signed32,
biasScale,
- biasOffset, true);
+ biasOffset,
+ true);
armnn::TensorInfo peepholeWeightsInfo({numUnits},
armnn::DataType::QSymmS16,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
armnn::TensorInfo layerNormWeightsInfo({numUnits},
armnn::DataType::QSymmS16,
layerNormScale,
- layerNormOffset, true);
+ layerNormOffset,
+ true);
armnn::TensorInfo projectionWeightsInfo({outputSize, numUnits},
armnn::DataType::QSymmS8,
weightsScale,
- weightsOffset, true);
+ weightsOffset,
+ true);
// Mandatory params
std::vector<int8_t> inputToForgetWeightsData = GenerateRandomData<int8_t>(inputWeightsInfo.GetNumElements());