From b03e8ffd1a895b680dca1ce90c049fa7a9a40cb0 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Mon, 8 Nov 2021 15:22:45 +0000 Subject: Fix MemoryOptimizerStrategyLibrary search Signed-off-by: Finn Williams Change-Id: I4ca8d9196abd0e116d420a36c780e39edbca0eb3 --- .../MemoryOptimizerStrategyFactory.hpp | 14 ++--- .../MemoryOptimizerStrategyLibrary.hpp | 63 +++++++++++----------- .../test/CMakeLists.txt | 1 + .../test/MemoryOptimizerStrategyLibraryTests.cpp | 28 ++++++++++ 4 files changed, 70 insertions(+), 36 deletions(-) create mode 100644 src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/MemoryOptimizerStrategyLibraryTests.cpp diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp index aff0995266..7b04f442d2 100644 --- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyFactory.hpp @@ -12,17 +12,19 @@ namespace armnn { -class MemoryOptimizerStrategyFactory +struct IMemoryOptimizerStrategyFactory { -public: - MemoryOptimizerStrategyFactory() {} + virtual ~IMemoryOptimizerStrategyFactory() = default; + virtual std::unique_ptr CreateMemoryOptimizerStrategy() = 0; +}; - template - std::unique_ptr CreateMemoryOptimizerStrategy() +template +struct StrategyFactory : public IMemoryOptimizerStrategyFactory +{ + std::unique_ptr CreateMemoryOptimizerStrategy() override { return std::make_unique(); } - }; } // namespace armnn \ No newline at end of file diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp index 5fa151560b..9814405ff7 100644 --- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/MemoryOptimizerStrategyLibrary.hpp @@ -6,49 +6,52 @@ #include #include "MemoryOptimizerStrategyFactory.hpp" -#include #include "strategies/ConstantMemoryStrategy.hpp" #include "strategies/StrategyValidator.hpp" #include "strategies/SingleAxisPriorityList.hpp" -namespace +#include + +namespace armnn { -// Default Memory Optimizer Strategies -static const std::vector memoryOptimizationStrategies( +namespace { - "ConstantMemoryStrategy", - "SingleAxisPriorityList" - "StrategyValidator" -}); - -#define CREATE_MEMORY_OPTIMIZER_STRATEGY(strategyName, memoryOptimizerStrategy) \ -{ \ - MemoryOptimizerStrategyFactory memoryOptimizerStrategyFactory; \ - memoryOptimizerStrategy = memoryOptimizerStrategyFactory.CreateMemoryOptimizerStrategy(); \ -} \ -} // anonymous namespace -namespace armnn +static std::map>& GetStrategyFactories() { - std::unique_ptr GetMemoryOptimizerStrategy(const std::string& strategyName) + static std::map> strategies; + + if (strategies.size() == 0) { - auto doesStrategyExist = std::find(memoryOptimizationStrategies.begin(), - memoryOptimizationStrategies.end(), - strategyName) != memoryOptimizationStrategies.end(); - if (doesStrategyExist) - { - std::unique_ptr memoryOptimizerStrategy = nullptr; - CREATE_MEMORY_OPTIMIZER_STRATEGY(armnn::ConstantMemoryStrategy, - memoryOptimizerStrategy); - return memoryOptimizerStrategy; - } - return nullptr; + strategies["ConstantMemoryStrategy"] = std::make_unique>(); + strategies["SingleAxisPriorityList"] = std::make_unique>(); + strategies["StrategyValidator"] = std::make_unique>(); } + return strategies; +} + +} // anonymous namespace +std::unique_ptr GetMemoryOptimizerStrategy(const std::string& strategyName) +{ + const auto& strategyFactoryMap = GetStrategyFactories(); + auto strategyFactory = strategyFactoryMap.find(strategyName); + if (strategyFactory != GetStrategyFactories().end()) + { + return strategyFactory->second->CreateMemoryOptimizerStrategy(); + } + return nullptr; +} - const std::vector& GetMemoryOptimizerStrategyNames() +const std::vector GetMemoryOptimizerStrategyNames() +{ + const auto& strategyFactoryMap = GetStrategyFactories(); + std::vector strategyNames; + for (const auto& strategyFactory : strategyFactoryMap) { - return memoryOptimizationStrategies; + strategyNames.emplace_back(strategyFactory.first); } + return strategyNames; +} } // namespace armnn \ No newline at end of file diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt index 3068b609f6..a82f718862 100644 --- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/CMakeLists.txt @@ -7,6 +7,7 @@ list(APPEND armnnMemoryOptimizationStrategiesUnitTests_sources ConstMemoryStrategyTests.cpp ValidatorStrategyTests.cpp SingleAxisPriorityListTests.cpp + MemoryOptimizerStrategyLibraryTests.cpp ) add_library(armnnMemoryOptimizationStrategiesUnitTests OBJECT ${armnnMemoryOptimizationStrategiesUnitTests_sources}) diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/MemoryOptimizerStrategyLibraryTests.cpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/MemoryOptimizerStrategyLibraryTests.cpp new file mode 100644 index 0000000000..482bc7d0bf --- /dev/null +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/MemoryOptimizerStrategyLibraryTests.cpp @@ -0,0 +1,28 @@ +// +// Copyright © 2021 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include + +#include + +using namespace armnn; + +TEST_SUITE("StrategyLibraryTestSuite") +{ + +TEST_CASE("StrategyLibraryTest") +{ + std::vector strategyNames = GetMemoryOptimizerStrategyNames(); + CHECK(strategyNames.size() != 0); + for (const auto& strategyName: strategyNames) + { + auto strategy = GetMemoryOptimizerStrategy(strategyName); + CHECK(strategy); + CHECK(strategy->GetName() == strategyName); + } +} + +} + -- cgit v1.2.1