diff options
author | Jan Eilers <jan.eilers@arm.com> | 2019-06-26 13:10:09 +0100 |
---|---|---|
committer | Jan Eilers <jan.eilers@arm.com> | 2019-07-02 09:59:37 +0000 |
commit | 38e05bd2836b1b65b440330a9c283038ba4192c3 (patch) | |
tree | c232f71ce6a101c70ed65e046678f7b22593dbe4 /src/backends/backendsCommon/test/WorkloadDataValidation.cpp | |
parent | d0c0cc3e27f1ada9df167d3b9ff248be432d16e1 (diff) | |
download | armnn-38e05bd2836b1b65b440330a9c283038ba4192c3.tar.gz |
IVGCVSW-3236 Extend Ref LSTM with layer normalization support
* Add descriptor values
* Update lstm queue descriptor validate function
* Update lstm workload
* Update isLstmSupported (Cl and Ref), LayerSupportBase, ILayerSupport
* Update lstm layer
* Add unit tests
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Change-Id: I932175d550facfb342325051eaa7bd2084ebdc18
Signed-off-by: Jan Eilers <jan.eilers@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/test/WorkloadDataValidation.cpp')
-rw-r--r-- | src/backends/backendsCommon/test/WorkloadDataValidation.cpp | 149 |
1 files changed, 133 insertions, 16 deletions
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp index 7c7af2ddce..c6960986b3 100644 --- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp +++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp @@ -453,22 +453,139 @@ BOOST_AUTO_TEST_CASE(ReshapeQueueDescriptor_Validate_MismatchingNumElements) BOOST_AUTO_TEST_CASE(LstmQueueDescriptor_Validate) { - armnn::TensorInfo inputTensorInfo; - armnn::TensorInfo outputTensorInfo; - - unsigned int inputShape[] = { 1, 2 }; - unsigned int outputShape[] = { 1 }; - - inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::DataType::Float32); - outputTensorInfo = armnn::TensorInfo(1, outputShape, armnn::DataType::Float32); - - LstmQueueDescriptor invalidData; - WorkloadInfo invalidInfo; - - AddInputToWorkload(invalidData, invalidInfo, inputTensorInfo, nullptr); - AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr); - - BOOST_CHECK_THROW(invalidData.Validate(invalidInfo), armnn::InvalidArgumentException); + armnn::DataType dataType = armnn::DataType::Float32; + + float qScale = 0.0f; + int32_t qOffset = 0; + + unsigned int batchSize = 2; + unsigned int outputSize = 3; + unsigned int inputSize = 5; + unsigned numUnits = 4; + + armnn::TensorInfo inputTensorInfo({batchSize , inputSize}, dataType, qScale, qOffset ); + armnn::TensorInfo outputStateInTensorInfo({batchSize , outputSize}, dataType, qScale, qOffset); + armnn::TensorInfo cellStateInTensorInfo({batchSize , numUnits}, dataType, qScale, qOffset); + + // Scratch buffer size with CIFG [batchSize, numUnits * 4] + armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset); + armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits}, dataType, qScale, qOffset); + armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset); + armnn::TensorInfo outputTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset); + + armnn::TensorInfo tensorInfo3({outputSize}, dataType, qScale, qOffset); + armnn::TensorInfo tensorInfo4({numUnits}, dataType, qScale, qOffset); + armnn::TensorInfo tensorInfo4x5({numUnits, inputSize}, dataType, qScale, qOffset); + armnn::TensorInfo tensorInfo4x3({numUnits, outputSize}, dataType, qScale, qOffset); + armnn::TensorInfo tensorInfo3x4({outputSize, numUnits}, dataType, qScale, qOffset); + + LstmQueueDescriptor data; + WorkloadInfo info; + + AddInputToWorkload(data, info, inputTensorInfo, nullptr); + AddInputToWorkload(data, info, outputStateInTensorInfo, nullptr); + AddInputToWorkload(data, info, cellStateInTensorInfo, nullptr); + + AddOutputToWorkload(data, info, scratchBufferTensorInfo, nullptr); + AddOutputToWorkload(data, info, outputStateOutTensorInfo, nullptr); + AddOutputToWorkload(data, info, cellStateOutTensorInfo, nullptr); + // AddOutputToWorkload(data, info, outputTensorInfo, nullptr); is left out + + armnn::ScopedCpuTensorHandle inputToInputWeightsTensor(tensorInfo4x5); + armnn::ScopedCpuTensorHandle inputToForgetWeightsTensor(tensorInfo4x5); + armnn::ScopedCpuTensorHandle inputToCellWeightsTensor(tensorInfo4x5); + armnn::ScopedCpuTensorHandle inputToOutputWeightsTensor(tensorInfo4x5); + armnn::ScopedCpuTensorHandle recurrentToForgetWeightsTensor(tensorInfo4x3); + armnn::ScopedCpuTensorHandle recurrentToInputWeightsTensor(tensorInfo4x3); + armnn::ScopedCpuTensorHandle recurrentToCellWeightsTensor(tensorInfo4x3); + armnn::ScopedCpuTensorHandle recurrentToOutputWeightsTensor(tensorInfo4x3); + armnn::ScopedCpuTensorHandle cellToInputWeightsTensor(tensorInfo4); + armnn::ScopedCpuTensorHandle inputGateBiasTensor(tensorInfo4); + armnn::ScopedCpuTensorHandle forgetGateBiasTensor(tensorInfo4); + armnn::ScopedCpuTensorHandle cellBiasTensor(tensorInfo4); + armnn::ScopedCpuTensorHandle outputGateBiasTensor(tensorInfo4); + armnn::ScopedCpuTensorHandle cellToForgetWeightsTensor(tensorInfo4); + armnn::ScopedCpuTensorHandle cellToOutputWeightsTensor(tensorInfo4); + armnn::ScopedCpuTensorHandle projectionWeightsTensor(tensorInfo3x4); + armnn::ScopedCpuTensorHandle projectionBiasTensor(tensorInfo3); + armnn::ScopedCpuTensorHandle inputLayerNormWeightsTensor(tensorInfo4); + armnn::ScopedCpuTensorHandle forgetLayerNormWeightsTensor(tensorInfo4); + armnn::ScopedCpuTensorHandle cellLayerNormWeightsTensor(tensorInfo4); + armnn::ScopedCpuTensorHandle outputLayerNormWeightsTensor(tensorInfo4); + + data.m_InputToInputWeights = &inputToInputWeightsTensor; + data.m_InputToForgetWeights = &inputToForgetWeightsTensor; + data.m_InputToCellWeights = &inputToCellWeightsTensor; + data.m_InputToOutputWeights = &inputToOutputWeightsTensor; + data.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; + data.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; + data.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; + data.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; + data.m_CellToInputWeights = &cellToInputWeightsTensor; + data.m_InputGateBias = &inputGateBiasTensor; + data.m_ForgetGateBias = &forgetGateBiasTensor; + data.m_CellBias = &cellBiasTensor; + data.m_OutputGateBias = &outputGateBiasTensor; + data.m_CellToForgetWeights = &cellToForgetWeightsTensor; + data.m_CellToOutputWeights = &cellToOutputWeightsTensor; + data.m_ProjectionWeights = &projectionWeightsTensor; + data.m_ProjectionBias = &projectionBiasTensor; + + data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor; + data.m_ForgetLayerNormWeights = &forgetLayerNormWeightsTensor; + data.m_CellLayerNormWeights = &cellLayerNormWeightsTensor; + data.m_OutputLayerNormWeights = &outputLayerNormWeightsTensor; + + // Flags to set test configuration + data.m_Parameters.m_ActivationFunc = 4; + data.m_Parameters.m_CifgEnabled = false; + data.m_Parameters.m_PeepholeEnabled = true; + data.m_Parameters.m_ProjectionEnabled = true; + data.m_Parameters.m_LayerNormEnabled = true; + + // check wrong number of outputs + BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); + AddOutputToWorkload(data, info, outputTensorInfo, nullptr); + + // check wrong cifg parameter configuration + data.m_Parameters.m_CifgEnabled = true; + armnn::TensorInfo scratchBufferTensorInfo2({batchSize, numUnits * 3}, dataType, qScale, qOffset); + SetWorkloadOutput(data, info, 0, scratchBufferTensorInfo2, nullptr); + BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); + data.m_Parameters.m_CifgEnabled = false; + SetWorkloadOutput(data, info, 0, scratchBufferTensorInfo, nullptr); + + // check wrong inputGateBias configuration + data.m_InputGateBias = nullptr; + BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); + data.m_InputGateBias = &inputGateBiasTensor; + + // check inconsistant projection parameters + data.m_Parameters.m_ProjectionEnabled = false; + BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); + data.m_Parameters.m_ProjectionEnabled = true; + data.m_ProjectionWeights = nullptr; + BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); + data.m_ProjectionWeights = &projectionWeightsTensor; + + // check missing input layer normalisation weights + data.m_InputLayerNormWeights = nullptr; + BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); + data.m_InputLayerNormWeights = &inputLayerNormWeightsTensor; + + // layer norm disabled but normalisation weights are present + data.m_Parameters.m_LayerNormEnabled = false; + BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); + data.m_Parameters.m_LayerNormEnabled = true; + + // check invalid outputTensor shape + armnn::TensorInfo incorrectOutputTensorInfo({batchSize, outputSize + 1}, dataType, qScale, qOffset); + SetWorkloadOutput(data, info, 3, incorrectOutputTensorInfo, nullptr); + BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); + SetWorkloadOutput(data, info, 3, outputTensorInfo, nullptr); + + // check correct configuration + BOOST_CHECK_NO_THROW(data.Validate(info)); } BOOST_AUTO_TEST_SUITE_END() |