aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/TensorIOUtils.hpp
blob: 55dd3428b82b06657de70434f00e9f2f12fe0acc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <armnn/Tensor.hpp>

#include <fmt/format.h>
#include <mapbox/variant.hpp>

namespace armnnUtils
{

template<typename TContainer>
inline armnn::InputTensors MakeInputTensors(const std::vector<armnn::BindingPointInfo>& inputBindings,
                                            const std::vector<TContainer>& inputDataContainers)
{
    armnn::InputTensors inputTensors;

    const size_t numInputs = inputBindings.size();
    if (numInputs != inputDataContainers.size())
    {
        throw armnn::Exception(fmt::format("The number of inputs does not match number of "
                                           "tensor data containers: {0} != {1}",
                                           numInputs,
                                           inputDataContainers.size()));
    }

    for (size_t i = 0; i < numInputs; i++)
    {
        const armnn::BindingPointInfo& inputBinding = inputBindings[i];
        const TContainer& inputData = inputDataContainers[i];

        mapbox::util::apply_visitor([&](auto&& value)
        {
            if (value.size() != inputBinding.second.GetNumElements())
            {
               throw armnn::Exception(fmt::format("The input tensor has incorrect size (expected {0} got {1})",
                                                  inputBinding.second.GetNumElements(),
                                                  value.size()));
            }
            armnn::TensorInfo inputTensorInfo = inputBinding.second;
            inputTensorInfo.SetConstant(true);
            armnn::ConstTensor inputTensor(inputTensorInfo, value.data());
            inputTensors.push_back(std::make_pair(inputBinding.first, inputTensor));
        },
        inputData);
    }

    return inputTensors;
}

template<typename TContainer>
inline armnn::OutputTensors MakeOutputTensors(const std::vector<armnn::BindingPointInfo>& outputBindings,
                                              std::vector<TContainer>& outputDataContainers)
{
    armnn::OutputTensors outputTensors;

    const size_t numOutputs = outputBindings.size();
    if (numOutputs != outputDataContainers.size())
    {
        throw armnn::Exception(fmt::format("Number of outputs does not match number"
                                           "of tensor data containers: {0} != {1}",
                                           numOutputs,
                                           outputDataContainers.size()));
    }

    for (size_t i = 0; i < numOutputs; i++)
    {
        const armnn::BindingPointInfo& outputBinding = outputBindings[i];
        TContainer& outputData = outputDataContainers[i];

        mapbox::util::apply_visitor([&](auto&& value)
        {
            if (value.size() != outputBinding.second.GetNumElements())
            {
                throw armnn::Exception("Output tensor has incorrect size");
            }

            armnn::Tensor outputTensor(outputBinding.second, value.data());
            outputTensors.push_back(std::make_pair(outputBinding.first, outputTensor));
        },
        outputData);
    }

    return outputTensors;
}

} // namespace armnnUtils