aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/ClContextDeserializer.cpp
blob: 35a8afafad69745663df2003e4ff3e4f60cfc785 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
//
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include "ClContextDeserializer.hpp"
#include "ClContextSchema_generated.h"

#include <armnn/Exceptions.hpp>
#include <armnn/utility/NumericCast.hpp>

#include <flatbuffers/flexbuffers.h>

#include <fmt/format.h>

#include <cstdlib>
#include <fstream>
#include <iostream>
#include <vector>

namespace armnn
{

void ClContextDeserializer::Deserialize(arm_compute::CLCompileContext& clCompileContext,
                                        cl::Context& context,
                                        cl::Device& device,
                                        const std::string& filePath)
{
    std::ifstream inputFileStream(filePath, std::ios::binary);
    std::vector<std::uint8_t> binaryContent;
    while (inputFileStream)
    {
        char input;
        inputFileStream.get(input);
        if (inputFileStream)
        {
            binaryContent.push_back(static_cast<std::uint8_t>(input));
        }
    }
    inputFileStream.close();
    DeserializeFromBinary(clCompileContext, context, device, binaryContent);
}

void ClContextDeserializer::DeserializeFromBinary(arm_compute::CLCompileContext& clCompileContext,
                                                  cl::Context& context,
                                                  cl::Device& device,
                                                  const std::vector<uint8_t>& binaryContent)
{
    if (binaryContent.data() == nullptr)
    {
        throw InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
                                                   CHECK_LOCATION().AsString()));
    }

    size_t binaryContentSize = binaryContent.size();
    flatbuffers::Verifier verifier(binaryContent.data(), binaryContentSize);
    if (verifier.VerifyBuffer<ClContext>() == false)
    {
        throw ParseException(fmt::format("Buffer doesn't conform to the expected Armnn "
                                         "flatbuffers format. size:{0} {1}",
                                         binaryContentSize,
                                         CHECK_LOCATION().AsString()));
    }
    auto clContext = GetClContext(binaryContent.data());

    for (Program const* program : *clContext->programs())
    {
        const char* volatile programName = program->name()->c_str();
        auto programBinary = program->binary();
        std::vector<uint8_t> binary(programBinary->begin(), programBinary->begin() + programBinary->size());

        cl::Program::Binaries   binaries{ binary };
        std::vector<cl::Device> devices {device};
        cl::Program             theProgram(context, devices, binaries);
        theProgram.build();
        clCompileContext.add_built_program(programName, theProgram);
    }
}

} // namespace armnn