aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp
blob: 5e20a9f218257a70ce9e5892106f30939faa5743 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
//
// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once

#include <armnn/backends/IMemoryOptimizerStrategy.hpp>
#include "MemoryOptimizerStrategyFactory.hpp"
#include <algorithm>

#include "strategies/ConstantMemoryStrategy.hpp"
#include "strategies/StrategyValidator.hpp"

namespace
{
// Default Memory Optimizer Strategies
static const std::vector<std::string> memoryOptimizationStrategies(
{
    "ConstantMemoryStrategy",
    "StrategyValidator"
});

#define CREATE_MEMORY_OPTIMIZER_STRATEGY(strategyName, memoryOptimizerStrategy)                                  \
{                                                                                                                \
    MemoryOptimizerStrategyFactory memoryOptimizerStrategyFactory;                                               \
    memoryOptimizerStrategy = memoryOptimizerStrategyFactory.CreateMemoryOptimizerStrategy<strategyName>();      \
}                                                                                                                \

} // anonymous namespace
namespace armnn
{
    std::unique_ptr<IMemoryOptimizerStrategy> GetMemoryOptimizerStrategy(const std::string& strategyName)
    {
        auto doesStrategyExist = std::find(memoryOptimizationStrategies.begin(),
                                           memoryOptimizationStrategies.end(),
                                           strategyName) != memoryOptimizationStrategies.end();
        if (doesStrategyExist)
        {
            std::unique_ptr<IMemoryOptimizerStrategy> memoryOptimizerStrategy = nullptr;
            CREATE_MEMORY_OPTIMIZER_STRATEGY(armnn::ConstantMemoryStrategy,
                                             memoryOptimizerStrategy);
            return  memoryOptimizerStrategy;
        }
        return nullptr;
    }


    const std::vector<std::string>& GetMemoryOptimizerStrategyNames()
    {
        return memoryOptimizationStrategies;
    }
} // namespace armnn