aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMoritz Pflanzer <moritz.pflanzer@arm.com>2017-09-12 15:51:33 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commit5b61fd3fbaf41031232296abde56258d12ba3340 (patch)
treea44cc3071d4d7b91480cb672c3ed4536857bb4e6
parenta3adb3a3bdce1f2ef764c5d5098e99695323f0a3 (diff)
downloadComputeLibrary-5b61fd3fbaf41031232296abde56258d12ba3340.tar.gz
COMPMID-417: Fix validation
Change-Id: I7a745037136bc6e02d177f65fe4f4cd43873b98e Reviewed-on: http://mpd-gerrit.cambridge.arm.com/87406 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
-rw-r--r--tests/framework/Framework.cpp18
-rw-r--r--tests/framework/printers/JSONPrinter.cpp66
-rw-r--r--tests/framework/printers/JSONPrinter.h11
-rw-r--r--tests/framework/printers/PrettyPrinter.cpp5
-rw-r--r--tests/framework/printers/PrettyPrinter.h2
-rw-r--r--tests/framework/printers/Printer.h5
-rw-r--r--tests/validation/CL/SoftmaxLayer.cpp12
-rw-r--r--tests/validation/NEON/SoftmaxLayer.cpp12
-rw-r--r--tests/validation/Validation.cpp4
-rw-r--r--tests/validation/Validation.h16
10 files changed, 101 insertions, 50 deletions
diff --git a/tests/framework/Framework.cpp b/tests/framework/Framework.cpp
index 343b7a8561..31e524338b 100644
--- a/tests/framework/Framework.cpp
+++ b/tests/framework/Framework.cpp
@@ -164,7 +164,8 @@ void Framework::log_failed_expectation(const TestError &error)
{
if(_log_level >= error.level() && _printer != nullptr)
{
- _printer->print_error(error);
+ constexpr bool expected_error = true;
+ _printer->print_error(error, expected_error);
}
if(_current_test_result != nullptr)
@@ -232,6 +233,8 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory)
_printer->print_errors_header();
}
+ const bool is_expected_failure = test_factory.status() == TestCaseFactory::Status::EXPECTED_FAILURE;
+
try
{
std::unique_ptr<TestCase> test_case = test_factory.make();
@@ -265,7 +268,7 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory)
{
if(_log_level >= error.level() && _printer != nullptr)
{
- _printer->print_error(error);
+ _printer->print_error(error, is_expected_failure);
}
result.status = TestResult::Status::FAILED;
@@ -282,7 +285,8 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory)
{
std::stringstream stream;
stream << "Error code: " << error.err();
- _printer->print_error(TestError(error.what(), LogLevel::ERRORS, stream.str()));
+ TestError test_error(error.what(), LogLevel::ERRORS, stream.str());
+ _printer->print_error(test_error, is_expected_failure);
}
result.status = TestResult::Status::FAILED;
@@ -297,7 +301,7 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory)
{
if(_log_level >= LogLevel::ERRORS && _printer != nullptr)
{
- _printer->print_error(error);
+ _printer->print_error(error, is_expected_failure);
}
result.status = TestResult::Status::CRASHED;
@@ -311,7 +315,7 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory)
{
if(_log_level >= LogLevel::ERRORS && _printer != nullptr)
{
- _printer->print_error(TestError("Received unknown exception"));
+ _printer->print_error(TestError("Received unknown exception"), is_expected_failure);
}
result.status = TestResult::Status::CRASHED;
@@ -326,7 +330,7 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory)
{
if(_log_level >= LogLevel::ERRORS && _printer != nullptr)
{
- _printer->print_error(error);
+ _printer->print_error(error, is_expected_failure);
}
result.status = TestResult::Status::CRASHED;
@@ -340,7 +344,7 @@ void Framework::run_test(const TestInfo &info, TestCaseFactory &test_factory)
{
if(_log_level >= LogLevel::ERRORS && _printer != nullptr)
{
- _printer->print_error(TestError("Received unknown exception"));
+ _printer->print_error(TestError("Received unknown exception"), is_expected_failure);
}
result.status = TestResult::Status::CRASHED;
diff --git a/tests/framework/printers/JSONPrinter.cpp b/tests/framework/printers/JSONPrinter.cpp
index ae19cae67c..4f17e6277b 100644
--- a/tests/framework/printers/JSONPrinter.cpp
+++ b/tests/framework/printers/JSONPrinter.cpp
@@ -46,6 +46,31 @@ void JSONPrinter::print_separator(bool &flag)
}
}
+template <typename T>
+void JSONPrinter::print_strings(T &&first, T &&last)
+{
+ bool first_entry = true;
+ std::stringstream log;
+
+ while(first != last)
+ {
+ print_separator(first_entry);
+
+ *_stream << R"(")";
+
+ log.str(*first);
+
+ for(std::string line; !std::getline(log, line).fail();)
+ {
+ *_stream << line << "; ";
+ }
+
+ *_stream << R"(")";
+
+ ++first;
+ }
+}
+
void JSONPrinter::print_entry(const std::string &name, const std::string &value)
{
print_separator(_first_entry);
@@ -90,38 +115,43 @@ void JSONPrinter::print_test_footer()
void JSONPrinter::print_errors_header()
{
- print_separator(_first_test_entry);
-
- _first_error = true;
- *_stream << R"("errors" : [)";
+ _errors.clear();
+ _expected_errors.clear();
+ _infos.clear();
}
void JSONPrinter::print_errors_footer()
{
+ print_separator(_first_test_entry);
+
+ *_stream << R"("errors" : [)";
+ print_strings(_errors.begin(), _errors.end());
+ *_stream << "]";
+
+ *_stream << R"(, "expected_errors" : [)";
+ print_strings(_expected_errors.begin(), _expected_errors.end());
+ *_stream << "]";
+
+ *_stream << R"(, "infos" : [)";
+ print_strings(_infos.begin(), _infos.end());
*_stream << "]";
}
-void JSONPrinter::print_error(const std::exception &error)
+void JSONPrinter::print_error(const std::exception &error, bool expected)
{
- std::stringstream error_log;
- error_log.str(error.what());
-
- for(std::string line; !std::getline(error_log, line).fail();)
+ if(expected)
{
- print_separator(_first_error);
-
- *_stream << R"(")" << line << R"(")";
+ _expected_errors.emplace_back(error.what());
+ }
+ else
+ {
+ _errors.emplace_back(error.what());
}
}
void JSONPrinter::print_info(const std::string &info)
{
- std::istringstream iss(info);
- for(std::string line; !std::getline(iss, line).fail();)
- {
- print_separator(_first_error);
- *_stream << R"(")" << line << R"(")";
- }
+ _infos.push_back(info);
}
void JSONPrinter::print_measurements(const Profiler::MeasurementsMap &measurements)
diff --git a/tests/framework/printers/JSONPrinter.h b/tests/framework/printers/JSONPrinter.h
index 18bd4380b0..a2811ea41b 100644
--- a/tests/framework/printers/JSONPrinter.h
+++ b/tests/framework/printers/JSONPrinter.h
@@ -26,6 +26,8 @@
#include "Printer.h"
+#include <list>
+
namespace arm_compute
{
namespace test
@@ -47,17 +49,22 @@ public:
void print_test_footer() override;
void print_errors_header() override;
void print_errors_footer() override;
- void print_error(const std::exception &error) override;
+ void print_error(const std::exception &error, bool expected) override;
void print_info(const std::string &info) override;
void print_measurements(const Profiler::MeasurementsMap &measurements) override;
private:
void print_separator(bool &flag);
+ template <typename T>
+ void print_strings(T &&first, T &&last);
+
+ std::list<std::string> _infos{};
+ std::list<std::string> _errors{};
+ std::list<std::string> _expected_errors{};
bool _first_entry{ true };
bool _first_test{ true };
bool _first_test_entry{ true };
- bool _first_error{ true };
};
} // namespace framework
} // namespace test
diff --git a/tests/framework/printers/PrettyPrinter.cpp b/tests/framework/printers/PrettyPrinter.cpp
index 2f7df1837a..5eec72a2fe 100644
--- a/tests/framework/printers/PrettyPrinter.cpp
+++ b/tests/framework/printers/PrettyPrinter.cpp
@@ -102,9 +102,10 @@ void PrettyPrinter::print_info(const std::string &info)
*_stream << begin_color("1") << "INFO: " << info << end_color() << "\n";
}
-void PrettyPrinter::print_error(const std::exception &error)
+void PrettyPrinter::print_error(const std::exception &error, bool expected)
{
- *_stream << begin_color("1") << "ERROR: " << error.what() << end_color() << "\n";
+ std::string prefix = expected ? "EXPECTED ERROR: " : "ERROR: ";
+ *_stream << begin_color("1") << prefix << error.what() << end_color() << "\n";
}
void PrettyPrinter::print_measurements(const Profiler::MeasurementsMap &measurements)
diff --git a/tests/framework/printers/PrettyPrinter.h b/tests/framework/printers/PrettyPrinter.h
index 3e2bebd1f7..f72a613868 100644
--- a/tests/framework/printers/PrettyPrinter.h
+++ b/tests/framework/printers/PrettyPrinter.h
@@ -53,7 +53,7 @@ public:
void print_test_footer() override;
void print_errors_header() override;
void print_errors_footer() override;
- void print_error(const std::exception &error) override;
+ void print_error(const std::exception &error, bool expected) override;
void print_info(const std::string &info) override;
void print_measurements(const Profiler::MeasurementsMap &measurements) override;
diff --git a/tests/framework/printers/Printer.h b/tests/framework/printers/Printer.h
index 16a4170d7a..c2a44240ba 100644
--- a/tests/framework/printers/Printer.h
+++ b/tests/framework/printers/Printer.h
@@ -104,9 +104,10 @@ public:
/** Print test error.
*
- * @param[in] error Description of the error.
+ * @param[in] error Description of the error.
+ * @param[in] expected Whether the error was expected or not.
*/
- virtual void print_error(const std::exception &error) = 0;
+ virtual void print_error(const std::exception &error, bool expected) = 0;
/** Print test log info.
*
diff --git a/tests/validation/CL/SoftmaxLayer.cpp b/tests/validation/CL/SoftmaxLayer.cpp
index 6a22eb1bcc..8c143ecd96 100644
--- a/tests/validation/CL/SoftmaxLayer.cpp
+++ b/tests/validation/CL/SoftmaxLayer.cpp
@@ -48,7 +48,7 @@ RelativeTolerance<half_float::half> tolerance_f16(half_float::half(0.2));
RelativeTolerance<float> tolerance_f32(0.001f);
/** Tolerance for fixed point operations */
-constexpr AbsoluteTolerance<int8_t> tolerance_fixed_point(2);
+constexpr AbsoluteTolerance<int16_t> tolerance_fixed_point(2);
/** CNN data types */
const auto CNNDataTypes = framework::dataset::make("DataType",
@@ -145,15 +145,17 @@ TEST_SUITE_END()
TEST_SUITE(QS16)
// Testing for fixed point position [1,14) as reciprocal limits the maximum fixed point position to 14
-FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
- DataType::QS16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataType",
+ DataType::QS16)),
framework::dataset::make("FractionalBits", 1, 14)))
{
// Validate output
validate(CLAccessor(_target), _reference, tolerance_fixed_point);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
- DataType::QS16)),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLSoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataType",
+ DataType::QS16)),
framework::dataset::make("FractionalBits", 1, 14)))
{
// Validate output
diff --git a/tests/validation/NEON/SoftmaxLayer.cpp b/tests/validation/NEON/SoftmaxLayer.cpp
index 36f1881147..7ac7759c22 100644
--- a/tests/validation/NEON/SoftmaxLayer.cpp
+++ b/tests/validation/NEON/SoftmaxLayer.cpp
@@ -49,7 +49,7 @@ constexpr AbsoluteTolerance<float> tolerance_f32(0.000001f);
constexpr AbsoluteTolerance<float> tolerance_f16(0.0001f);
#endif /* ARM_COMPUTE_ENABLE_FP16*/
/** Tolerance for fixed point operations */
-constexpr AbsoluteTolerance<int8_t> tolerance_fixed_point(2);
+constexpr AbsoluteTolerance<int16_t> tolerance_fixed_point(2);
/** CNN data types */
const auto CNNDataTypes = framework::dataset::make("DataType",
@@ -151,15 +151,17 @@ TEST_SUITE_END()
TEST_SUITE(QS16)
// Testing for fixed point position [1,14) as reciprocal limits the maximum fixed point position to 14
-FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
- DataType::QS16)),
+FIXTURE_DATA_TEST_CASE(RunSmall, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(),
+ framework::dataset::make("DataType",
+ DataType::QS16)),
framework::dataset::make("FractionalBits", 1, 14)))
{
// Validate output
validate(Accessor(_target), _reference, tolerance_fixed_point);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
- DataType::QS16)),
+FIXTURE_DATA_TEST_CASE(RunLarge, NESoftmaxLayerFixedPointFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(),
+ framework::dataset::make("DataType",
+ DataType::QS16)),
framework::dataset::make("FractionalBits", 1, 14)))
{
// Validate output
diff --git a/tests/validation/Validation.cpp b/tests/validation/Validation.cpp
index 690c4eac9e..1a082111a9 100644
--- a/tests/validation/Validation.cpp
+++ b/tests/validation/Validation.cpp
@@ -130,7 +130,7 @@ void check_border_element(const IAccessor &tensor, const Coordinates &id,
const double target = get_double_data(ptr + channel_offset, tensor.data_type());
const double reference = get_double_data(static_cast<const uint8_t *>(border_value) + channel_offset, tensor.data_type());
- if(!compare<AbsoluteTolerance<double>, double>(target, reference))
+ if(!compare<AbsoluteTolerance<double>>(target, reference))
{
ARM_COMPUTE_TEST_INFO("id = " << id);
ARM_COMPUTE_TEST_INFO("channel = " << channel);
@@ -192,7 +192,7 @@ void validate(const IAccessor &tensor, const void *reference_value)
const double target = get_double_data(ptr + channel_offset, tensor.data_type());
const double reference = get_double_data(reference_value, tensor.data_type());
- if(!compare<AbsoluteTolerance<double>, double>(target, reference))
+ if(!compare<AbsoluteTolerance<double>>(target, reference))
{
ARM_COMPUTE_TEST_INFO("id = " << id);
ARM_COMPUTE_TEST_INFO("channel = " << channel);
diff --git a/tests/validation/Validation.h b/tests/validation/Validation.h
index e70c970cc1..6bc42a4ed6 100644
--- a/tests/validation/Validation.h
+++ b/tests/validation/Validation.h
@@ -226,11 +226,11 @@ struct compare_base
T _tolerance{};
};
-template <typename T, typename U>
+template <typename T>
struct compare;
template <typename U>
-struct compare<AbsoluteTolerance<U>, U> : public compare_base<AbsoluteTolerance<U>>
+struct compare<AbsoluteTolerance<U>> : public compare_base<AbsoluteTolerance<U>>
{
using compare_base<AbsoluteTolerance<U>>::compare_base;
@@ -245,12 +245,16 @@ struct compare<AbsoluteTolerance<U>, U> : public compare_base<AbsoluteTolerance<
return true;
}
- return static_cast<U>(std::abs(this->_target - this->_reference)) <= static_cast<U>(this->_tolerance);
+ using comparison_type = typename std::conditional<std::is_integral<U>::value, int64_t, U>::type;
+
+ const comparison_type abs_difference(std::abs(static_cast<comparison_type>(this->_target) - static_cast<comparison_type>(this->_reference)));
+
+ return abs_difference <= static_cast<comparison_type>(this->_tolerance);
}
};
template <typename U>
-struct compare<RelativeTolerance<U>, U> : public compare_base<RelativeTolerance<U>>
+struct compare<RelativeTolerance<U>> : public compare_base<RelativeTolerance<U>>
{
using compare_base<RelativeTolerance<U>>::compare_base;
@@ -325,7 +329,7 @@ void validate(const IAccessor &tensor, const SimpleTensor<T> &reference, const V
const T &target_value = reinterpret_cast<const T *>(tensor(id))[c];
const T &reference_value = reinterpret_cast<const T *>(reference(id))[c];
- if(!compare<U, typename U::value_type>(target_value, reference_value, tolerance_value))
+ if(!compare<U>(target_value, reference_value, tolerance_value))
{
ARM_COMPUTE_TEST_INFO("id = " << id);
ARM_COMPUTE_TEST_INFO("channel = " << c);
@@ -359,7 +363,7 @@ void validate(T target, T reference, U tolerance)
ARM_COMPUTE_TEST_INFO("reference = " << std::setprecision(5) << framework::make_printable(reference));
ARM_COMPUTE_TEST_INFO("target = " << std::setprecision(5) << framework::make_printable(target));
ARM_COMPUTE_TEST_INFO("tolerance = " << std::setprecision(5) << framework::make_printable(static_cast<typename U::value_type>(tolerance)));
- ARM_COMPUTE_EXPECT((compare<U, typename U::value_type>(target, reference, tolerance)), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT((compare<U>(target, reference, tolerance)), framework::LogLevel::ERRORS);
}
} // namespace validation
} // namespace test