aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSiCongLi <sicong.li@arm.com>2021-12-22 15:37:20 +0000
committerSiCong Li <sicong.li@arm.com>2021-12-24 15:59:25 +0000
commitcb86956e1972be4b2ddbaacaa23a0d21185f8ccb (patch)
tree69a51b7b8a8121756044fa60657d05d534fb6e02
parentb2eba7f307d5ae634ff41bd88d5bd1659466d82d (diff)
downloadComputeLibrary-cb86956e1972be4b2ddbaacaa23a0d21185f8ccb.tar.gz
Fix test validation method
* Allow non-finite values to be equal (inf == inf, -inf == -inf) in validate * Fix SpecialPoolingLayerValidationFixture Partially resolves COMPMID-4998 Change-Id: I3fba1ccee74c1af419a3b6088ddac68c79aa243a Signed-off-by: SiCongLi <sicong.li@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6856 Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--tests/validation/Validation.h23
-rw-r--r--tests/validation/fixtures/PoolingLayerFixture.h2
2 files changed, 18 insertions, 7 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h
index 638a1c20ee..4f3f92da24 100644
--- a/tests/validation/Validation.h
+++ b/tests/validation/Validation.h
@@ -45,6 +45,17 @@ namespace test
{
namespace validation
{
+namespace
+{
+// Compare if 2 values are both infinities and if they are "equal" (has the same sign)
+template <typename T>
+bool are_equal_infs(T val0, T val1)
+{
+ const auto same_sign = std::signbit(val0) == std::signbit(val1);
+ return (!support::cpp11::isfinite(val0)) && (!support::cpp11::isfinite(val1)) && same_sign;
+}
+} // namespace
+
/** Class reprensenting an absolute tolerance value. */
template <typename T>
class AbsoluteTolerance
@@ -296,9 +307,9 @@ struct compare<AbsoluteTolerance<U>> : public compare_base<AbsoluteTolerance<U>>
/** Perform comparison */
operator bool() const
{
- if(!support::cpp11::isfinite(this->_target) || !support::cpp11::isfinite(this->_reference))
+ if(are_equal_infs(this->_target, this->_reference))
{
- return false;
+ return true;
}
else if(this->_target == this->_reference)
{
@@ -322,9 +333,9 @@ struct compare<RelativeTolerance<U>> : public compare_base<RelativeTolerance<U>>
/** Perform comparison */
operator bool() const
{
- if(!support::cpp11::isfinite(this->_target) || !support::cpp11::isfinite(this->_reference))
+ if(are_equal_infs(this->_target, this->_reference))
{
- return false;
+ return true;
}
else if(this->_target == this->_reference)
{
@@ -494,9 +505,9 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor<T> &reference, co
// check for wrapping
if(!equal)
{
- if(!support::cpp11::isfinite(target_value) || !support::cpp11::isfinite(reference_value))
+ if(are_equal_infs(target_value, reference_value))
{
- equal = false;
+ equal = true;
}
else
{
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h
index ec4e9f80dd..6e9edfbb5d 100644
--- a/tests/validation/fixtures/PoolingLayerFixture.h
+++ b/tests/validation/fixtures/PoolingLayerFixture.h
@@ -213,7 +213,7 @@ public:
template <typename...>
void setup(TensorShape src_shape, PoolingLayerInfo pool_info, DataType data_type)
{
- PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW);
+ PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, pool_info.data_layout);
}
};