diff options
author | Finn Williams <finn.williams@arm.com> | 2021-11-08 15:22:45 +0000 |
---|---|---|
committer | David Monahan <david.monahan@arm.com> | 2021-11-08 18:11:37 +0000 |
commit | b03e8ffd1a895b680dca1ce90c049fa7a9a40cb0 (patch) | |
tree | 03a97f1cb5d4d1d88324c059f782d840b7b0522d /src/backends/backendsCommon | |
parent | f8fb46df602c72d62defe82e3283b33f9eeccdd3 (diff) | |
download | armnn-b03e8ffd1a895b680dca1ce90c049fa7a9a40cb0.tar.gz |
Fix MemoryOptimizerStrategyLibrary search
Signed-off-by: Finn Williams <finn.williams@arm.com>
Change-Id: I4ca8d9196abd0e116d420a36c780e39edbca0eb3
Diffstat (limited to 'src/backends/backendsCommon')
4 files changed, 70 insertions, 36 deletions
diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp index aff0995266..7b04f442d2 100644 --- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp @@ -12,17 +12,19 @@ namespace armnn { -class MemoryOptimizerStrategyFactory +struct IMemoryOptimizerStrategyFactory { -public: - MemoryOptimizerStrategyFactory() {} + virtual ~IMemoryOptimizerStrategyFactory() = default; + virtual std::unique_ptr<IMemoryOptimizerStrategy> CreateMemoryOptimizerStrategy() = 0; +}; - template <typename T> - std::unique_ptr<IMemoryOptimizerStrategy> CreateMemoryOptimizerStrategy() +template <typename T> +struct StrategyFactory : public IMemoryOptimizerStrategyFactory +{ + std::unique_ptr<IMemoryOptimizerStrategy> CreateMemoryOptimizerStrategy() override { return std::make_unique<T>(); } - }; } // namespace armnn
\ No newline at end of file diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp index 5fa151560b..9814405ff7 100644 --- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp @@ -6,49 +6,52 @@ #include <armnn/backends/IMemoryOptimizerStrategy.hpp> #include "MemoryOptimizerStrategyFactory.hpp" -#include <algorithm> #include "strategies/ConstantMemoryStrategy.hpp" #include "strategies/StrategyValidator.hpp" #include "strategies/SingleAxisPriorityList.hpp" -namespace +#include <map> + +namespace armnn { -// Default Memory Optimizer Strategies -static const std::vector<std::string> memoryOptimizationStrategies( +namespace { - "ConstantMemoryStrategy", - "SingleAxisPriorityList" - "StrategyValidator" -}); - -#define CREATE_MEMORY_OPTIMIZER_STRATEGY(strategyName, memoryOptimizerStrategy) \ -{ \ - MemoryOptimizerStrategyFactory memoryOptimizerStrategyFactory; \ - memoryOptimizerStrategy = memoryOptimizerStrategyFactory.CreateMemoryOptimizerStrategy<strategyName>(); \ -} \ -} // anonymous namespace -namespace armnn +static std::map<std::string, std::unique_ptr<IMemoryOptimizerStrategyFactory>>& GetStrategyFactories() { - std::unique_ptr<IMemoryOptimizerStrategy> GetMemoryOptimizerStrategy(const std::string& strategyName) + static std::map<std::string, std::unique_ptr<IMemoryOptimizerStrategyFactory>> strategies; + + if (strategies.size() == 0) { - auto doesStrategyExist = std::find(memoryOptimizationStrategies.begin(), - memoryOptimizationStrategies.end(), - strategyName) != memoryOptimizationStrategies.end(); - if (doesStrategyExist) - { - std::unique_ptr<IMemoryOptimizerStrategy> memoryOptimizerStrategy = nullptr; - CREATE_MEMORY_OPTIMIZER_STRATEGY(armnn::ConstantMemoryStrategy, - memoryOptimizerStrategy); - return memoryOptimizerStrategy; - } - return nullptr; + strategies["ConstantMemoryStrategy"] = std::make_unique<StrategyFactory<ConstantMemoryStrategy>>(); + strategies["SingleAxisPriorityList"] = std::make_unique<StrategyFactory<SingleAxisPriorityList>>(); + strategies["StrategyValidator"] = std::make_unique<StrategyFactory<StrategyValidator>>(); } + return strategies; +} + +} // anonymous namespace +std::unique_ptr<IMemoryOptimizerStrategy> GetMemoryOptimizerStrategy(const std::string& strategyName) +{ + const auto& strategyFactoryMap = GetStrategyFactories(); + auto strategyFactory = strategyFactoryMap.find(strategyName); + if (strategyFactory != GetStrategyFactories().end()) + { + return strategyFactory->second->CreateMemoryOptimizerStrategy(); + } + return nullptr; +} - const std::vector<std::string>& GetMemoryOptimizerStrategyNames() +const std::vector<std::string> GetMemoryOptimizerStrategyNames() +{ + const auto& strategyFactoryMap = GetStrategyFactories(); + std::vector<std::string> strategyNames; + for (const auto& strategyFactory : strategyFactoryMap) { - return memoryOptimizationStrategies; + strategyNames.emplace_back(strategyFactory.first); } + return strategyNames; +} } // namespace armnn
\ No newline at end of file diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt index 3068b609f6..a82f718862 100644 --- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND armnnMemoryOptimizationStrategiesUnitTests_sources ConstMemoryStrategyTests.cpp ValidatorStrategyTests.cpp SingleAxisPriorityListTests.cpp + MemoryOptimizerStrategyLibraryTests.cpp ) add_library(armnnMemoryOptimizationStrategiesUnitTests OBJECT ${armnnMemoryOptimizationStrategiesUnitTests_sources}) diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/MemoryOptimizerStrategyLibraryTests.cpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/MemoryOptimizerStrategyLibraryTests.cpp new file mode 100644 index 0000000000..482bc7d0bf --- /dev/null +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/MemoryOptimizerStrategyLibraryTests.cpp @@ -0,0 +1,28 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp> + +#include <doctest/doctest.h> + +using namespace armnn; + +TEST_SUITE("StrategyLibraryTestSuite") +{ + +TEST_CASE("StrategyLibraryTest") +{ + std::vector<std::string> strategyNames = GetMemoryOptimizerStrategyNames(); + CHECK(strategyNames.size() != 0); + for (const auto& strategyName: strategyNames) + { + auto strategy = GetMemoryOptimizerStrategy(strategyName); + CHECK(strategy); + CHECK(strategy->GetName() == strategyName); + } +} + +} + |