aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/BatchMatMulLayer.cpp
blob: cafb051c7b1409423efe417bd04b453ad90b3c3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//
// Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include "BatchMatMulLayer.hpp"

#include <armnn/backends/WorkloadFactory.hpp>
#include <armnnUtils/Permute.hpp>
#include "layers/LayerCloneBase.hpp"

namespace armnn
{

BatchMatMulLayer::BatchMatMulLayer(const BatchMatMulDescriptor& param, const char* name)
    : LayerWithParameters(2, 1, LayerType::BatchMatMul, param, name)
{}

std::unique_ptr<IWorkload> BatchMatMulLayer::CreateWorkload(const IWorkloadFactory& factory) const
{
    BatchMatMulQueueDescriptor descriptor;
    SetAdditionalInfo(descriptor);

    return factory.CreateWorkload(LayerType::BatchMatMul, descriptor, PrepInfoAndDesc(descriptor));
}

BatchMatMulLayer* BatchMatMulLayer::Clone(Graph& graph) const
{
    auto layer = CloneBase<BatchMatMulLayer>(graph, m_Param, GetName());

    return std::move(layer);
}

std::vector<TensorShape> BatchMatMulLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
    if (inputShapes.size() != 2)
    {
        throw armnn::LayerValidationException("inputShapes' size is \"" + std::to_string(inputShapes.size()) +
                                              "\" - should be \"2\".");
    }

    TensorShape inputXShape = inputShapes[0];
    TensorShape inputYShape = inputShapes[1];

    // Adjoint is assumed to be square, but we will apply the permute anyway
    if(m_Param.m_TransposeX || m_Param.m_AdjointX)
    {
        auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutX,
                                                               inputXShape);
        inputXShape = armnnUtils::Permuted(inputXShape, permuteVec);
    }
    if(m_Param.m_TransposeY || m_Param.m_AdjointY)
    {
        auto permuteVec = BatchMatMulDescriptor::GetPermuteVec(m_Param.m_DataLayoutY,
                                                               inputYShape);
        inputYShape = armnnUtils::Permuted(inputYShape, permuteVec);
    }

    TensorShape& longerInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
                               inputXShape : inputYShape;
    TensorShape& shorterInput = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
                                inputYShape : inputXShape;

    unsigned int inputNumDimsOffset = longerInput.GetNumDimensions() - shorterInput.GetNumDimensions();

    unsigned int outputNumDimensions = longerInput.GetNumDimensions();

    std::vector<unsigned int> tensorDimensions(outputNumDimensions, 0);

    const auto& longerInputDataLayout = inputXShape.GetNumDimensions() >= inputYShape.GetNumDimensions()?
                                        m_Param.m_DataLayoutX : m_Param.m_DataLayoutY;
    auto longerAxesToMul = BatchMatMulDescriptor::GetAxesToMul(longerInputDataLayout,
                                                               longerInput);

    for (unsigned int i = 0; i < outputNumDimensions; ++i)
    {
        if (i == longerAxesToMul.first)
        {
            tensorDimensions[i] = &shorterInput == &inputXShape ? inputXShape[i - inputNumDimsOffset] : inputXShape[i];
        }
        else if(i == longerAxesToMul.second)
        {
            tensorDimensions[i] = &shorterInput == &inputYShape ? inputYShape[i - inputNumDimsOffset] : inputYShape[i];
        }
        else // The other dimensions not to be multiplied (but may be broadcasted)
        {
            // Does NOT validate whether it's a valid broadcast - that's done in the validate func in WorkloadData.cpp
            tensorDimensions[i] = static_cast<int>(i) - static_cast<int>(inputNumDimsOffset) < 0 ?
                longerInput[i] :
                std::max(longerInput[i], shorterInput[i - inputNumDimsOffset]);
        }
    }

    auto outputShape = TensorShape(outputNumDimensions, tensorDimensions.data());
    return std::vector<TensorShape>({ outputShape });
}

void BatchMatMulLayer::ValidateTensorShapesFromInputs()
{
    VerifyLayerConnections(2, CHECK_LOCATION());

    const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();

    VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);

    auto inferredShapes = InferOutputShapes({
        GetInputSlot(0).GetTensorInfo().GetShape(),
        GetInputSlot(1).GetTensorInfo().GetShape() });

    if (inferredShapes.size() != 1)
    {
        throw armnn::LayerValidationException("inferredShapes has "
                                              + std::to_string(inferredShapes.size()) +
                                              " elements - should only have 1.");
    }

    ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "BatchMatMulLayer");
}

} // namespace armnn