aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/validation/Validation.h23
-rw-r--r--tests/validation/fixtures/PoolingLayerFixture.h2
2 files changed, 18 insertions, 7 deletions
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h
index 638a1c20ee..4f3f92da24 100644
--- a/tests/validation/Validation.h
+++ b/tests/validation/Validation.h
@@ -45,6 +45,17 @@ namespace test
{
namespace validation
{
+namespace
+{
+// Compare if 2 values are both infinities and if they are "equal" (has the same sign)
+template <typename T>
+bool are_equal_infs(T val0, T val1)
+{
+ const auto same_sign = std::signbit(val0) == std::signbit(val1);
+ return (!support::cpp11::isfinite(val0)) && (!support::cpp11::isfinite(val1)) && same_sign;
+}
+} // namespace
+
/** Class reprensenting an absolute tolerance value. */
template <typename T>
class AbsoluteTolerance
@@ -296,9 +307,9 @@ struct compare<AbsoluteTolerance<U>> : public compare_base<AbsoluteTolerance<U>>
/** Perform comparison */
operator bool() const
{
- if(!support::cpp11::isfinite(this->_target) || !support::cpp11::isfinite(this->_reference))
+ if(are_equal_infs(this->_target, this->_reference))
{
- return false;
+ return true;
}
else if(this->_target == this->_reference)
{
@@ -322,9 +333,9 @@ struct compare<RelativeTolerance<U>> : public compare_base<RelativeTolerance<U>>
/** Perform comparison */
operator bool() const
{
- if(!support::cpp11::isfinite(this->_target) || !support::cpp11::isfinite(this->_reference))
+ if(are_equal_infs(this->_target, this->_reference))
{
- return false;
+ return true;
}
else if(this->_target == this->_reference)
{
@@ -494,9 +505,9 @@ void validate_wrap(const IAccessor &tensor, const SimpleTensor<T> &reference, co
// check for wrapping
if(!equal)
{
- if(!support::cpp11::isfinite(target_value) || !support::cpp11::isfinite(reference_value))
+ if(are_equal_infs(target_value, reference_value))
{
- equal = false;
+ equal = true;
}
else
{
diff --git a/tests/validation/fixtures/PoolingLayerFixture.h b/tests/validation/fixtures/PoolingLayerFixture.h
index ec4e9f80dd..6e9edfbb5d 100644
--- a/tests/validation/fixtures/PoolingLayerFixture.h
+++ b/tests/validation/fixtures/PoolingLayerFixture.h
@@ -213,7 +213,7 @@ public:
template <typename...>
void setup(TensorShape src_shape, PoolingLayerInfo pool_info, DataType data_type)
{
- PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, DataLayout::NCHW);
+ PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, pool_info.data_layout);
}
};