aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/Workload.hpp
blob: dbc7574d0e0f914272af41bc380453ed69fad963 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// See LICENSE file in the project root for full license information.
//
#pragma once

#include "WorkloadData.hpp"
#include "WorkloadInfo.hpp"
#include <algorithm>
#include "Profiling.hpp"

namespace armnn
{

// Workload interface to enqueue a layer computation
class IWorkload
{
public:
    virtual ~IWorkload(){};

    virtual void Execute() const = 0;
};

// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
// in the various workload factories.
// There should never be an instantiation of a NullWorkload.
class NullWorkload : public IWorkload
{
    NullWorkload()=delete;
};

template <typename QueueDescriptor>
class BaseWorkload : public IWorkload
{
public:

    BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
        : m_Data(descriptor)
    {
        m_Data.Validate(info);
    }

    const QueueDescriptor& GetData() const { return m_Data; }

protected:
    const QueueDescriptor m_Data;
};

template <typename QueueDescriptor, armnn::DataType DataType>
class TypedWorkload : public BaseWorkload<QueueDescriptor>
{
public:

    TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
        : BaseWorkload<QueueDescriptor>(descriptor, info)
    {
        BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
                                     info.m_InputTensorInfos.end(),
                                     [&](auto it){
                                         return it.GetDataType() == DataType;
                                     }),
                         "Trying to create workload with incorrect type");
        BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
                                     info.m_OutputTensorInfos.end(),
                                     [&](auto it){
                                         return it.GetDataType() == DataType;
                                     }),
                         "Trying to create workload with incorrect type");
    }

    static constexpr armnn::DataType ms_DataType = DataType;
};

template <typename QueueDescriptor>
using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;

template <typename QueueDescriptor>
using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;

} //namespace armnn