aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/test/ClBackendTests.cpp
blob: 33f321653cc68d4961306a73dea44b1e2a4d1c11 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
//
// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <cl/ClBackend.hpp>
#include <cl/ClTensorHandleFactory.hpp>
#include <cl/ClImportTensorHandleFactory.hpp>
#include <cl/test/ClContextControlFixture.hpp>

#include <doctest/doctest.h>

using namespace armnn;

TEST_SUITE("ClBackendTests")
{
TEST_CASE("ClRegisterTensorHandleFactoriesMatchingImportFactoryId")
{
    auto clBackend = std::make_unique<ClBackend>();
    TensorHandleFactoryRegistry registry;
    clBackend->RegisterTensorHandleFactories(registry);

    // When calling RegisterTensorHandleFactories, CopyAndImportFactoryPair is registered
    // Get ClImportTensorHandleFactory id as the matching import factory id
    CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
           ClImportTensorHandleFactory::GetIdStatic()));
}

TEST_CASE("ClRegisterTensorHandleFactoriesWithMemorySourceFlagsMatchingImportFactoryId")
{
    auto clBackend = std::make_unique<ClBackend>();
    TensorHandleFactoryRegistry registry;
    clBackend->RegisterTensorHandleFactories(registry,
                                             static_cast<MemorySourceFlags>(MemorySource::Malloc),
                                             static_cast<MemorySourceFlags>(MemorySource::Malloc));

    // When calling RegisterTensorHandleFactories with MemorySourceFlags, CopyAndImportFactoryPair is registered
    // Get ClImportTensorHandleFactory id as the matching import factory id
    CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
           ClImportTensorHandleFactory::GetIdStatic()));
}

TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryMatchingImportFactoryId")
{
    auto clBackend = std::make_unique<ClBackend>();
    TensorHandleFactoryRegistry registry;
    clBackend->CreateWorkloadFactory(registry);

    // When calling CreateWorkloadFactory, CopyAndImportFactoryPair is registered
    // Get ClImportTensorHandleFactory id as the matching import factory id
    CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
           ClImportTensorHandleFactory::GetIdStatic()));
}

TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWithOptionsMatchingImportFactoryId")
{
    auto clBackend = std::make_unique<ClBackend>();
    TensorHandleFactoryRegistry registry;
    ModelOptions modelOptions;
    clBackend->CreateWorkloadFactory(registry, modelOptions);

    // When calling CreateWorkloadFactory with ModelOptions, CopyAndImportFactoryPair is registered
    // Get ClImportTensorHandleFactory id as the matching import factory id
    CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
           ClImportTensorHandleFactory::GetIdStatic()));
}

TEST_CASE_FIXTURE(ClContextControlFixture, "ClCreateWorkloadFactoryWitMemoryFlagsMatchingImportFactoryId")
{
    auto clBackend = std::make_unique<ClBackend>();
    TensorHandleFactoryRegistry registry;
    ModelOptions modelOptions;
    clBackend->CreateWorkloadFactory(registry, modelOptions,
                                     static_cast<MemorySourceFlags>(MemorySource::Malloc),
                                     static_cast<MemorySourceFlags>(MemorySource::Malloc));

    // When calling CreateWorkloadFactory with ModelOptions and MemorySourceFlags,
    // CopyAndImportFactoryPair is registered
    // Get ClImportTensorHandleFactory id as the matching import factory id
    CHECK((registry.GetMatchingImportFactoryId(ClTensorHandleFactory::GetIdStatic()) ==
           ClImportTensorHandleFactory::GetIdStatic()));
}
}