aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/UnitTests.hpp
blob: 9b750b5b33c7ac7127621c1c05571b3b22e81520 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// See LICENSE file in the project root for full license information.
//
#pragma once

#include "Logging.hpp"
#include "armnn/Utils.hpp"
#include "backends/RefWorkloadFactory.hpp"
#include "backends/test/LayerTests.hpp"
#include <boost/test/unit_test.hpp>

inline void ConfigureLoggingTest()
{
    // Configure logging for both the ARMNN library and this test program
    armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal);
    armnnUtils::ConfigureLogging(boost::log::core::get().get(), true, true, armnn::LogSeverity::Fatal);
}

// The following macros require the caller to have defined FactoryType, with one of the following using statements:
//
//      using FactoryType = armnn::RefWorkloadFactory;
//      using FactoryType = armnn::ClWorkloadFactory;
//      using FactoryType = armnn::NeonWorkloadFactory;

/// Executes BOOST_TEST on CompareTensors() return value so that the predicate_result message is reported.
/// If the test reports itself as not supported then the tensors are not compared.
/// Additionally this checks that the supportedness reported by the test matches the name of the test.
/// Unsupported tests must be 'tagged' by including "UNSUPPORTED" in their name.
/// This is useful because it clarifies that the feature being tested is not actually supported
/// (a passed test with the name of a feature would imply that feature was supported).
/// If support is added for a feature, the test case will fail because the name incorrectly contains UNSUPPORTED.
/// If support is removed for a feature, the test case will fail because the name doesn't contain UNSUPPORTED.
template <typename T, std::size_t n>
void CompareTestResultIfSupported(const std::string& testName, const LayerTestResult<T, n>& testResult)
{
    bool testNameIndicatesUnsupported = testName.find("UNSUPPORTED") != std::string::npos;
    BOOST_CHECK_MESSAGE(testNameIndicatesUnsupported != testResult.supported,
        "The test name does not match the supportedness it is reporting");
    if (testResult.supported)
    {
        BOOST_TEST(CompareTensors(testResult.output, testResult.outputExpected));
    }
}

template<typename FactoryType, typename TFuncPtr, typename... Args>
void RunTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
{
    FactoryType workloadFactory;
    auto testResult = (*testFunction)(workloadFactory, args...);
    CompareTestResultIfSupported(testName, testResult);
}

#define ARMNN_AUTO_TEST_CASE(TestName, TestFunction, ...) \
    BOOST_AUTO_TEST_CASE(TestName) \
    { \
        RunTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
    }

template<typename FactoryType, typename TFuncPtr, typename... Args>
void CompareRefTestFunction(const char* testName, TFuncPtr testFunction, Args... args)
{
    FactoryType workloadFactory;
    armnn::RefWorkloadFactory refWorkloadFactory;
    auto testResult = (*testFunction)(workloadFactory, refWorkloadFactory, args...);
    CompareTestResultIfSupported(testName, testResult);
}

#define ARMNN_COMPARE_REF_AUTO_TEST_CASE(TestName, TestFunction, ...) \
    BOOST_AUTO_TEST_CASE(TestName) \
    { \
        CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
    }

#define ARMNN_COMPARE_REF_FIXTURE_TEST_CASE(TestName, Fixture, TestFunction, ...) \
    BOOST_FIXTURE_TEST_CASE(TestName, Fixture) \
    { \
        CompareRefTestFunction<FactoryType>(#TestName, &TestFunction, ##__VA_ARGS__); \
    }