aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.hpp
blob: 1dd7ec52840545c1d714bac16666c6b2c463033b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include "armnn/INetwork.hpp"
#include "armnnDeserializer/IDeserializer.hpp"
#include <ArmnnSchema_generated.h>

namespace armnnDeserializer
{
class Deserializer : public IDeserializer
{
public:
    // Shorthands for deserializer types
    using ConstTensorRawPtr = const armnnSerializer::ConstTensor *;
    using GraphPtr = const armnnSerializer::SerializedGraph *;
    using TensorRawPtr = const armnnSerializer::TensorInfo *;
    using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *;
    using TensorRawPtrVector = std::vector<TensorRawPtr>;
    using LayerRawPtr = const armnnSerializer::LayerBase *;
    using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
    using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;

public:

    /// Create an input network from binary file contents
    armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) override;

    /// Create an input network from a binary input stream
    armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent) override;

    /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
    BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const override;

    /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
    BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const override;

    Deserializer();
    ~Deserializer() {}

public:
    // testable helpers
    static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len);
    static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex);
    static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex);
    static LayerBaseRawPtrVector GetGraphInputs(const GraphPtr& graphPtr);
    static LayerBaseRawPtrVector GetGraphOutputs(const GraphPtr& graphPtr);
    static LayerBaseRawPtr GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex);
    static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex);
    static std::string GetLayerName(const GraphPtr& graph, unsigned int index);
    armnn::Pooling2dDescriptor GetPoolingDescriptor(PoolingDescriptor pooling2dDescriptor,
                                                    unsigned int layerIndex);
    static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
                                                  const std::vector<uint32_t> & targetDimsIn);

private:
    // No copying allowed until it is wanted and properly implemented
    Deserializer(const Deserializer&) = delete;
    Deserializer& operator=(const Deserializer&) = delete;

    /// Create the network from an already loaded flatbuffers graph
    armnn::INetworkPtr CreateNetworkFromGraph(GraphPtr graph);

    // signature for the parser functions
    using LayerParsingFunction = void(Deserializer::*)(GraphPtr graph, unsigned int layerIndex);

    void ParseUnsupportedLayer(GraphPtr graph, unsigned int layerIndex);
    void ParseActivation(GraphPtr graph, unsigned int layerIndex);
    void ParseAdd(GraphPtr graph, unsigned int layerIndex);
    void ParseBatchToSpaceNd(GraphPtr graph, unsigned int layerIndex);
    void ParseConstant(GraphPtr graph, unsigned int layerIndex);
    void ParseConvolution2d(GraphPtr graph, unsigned int layerIndex);
    void ParseDepthwiseConvolution2d(GraphPtr graph, unsigned int layerIndex);
    void ParseDivision(GraphPtr graph, unsigned int layerIndex);
    void ParseEqual(GraphPtr graph, unsigned int layerIndex);
    void ParseFullyConnected(GraphPtr graph, unsigned int layerIndex);
    void ParseMinimum(GraphPtr graph, unsigned int layerIndex);
    void ParseMaximum(GraphPtr graph, unsigned int layerIndex);
    void ParseMultiplication(GraphPtr graph, unsigned int layerIndex);
    void ParsePermute(GraphPtr graph, unsigned int layerIndex);
    void ParsePooling2d(GraphPtr graph, unsigned int layerIndex);
    void ParseReshape(GraphPtr graph, unsigned int layerIndex);
    void ParseSoftmax(GraphPtr graph, unsigned int layerIndex);
    void ParseSpaceToBatchNd(GraphPtr graph, unsigned int layerIndex);

    void RegisterOutputSlotOfConnection(uint32_t connectionIndex, armnn::IOutputSlot* slot);
    void RegisterInputSlotOfConnection(uint32_t connectionIndex, armnn::IInputSlot* slot);
    void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex,
                            armnn::IConnectableLayer* layer);
    void RegisterOutputSlots(GraphPtr graph, uint32_t layerIndex,
                             armnn::IConnectableLayer* layer);
    void ResetParser();

    void SetupInputLayers(GraphPtr graphPtr);
    void SetupOutputLayers(GraphPtr graphPtr);

    /// The network we're building. Gets cleared after it is passed to the user
    armnn::INetworkPtr                    m_Network;
    std::vector<LayerParsingFunction>     m_ParserFunctions;

    using NameToBindingInfo = std::pair<std::string, BindingPointInfo >;
    std::vector<NameToBindingInfo>    m_InputBindings;
    std::vector<NameToBindingInfo>    m_OutputBindings;

    /// A mapping of an output slot to each of the input slots it should be connected to
    /// The outputSlot is from the layer that creates this tensor as one of its outputs
    /// The inputSlots are from the layers that use this tensor as one of their inputs
    struct Slots
    {
        armnn::IOutputSlot* outputSlot;
        std::vector<armnn::IInputSlot*> inputSlots;

        Slots() : outputSlot(nullptr) { }
    };
    typedef std::vector<Slots> Connection;
    std::vector<Connection>   m_GraphConnections;
};

} //namespace armnnDeserializer