aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjaneil01 <jan.eilers@arm.com>2019-11-15 15:00:16 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-11-15 17:14:05 +0000
commit17d8d85ece077c1272839edea32c55af553ced59 (patch)
tree89a8ae955f6a2562c32a7c21f6c3b4bbe62d6c2d
parent0270524f96c4e21a755d1c71e46c4e8665918237 (diff)
downloadarmnn-17d8d85ece077c1272839edea32c55af553ced59.tar.gz
IVGCVSW-3486 Add clipping parameter validation in LstmQueueDescriptor
* Add clipping parameter validation in LstmQueueDescriptor * Related UnitTest Signed-off-by: janeil01 <jan.eilers@arm.com> Change-Id: I86ff81cacc0e1fff5b78a8d6c2dcbf9ff57e2272
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp14
-rw-r--r--src/backends/backendsCommon/test/WorkloadDataValidation.cpp10
2 files changed, 22 insertions, 2 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 443dc8eae3..c145c4b39f 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1731,10 +1731,20 @@ void LstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
"output_" + std::to_string(i));
}
- // TODO: check clipping parameter is valid
+ // Making sure clipping parameters have valid values.
+ // == 0 means no clipping
+ // > 0 means clipping
+ if (m_Parameters.m_ClippingThresCell < 0.0f)
+ {
+ throw InvalidArgumentException(descriptorName + ": negative cell clipping threshold is invalid");
+ }
+ if (m_Parameters.m_ClippingThresProj < 0.0f)
+ {
+ throw InvalidArgumentException(descriptorName + ": negative projection clipping threshold is invalid");
+ }
+
// Inferring batch size, number of outputs and number of cells from the inputs.
- // TODO: figure out if there is a way to make sure the specific inputs are at that index of workloadInfo
const uint32_t n_input = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
const uint32_t n_batch = workloadInfo.m_InputTensorInfos[0].GetShape()[0];
ValidatePointer(m_InputToOutputWeights, "Null pointer check", "InputToOutputWeights");
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 70d00b3a91..b5acd88e89 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -584,6 +584,16 @@ BOOST_AUTO_TEST_CASE(LstmQueueDescriptor_Validate)
BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
SetWorkloadOutput(data, info, 3, outputTensorInfo, nullptr);
+ // check invalid cell clipping parameters
+ data.m_Parameters.m_ClippingThresCell = -1.0f;
+ BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+ data.m_Parameters.m_ClippingThresCell = 0.0f;
+
+ // check invalid projection clipping parameters
+ data.m_Parameters.m_ClippingThresProj = -1.0f;
+ BOOST_CHECK_THROW(data.Validate(info), armnn::InvalidArgumentException);
+ data.m_Parameters.m_ClippingThresProj = 0.0f;
+
// check correct configuration
BOOST_CHECK_NO_THROW(data.Validate(info));
}