diff options
Diffstat (limited to 'src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp')
-rw-r--r-- | src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp | 63 |
1 files changed, 33 insertions, 30 deletions
diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp index 5fa151560b..9814405ff7 100644 --- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp @@ -6,49 +6,52 @@ #include <armnn/backends/IMemoryOptimizerStrategy.hpp> #include "MemoryOptimizerStrategyFactory.hpp" -#include <algorithm> #include "strategies/ConstantMemoryStrategy.hpp" #include "strategies/StrategyValidator.hpp" #include "strategies/SingleAxisPriorityList.hpp" -namespace +#include <map> + +namespace armnn { -// Default Memory Optimizer Strategies -static const std::vector<std::string> memoryOptimizationStrategies( +namespace { - "ConstantMemoryStrategy", - "SingleAxisPriorityList" - "StrategyValidator" -}); - -#define CREATE_MEMORY_OPTIMIZER_STRATEGY(strategyName, memoryOptimizerStrategy) \ -{ \ - MemoryOptimizerStrategyFactory memoryOptimizerStrategyFactory; \ - memoryOptimizerStrategy = memoryOptimizerStrategyFactory.CreateMemoryOptimizerStrategy<strategyName>(); \ -} \ -} // anonymous namespace -namespace armnn +static std::map<std::string, std::unique_ptr<IMemoryOptimizerStrategyFactory>>& GetStrategyFactories() { - std::unique_ptr<IMemoryOptimizerStrategy> GetMemoryOptimizerStrategy(const std::string& strategyName) + static std::map<std::string, std::unique_ptr<IMemoryOptimizerStrategyFactory>> strategies; + + if (strategies.size() == 0) { - auto doesStrategyExist = std::find(memoryOptimizationStrategies.begin(), - memoryOptimizationStrategies.end(), - strategyName) != memoryOptimizationStrategies.end(); - if (doesStrategyExist) - { - std::unique_ptr<IMemoryOptimizerStrategy> memoryOptimizerStrategy = nullptr; - CREATE_MEMORY_OPTIMIZER_STRATEGY(armnn::ConstantMemoryStrategy, - memoryOptimizerStrategy); - return memoryOptimizerStrategy; - } - return nullptr; + strategies["ConstantMemoryStrategy"] = std::make_unique<StrategyFactory<ConstantMemoryStrategy>>(); + strategies["SingleAxisPriorityList"] = std::make_unique<StrategyFactory<SingleAxisPriorityList>>(); + strategies["StrategyValidator"] = std::make_unique<StrategyFactory<StrategyValidator>>(); } + return strategies; +} + +} // anonymous namespace +std::unique_ptr<IMemoryOptimizerStrategy> GetMemoryOptimizerStrategy(const std::string& strategyName) +{ + const auto& strategyFactoryMap = GetStrategyFactories(); + auto strategyFactory = strategyFactoryMap.find(strategyName); + if (strategyFactory != GetStrategyFactories().end()) + { + return strategyFactory->second->CreateMemoryOptimizerStrategy(); + } + return nullptr; +} - const std::vector<std::string>& GetMemoryOptimizerStrategyNames() +const std::vector<std::string> GetMemoryOptimizerStrategyNames() +{ + const auto& strategyFactoryMap = GetStrategyFactories(); + std::vector<std::string> strategyNames; + for (const auto& strategyFactory : strategyFactoryMap) { - return memoryOptimizationStrategies; + strategyNames.emplace_back(strategyFactory.first); } + return strategyNames; +} } // namespace armnn
\ No newline at end of file |