aboutsummaryrefslogtreecommitdiff
path: root/include/armnnTestUtils/MockBackend.hpp
blob: 8bc41b3f3f57ff6f530f5771537685ea2706a38e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
//
// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <armnn/backends/IBackendInternal.hpp>
#include <armnn/backends/MemCopyWorkload.hpp>
#include <armnnTestUtils/MockTensorHandle.hpp>

namespace armnn
{

// A bare bones Mock backend to enable unit testing of simple tensor manipulation features.
class MockBackend : public IBackendInternal
{
public:
    MockBackend() = default;

    ~MockBackend() = default;

    static const BackendId& GetIdStatic();

    const BackendId& GetId() const override
    {
        return GetIdStatic();
    }
    IBackendInternal::IWorkloadFactoryPtr
        CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override
    {
        IgnoreUnused(memoryManager);
        return nullptr;
    }

    IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
    {
        return nullptr;
    };
};

class MockWorkloadFactory : public IWorkloadFactory
{

public:
    explicit MockWorkloadFactory(const std::shared_ptr<MockMemoryManager>& memoryManager);
    MockWorkloadFactory();

    ~MockWorkloadFactory()
    {}

    const BackendId& GetBackendId() const override;

    bool SupportsSubTensors() const override
    {
        return false;
    }

    ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
    std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle&,
                                                         TensorShape const&,
                                                         unsigned int const*) const override
    {
        return nullptr;
    }

    ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
    std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
                                                      const bool IsMemoryManaged = true) const override
    {
        IgnoreUnused(IsMemoryManaged);
        return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
    };

    ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
    std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
                                                      DataLayout dataLayout,
                                                      const bool IsMemoryManaged = true) const override
    {
        IgnoreUnused(dataLayout, IsMemoryManaged);
        return std::make_unique<MockTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
    };

    ARMNN_DEPRECATED_MSG_REMOVAL_DATE(
        "Use ABI stable "
        "CreateWorkload(LayerType, const QueueDescriptor&, const WorkloadInfo& info) instead.",
        "22.11")
    std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
                                           const WorkloadInfo& info) const override
    {
        if (info.m_InputTensorInfos.empty())
        {
            throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Input cannot be zero length");
        }
        if (info.m_OutputTensorInfos.empty())
        {
            throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Output cannot be zero length");
        }

        if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
        {
            throw InvalidArgumentException(
                "MockWorkloadFactory::CreateInput: data input and output differ in byte count.");
        }

        return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
    };

    std::unique_ptr<IWorkload>
        CreateWorkload(LayerType type, const QueueDescriptor& descriptor, const WorkloadInfo& info) const override;

private:
    mutable std::shared_ptr<MockMemoryManager> m_MemoryManager;
};

}    // namespace armnn