aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp')
-rw-r--r--src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp63
1 files changed, 33 insertions, 30 deletions
diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp
index 5fa151560b..9814405ff7 100644
--- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp
+++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp
@@ -6,49 +6,52 @@
#include <armnn/backends/IMemoryOptimizerStrategy.hpp>
#include "MemoryOptimizerStrategyFactory.hpp"
-#include <algorithm>
#include "strategies/ConstantMemoryStrategy.hpp"
#include "strategies/StrategyValidator.hpp"
#include "strategies/SingleAxisPriorityList.hpp"
-namespace
+#include <map>
+
+namespace armnn
{
-// Default Memory Optimizer Strategies
-static const std::vector<std::string> memoryOptimizationStrategies(
+namespace
{
- "ConstantMemoryStrategy",
- "SingleAxisPriorityList"
- "StrategyValidator"
-});
-
-#define CREATE_MEMORY_OPTIMIZER_STRATEGY(strategyName, memoryOptimizerStrategy) \
-{ \
- MemoryOptimizerStrategyFactory memoryOptimizerStrategyFactory; \
- memoryOptimizerStrategy = memoryOptimizerStrategyFactory.CreateMemoryOptimizerStrategy<strategyName>(); \
-} \
-} // anonymous namespace
-namespace armnn
+static std::map<std::string, std::unique_ptr<IMemoryOptimizerStrategyFactory>>& GetStrategyFactories()
{
- std::unique_ptr<IMemoryOptimizerStrategy> GetMemoryOptimizerStrategy(const std::string& strategyName)
+ static std::map<std::string, std::unique_ptr<IMemoryOptimizerStrategyFactory>> strategies;
+
+ if (strategies.size() == 0)
{
- auto doesStrategyExist = std::find(memoryOptimizationStrategies.begin(),
- memoryOptimizationStrategies.end(),
- strategyName) != memoryOptimizationStrategies.end();
- if (doesStrategyExist)
- {
- std::unique_ptr<IMemoryOptimizerStrategy> memoryOptimizerStrategy = nullptr;
- CREATE_MEMORY_OPTIMIZER_STRATEGY(armnn::ConstantMemoryStrategy,
- memoryOptimizerStrategy);
- return memoryOptimizerStrategy;
- }
- return nullptr;
+ strategies["ConstantMemoryStrategy"] = std::make_unique<StrategyFactory<ConstantMemoryStrategy>>();
+ strategies["SingleAxisPriorityList"] = std::make_unique<StrategyFactory<SingleAxisPriorityList>>();
+ strategies["StrategyValidator"] = std::make_unique<StrategyFactory<StrategyValidator>>();
}
+ return strategies;
+}
+
+} // anonymous namespace
+std::unique_ptr<IMemoryOptimizerStrategy> GetMemoryOptimizerStrategy(const std::string& strategyName)
+{
+ const auto& strategyFactoryMap = GetStrategyFactories();
+ auto strategyFactory = strategyFactoryMap.find(strategyName);
+ if (strategyFactory != GetStrategyFactories().end())
+ {
+ return strategyFactory->second->CreateMemoryOptimizerStrategy();
+ }
+ return nullptr;
+}
- const std::vector<std::string>& GetMemoryOptimizerStrategyNames()
+const std::vector<std::string> GetMemoryOptimizerStrategyNames()
+{
+ const auto& strategyFactoryMap = GetStrategyFactories();
+ std::vector<std::string> strategyNames;
+ for (const auto& strategyFactory : strategyFactoryMap)
{
- return memoryOptimizationStrategies;
+ strategyNames.emplace_back(strategyFactory.first);
}
+ return strategyNames;
+}
} // namespace armnn \ No newline at end of file