aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/strategies/StrategyValidator.cpp
blob: 173ced4cac602cd2f733c79531e509fedec9d111 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
//
// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <unordered_map>
#include <iostream>
#include "StrategyValidator.hpp"

namespace armnn
{

std::vector<MemBin> StrategyValidator::Optimize(std::vector<MemBlock>& memBlocks)
{
    // Condition #1: All Memblocks have been assigned to a MemBin

    // Condition #2: No Memblock is assigned to multiple MemBins

    // Condition #3: No two Memblocks in a MemBin overlap in both the X and Y axis
    //               Memblocks in a MemBin can overlap on the X axis for SingleAxisPacking
    //               Memblocks in a MemBin can overlap on the Y axis or the X for MultiAxisPacking but not both

    std::unordered_map<unsigned int, bool> validationMap;

    for (auto memBlock : memBlocks)
    {
        validationMap[memBlock.m_Index] = false;
    }

    auto memBinVect = m_Strategy->Optimize(memBlocks);

    // Compare each of the input memblocks against every assignedBlock in each bin
    // if we get through all bins without finding a block return
    // if at any stage the block is found twice return

    for (auto memBin : memBinVect)
    {
        for (auto block : memBin.m_MemBlocks)
        {
            try
            {
                if (!validationMap.at(block.m_Index))
                {
                    validationMap.at(block.m_Index) = true;
                }
                else
                {
                    throw MemoryValidationException("Condition #2: Memblock is assigned to multiple MemBins");
                }
            }
            catch (const std::out_of_range&)
            {
                throw MemoryValidationException("Unknown index ");
            }
        }
    }

    for (auto memBlock : memBlocks)
    {
        if (!validationMap.at(memBlock.m_Index))
        {
            throw MemoryValidationException("Condition #1: Block not found in any bin");
        }
    }

    // Check for overlaps once we know blocks are all assigned and no duplicates
    for (auto bin : memBinVect)
    {
        for (unsigned int i = 0; i < bin.m_MemBlocks.size(); ++i)
        {
            auto Block1 = bin.m_MemBlocks[i];
            auto B1Left = Block1.m_Offset;
            auto B1Right = Block1.m_Offset + Block1.m_MemSize;

            auto B1Top = Block1.m_StartOfLife;
            auto B1Bottom = Block1.m_EndOfLife;

            // Only compare with blocks after the current one as previous have already been checked
            for (unsigned int j = i + 1; j < bin.m_MemBlocks.size(); ++j)
            {
                auto Block2 = bin.m_MemBlocks[j];
                auto B2Left = Block2.m_Offset;
                auto B2Right = Block2.m_Offset + Block2.m_MemSize;

                auto B2Top = Block2.m_StartOfLife;
                auto B2Bottom = Block2.m_EndOfLife;

                switch (m_Strategy->GetMemBlockStrategyType())
                {
                    case (MemBlockStrategyType::SingleAxisPacking):
                    {
                        if (B1Top <= B2Bottom && B1Bottom >= B2Top)
                        {
                            throw MemoryValidationException("Condition #3: "
                                        "invalid as two Memblocks overlap on the Y axis for SingleAxisPacking");

                        }
                        break;
                    }
                    case (MemBlockStrategyType::MultiAxisPacking):
                    {
                        // If overlapping on both X and Y then invalid
                        if (B1Left <= B2Right && B1Right >= B2Left &&
                            B1Top <= B2Bottom && B1Bottom >= B2Top)
                        {
                            // Condition #3: two Memblocks overlap on both the X and Y axis
                            throw MemoryValidationException("Condition #3: "
                                                            "two Memblocks overlap on both the X and Y axis");
                        }
                        break;
                    }
                    default:
                        throw MemoryValidationException("Unknown MemBlockStrategyType");
                }
            }
        }
    }

    // None of the conditions broken so return true
    return memBinVect;
}

} // namespace armnn