diff options
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 14 | ||||
-rw-r--r-- | src/backends/backendsCommon/test/WorkloadDataValidation.cpp | 10 |
2 files changed, 22 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 443dc8eae3..c145c4b39f 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1731,10 +1731,20 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const "output_" + std::to_string(i)); } - // TODO: check clipping parameter is valid + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + if (m_Parameters.m_ClippingThresCell < 0.0f) + { + throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid"); + } + if (m_Parameters.m_ClippingThresProj < 0.0f) + { + throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid"); + } + // Inferring batch size, number of outputs and number of cells from the inputs. - // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1]; const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0]; ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights"); diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp index 70d00b3a91..b5acd88e89 100644 --- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp +++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp @@ -584,6 +584,16 @@ BOOST_AUTO_TEST_CASE(LstmQueueDescriptor_Validate) BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); SetWorkloadOutput(data, info, 3, outputTensorInfo, nullptr); + // check invalid cell clipping parameters + data.m_Parameters.m_ClippingThresCell = -1.0f; + BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); + data.m_Parameters.m_ClippingThresCell = 0.0f; + + // check invalid projection clipping parameters + data.m_Parameters.m_ClippingThresProj = -1.0f; + BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException); + data.m_Parameters.m_ClippingThresProj = 0.0f; + // check correct configuration BOOST_CHECK_NO_THROW(data.Validate(info)); } |