aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2019-06-26 13:10:09 +0100
committerJan Eilers <jan.eilers@arm.com>2019-07-02 09:59:37 +0000
commit38e05bd2836b1b65b440330a9c283038ba4192c3 (patch)
treec232f71ce6a101c70ed65e046678f7b22593dbe4 /src/backends/backendsCommon/test/WorkloadDataValidation.cpp
parentd0c0cc3e27f1ada9df167d3b9ff248be432d16e1 (diff)
downloadarmnn-38e05bd2836b1b65b440330a9c283038ba4192c3.tar.gz
IVGCVSW-3236 Extend Ref LSTM with layer normalization support
* Add descriptor values * Update lstm queue descriptor validate function * Update lstm workload * Update isLstmSupported (Cl and Ref), LayerSupportBase, ILayerSupport * Update lstm layer * Add unit tests Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I932175d550facfb342325051eaa7bd2084ebdc18 Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/test/WorkloadDataValidation.cpp')
-rw-r--r--src/backends/backendsCommon/test/WorkloadDataValidation.cpp149
1 files changed, 133 insertions, 16 deletions
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 7c7af2ddce..c6960986b3 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -453,22 +453,139 @@ BOOST_AUTO_TEST_CASE(ReshapeQueueDescriptor_Validate_MismatchingNumElements)
BOOST_AUTO_TEST_CASE(LstmQueueDescriptor_Validate)
{
- armnn::TensorInfo inputTensorInfo;
- armnn::TensorInfo outputTensorInfo;
-
- unsigned int inputShape[] = { 1, 2 };
- unsigned int outputShape[] = { 1 };
-
- inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::DataType::Float32);
- outputTensorInfo = armnn::TensorInfo(1, outputShape, armnn::DataType::Float32);
-
- LstmQueueDescriptor invalidData;
- WorkloadInfo invalidInfo;
-
- AddInputToWorkload(invalidData, invalidInfo, inputTensorInfo, nullptr);
- AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
-
- BOOST_CHECK_THROW(invalidData.Validate(invalidInfo), armnn::InvalidArgumentException);
+ armnn::DataType dataType = armnn::DataType::Float32;
+
+ float qScale = 0.0f;
+ int32_t qOffset = 0;
+
+ unsigned int batchSize = 2;
+ unsigned int outputSize = 3;
+ unsigned int inputSize = 5;
+ unsigned numUnits = 4;
+
+ armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, dataType, qScale, qOffset );
+ armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, dataType, qScale, qOffset);
+ armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, dataType, qScale, qOffset);
+
+ // Scratch buffer size with CIFG [batchSize, numUnits * 4]
+ armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset);
+ armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset);
+ armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
+ armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
+
+ armnn::TensorInfo tensorInfo3({outputSize}, dataType, qScale, qOffset);
+ armnn::TensorInfo tensorInfo4({numUnits}, dataType, qScale, qOffset);
+ armnn::TensorInfo tensorInfo4x5({numUnits, inputSize}, dataType, qScale, qOffset);
+ armnn::TensorInfo tensorInfo4x3({numUnits, outputSize}, dataType, qScale, qOffset);
+ armnn::TensorInfo tensorInfo3x4({outputSize, numUnits}, dataType, qScale, qOffset);
+
+ LstmQueueDescriptor data;
+ WorkloadInfo info;
+
+ AddInputToWorkload(data, info, inputTensorInfo, nullptr);
+ AddInputToWorkload(data, info, outputStateInTensorInfo, nullptr);
+ AddInputToWorkload(data, info, cellStateInTensorInfo, nullptr);
+
+ AddOutputToWorkload(data, info, scratchBufferTensorInfo, nullptr);
+ AddOutputToWorkload(data, info, outputStateOutTensorInfo, nullptr);
+ AddOutputToWorkload(data, info, cellStateOutTensorInfo, nullptr);
+ // AddOutputToWorkload(data, info, outputTensorInfo, nullptr); is left out
+
+ armnn::ScopedCpuTensorHandle inputToInputWeightsTensor(tensorInfo4x5);
+ armnn::ScopedCpuTensorHandle inputToForgetWeightsTensor(tensorInfo4x5);
+ armnn::ScopedCpuTensorHandle inputToCellWeightsTensor(tensorInfo4x5);
+ armnn::ScopedCpuTensorHandle inputToOutputWeightsTensor(tensorInfo4x5);
+ armnn::ScopedCpuTensorHandle recurrentToForgetWeightsTensor(tensorInfo4x3);
+ armnn::ScopedCpuTensorHandle recurrentToInputWeightsTensor(tensorInfo4x3);
+ armnn::ScopedCpuTensorHandle recurrentToCellWeightsTensor(tensorInfo4x3);
+ armnn::ScopedCpuTensorHandle recurrentToOutputWeightsTensor(tensorInfo4x3);
+ armnn::ScopedCpuTensorHandle cellToInputWeightsTensor(tensorInfo4);
+ armnn::ScopedCpuTensorHandle inputGateBiasTensor(tensorInfo4);
+ armnn::ScopedCpuTensorHandle forgetGateBiasTensor(tensorInfo4);
+ armnn::ScopedCpuTensorHandle cellBiasTensor(tensorInfo4);
+ armnn::ScopedCpuTensorHandle outputGateBiasTensor(tensorInfo4);
+ armnn::ScopedCpuTensorHandle cellToForgetWeightsTensor(tensorInfo4);
+ armnn::ScopedCpuTensorHandle cellToOutputWeightsTensor(tensorInfo4);
+ armnn::ScopedCpuTensorHandle projectionWeightsTensor(tensorInfo3x4);
+ armnn::ScopedCpuTensorHandle projectionBiasTensor(tensorInfo3);
+ armnn::ScopedCpuTensorHandle inputLayerNormWeightsTensor(tensorInfo4);
+ armnn::ScopedCpuTensorHandle forgetLayerNormWeightsTensor(tensorInfo4);
+ armnn::ScopedCpuTensorHandle cellLayerNormWeightsTensor(tensorInfo4);
+ armnn::ScopedCpuTensorHandle outputLayerNormWeightsTensor(tensorInfo4);
+
+ data.m_InputToInputWeights = &inputToInputWeightsTensor;
+ data.m_InputToForgetWeights = &inputToForgetWeightsTensor;
+ data.m_InputToCellWeights = &inputToCellWeightsTensor;
+ data.m_InputToOutputWeights = &inputToOutputWeightsTensor;
+ data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor;
+ data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor;
+ data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor;
+ data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor;
+ data.m_CellToInputWeights = &cellToInputWeightsTensor;
+ data.m_InputGateBias = &inputGateBiasTensor;
+ data.m_ForgetGateBias = &forgetGateBiasTensor;
+ data.m_CellBias = &cellBiasTensor;
+ data.m_OutputGateBias = &outputGateBiasTensor;
+ data.m_CellToForgetWeights = &cellToForgetWeightsTensor;
+ data.m_CellToOutputWeights = &cellToOutputWeightsTensor;
+ data.m_ProjectionWeights = &projectionWeightsTensor;
+ data.m_ProjectionBias = &projectionBiasTensor;
+
+ data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
+ data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor;
+ data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor;
+ data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor;
+
+ // Flags to set test configuration
+ data.m_Parameters.m_ActivationFunc = 4;
+ data.m_Parameters.m_CifgEnabled = false;
+ data.m_Parameters.m_PeepholeEnabled = true;
+ data.m_Parameters.m_ProjectionEnabled = true;
+ data.m_Parameters.m_LayerNormEnabled = true;
+
+ // check wrong number of outputs
+ BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+ AddOutputToWorkload(data, info, outputTensorInfo, nullptr);
+
+ // check wrong cifg parameter configuration
+ data.m_Parameters.m_CifgEnabled = true;
+ armnn::TensorInfo scratchBufferTensorInfo2({batchSize, numUnits * 3}, dataType, qScale, qOffset);
+ SetWorkloadOutput(data, info, 0, scratchBufferTensorInfo2, nullptr);
+ BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+ data.m_Parameters.m_CifgEnabled = false;
+ SetWorkloadOutput(data, info, 0, scratchBufferTensorInfo, nullptr);
+
+ // check wrong inputGateBias configuration
+ data.m_InputGateBias = nullptr;
+ BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+ data.m_InputGateBias = &inputGateBiasTensor;
+
+ // check inconsistant projection parameters
+ data.m_Parameters.m_ProjectionEnabled = false;
+ BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+ data.m_Parameters.m_ProjectionEnabled = true;
+ data.m_ProjectionWeights = nullptr;
+ BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+ data.m_ProjectionWeights = &projectionWeightsTensor;
+
+ // check missing input layer normalisation weights
+ data.m_InputLayerNormWeights = nullptr;
+ BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+ data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor;
+
+ // layer norm disabled but normalisation weights are present
+ data.m_Parameters.m_LayerNormEnabled = false;
+ BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+ data.m_Parameters.m_LayerNormEnabled = true;
+
+ // check invalid outputTensor shape
+ armnn::TensorInfo incorrectOutputTensorInfo({batchSize, outputSize + 1}, dataType, qScale, qOffset);
+ SetWorkloadOutput(data, info, 3, incorrectOutputTensorInfo, nullptr);
+ BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+ SetWorkloadOutput(data, info, 3, outputTensorInfo, nullptr);
+
+ // check correct configuration
+ BOOST_CHECK_NO_THROW(data.Validate(info));
}
BOOST_AUTO_TEST_SUITE_END()