aboutsummaryrefslogtreecommitdiff
path: root/1.0/FullyConnected.hpp
blob: 0fb029dea01b6340baae34d04b4e62cf6a741267 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

#include <armnn/Tensor.hpp>

#include "../ConversionUtils.hpp"

namespace armnn_driver
{

inline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape &inputShape,
                                                     const armnn::TensorShape &weightsShape)
{
    if (inputShape.GetNumDimensions() > 2U)
    {
        unsigned int dim0 = inputShape[0];
        unsigned int dim1 = inputShape[1];

        for (unsigned int i = 2U; i < inputShape.GetNumDimensions(); ++i)
        {
            dim1 *= inputShape[i];
        }

        unsigned int divisor = weightsShape[1] / dim1;
        if(dim0 % divisor != 0)
        {
            throw std::runtime_error("Failed to deduce tensor shape");
        }

        return armnn::TensorShape({dim0 / divisor, dim1 * divisor});
    }
    else
    {
        return inputShape;
    }
}

}