aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/NeonInterceptorScheduler.cpp
blob: fc95ef439e7b9cf2d9ff15113c3f9dfb1eb17ded (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// See LICENSE file in the project root for full license information.
//

#include "NeonInterceptorScheduler.hpp"

#include <boost/assert.hpp>

namespace armnn{

NeonInterceptorScheduler::NeonInterceptorScheduler(NeonTimer::KernelMeasurements& kernels,
                                                   arm_compute::IScheduler &realScheduler)
        : m_Kernels(kernels), m_RealScheduler(realScheduler)
{
}

void NeonInterceptorScheduler::set_num_threads(unsigned int numThreads)
{
    m_RealScheduler.set_num_threads(numThreads);
}

unsigned int NeonInterceptorScheduler::num_threads() const
{
    return m_RealScheduler.num_threads();
}

void NeonInterceptorScheduler::schedule(arm_compute::ICPPKernel* kernel, const Hints& hints)
{
    m_Timer.Start();
    m_RealScheduler.schedule(kernel, hints.split_dimension());
    m_Timer.Stop();

    std::vector<Measurement> measurements = m_Timer.GetMeasurements();
    BOOST_ASSERT(!measurements.empty());

    Measurement measurement(measurements.front()); // NOTE: 1st measurement is delta
    measurement.m_Name = kernel->name();
    m_Kernels.push_back(std::move(measurement));
}

void NeonInterceptorScheduler::run_workloads(std::vector <Workload>& workloads)
{
    m_Timer.Start();
    m_RealScheduler.run_workloads(workloads);
    m_Timer.Stop();

    std::vector<Measurement> measurements = m_Timer.GetMeasurements();
    BOOST_ASSERT_MSG(measurements.size() == 3, "WallClockTimer does not have correct amount of measurements.");

    // WallClockTimer has 3 measurements, duration always being the first.
    Measurement measurement(measurements.front());
    measurement.m_Name = "Workload";
    m_Kernels.push_back(std::move(measurement));
}

} // namespace armnn