blob: dbc7574d0e0f914272af41bc380453ed69fad963 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
|
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// See LICENSE file in the project root for full license information.
//
#pragma once
#include "WorkloadData.hpp"
#include "WorkloadInfo.hpp"
#include <algorithm>
#include "Profiling.hpp"
namespace armnn
{
// Workload interface to enqueue a layer computation
class IWorkload
{
public:
virtual ~IWorkload(){};
virtual void Execute() const = 0;
};
// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
// in the various workload factories.
// There should never be an instantiation of a NullWorkload.
class NullWorkload : public IWorkload
{
NullWorkload()=delete;
};
template <typename QueueDescriptor>
class BaseWorkload : public IWorkload
{
public:
BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
: m_Data(descriptor)
{
m_Data.Validate(info);
}
const QueueDescriptor& GetData() const { return m_Data; }
protected:
const QueueDescriptor m_Data;
};
template <typename QueueDescriptor, armnn::DataType DataType>
class TypedWorkload : public BaseWorkload<QueueDescriptor>
{
public:
TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
: BaseWorkload<QueueDescriptor>(descriptor, info)
{
BOOST_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
info.m_InputTensorInfos.end(),
[&](auto it){
return it.GetDataType() == DataType;
}),
"Trying to create workload with incorrect type");
BOOST_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
info.m_OutputTensorInfos.end(),
[&](auto it){
return it.GetDataType() == DataType;
}),
"Trying to create workload with incorrect type");
}
static constexpr armnn::DataType ms_DataType = DataType;
};
template <typename QueueDescriptor>
using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
template <typename QueueDescriptor>
using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QuantisedAsymm8>;
} //namespace armnn
|