aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/OptimizerTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/OptimizerTests.cpp')
-rw-r--r--src/armnn/test/OptimizerTests.cpp346
1 files changed, 275 insertions, 71 deletions
diff --git a/src/armnn/test/OptimizerTests.cpp b/src/armnn/test/OptimizerTests.cpp
index d4e2d499d5..19bd58193a 100644
--- a/src/armnn/test/OptimizerTests.cpp
+++ b/src/armnn/test/OptimizerTests.cpp
@@ -139,6 +139,153 @@ void CreateLSTMLayerHelper(Graph &graph, bool CifgEnabled)
Connect(layer, output, lstmTensorInfo3, 3, 0);
}
+
+class MockLayerSupport : public LayerSupportBase
+{
+public:
+ bool IsInputSupported(const TensorInfo& /*input*/,
+ Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
+ {
+ return true;
+ }
+
+ bool IsOutputSupported(const TensorInfo& /*input*/,
+ Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
+ {
+ return true;
+ }
+
+ bool IsActivationSupported(const TensorInfo& /*input0*/,
+ const TensorInfo& /*output*/,
+ const ActivationDescriptor& /*descriptor*/,
+ Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
+ {
+ return true;
+ }
+};
+
+template <typename NamePolicy>
+class MockBackend : public IBackendInternal
+{
+public:
+ MockBackend() :
+ m_BackendCapabilities(NamePolicy::GetIdStatic(), {{"NullCapability", false}}),
+ m_CustomAllocator(false) {};
+ MockBackend(const BackendCapabilities& capabilities) :
+ m_BackendCapabilities(capabilities),
+ m_CustomAllocator(false) {};
+ ~MockBackend() = default;
+
+ static const BackendId& GetIdStatic()
+ {
+ return NamePolicy::GetIdStatic();
+ }
+ const BackendId& GetId() const override
+ {
+ return GetIdStatic();
+ }
+
+ IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override
+ {
+ return nullptr;
+ };
+
+ IBackendInternal::IWorkloadFactoryPtr
+ CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr&) const override
+ {
+ return nullptr;
+ }
+
+ IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
+ {
+ return nullptr;
+ }
+
+ IBackendInternal::Optimizations GetOptimizations() const override
+ {
+ return {};
+ }
+ IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
+ {
+ return std::make_shared<MockLayerSupport>();
+ }
+
+ OptimizationViews OptimizeSubgraphView(const SubgraphView&) const override
+ {
+ return {};
+ };
+
+ BackendCapabilities GetCapabilities() const override
+ {
+ return m_BackendCapabilities;
+ };
+
+ virtual bool UseCustomMemoryAllocator(armnn::Optional<std::string&> errMsg) override
+ {
+ IgnoreUnused(errMsg);
+ m_CustomAllocator = true;
+ return m_CustomAllocator;
+ }
+
+ BackendCapabilities m_BackendCapabilities;
+ bool m_CustomAllocator;
+};
+
+template <typename NamePolicy>
+class NoProtectedModeMockBackend : public IBackendInternal
+{
+public:
+ NoProtectedModeMockBackend() : m_BackendCapabilities(NamePolicy::GetIdStatic(), {{"NullCapability", false}}) {};
+ NoProtectedModeMockBackend(const BackendCapabilities& capabilities) : m_BackendCapabilities(capabilities) {};
+ ~NoProtectedModeMockBackend() = default;
+
+ static const BackendId& GetIdStatic()
+ {
+ return NamePolicy::GetIdStatic();
+ }
+ const BackendId& GetId() const override
+ {
+ return GetIdStatic();
+ }
+
+ IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override
+ {
+ return nullptr;
+ };
+
+ IBackendInternal::IWorkloadFactoryPtr
+ CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr&) const override
+ {
+ return nullptr;
+ }
+
+ IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
+ {
+ return nullptr;
+ }
+
+ IBackendInternal::Optimizations GetOptimizations() const override
+ {
+ return {};
+ }
+ IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
+ {
+ return std::make_shared<MockLayerSupport>();
+ }
+
+ OptimizationViews OptimizeSubgraphView(const SubgraphView&) const override
+ {
+ return {};
+ };
+
+ BackendCapabilities GetCapabilities() const override
+ {
+ return m_BackendCapabilities;
+ };
+
+ BackendCapabilities m_BackendCapabilities;
+};
+
} // namespace
TEST_SUITE("Optimizer")
@@ -543,77 +690,6 @@ TEST_CASE("DetectionPostProcessValidateTensorShapes")
CHECK_NOTHROW(graph.InferTensorInfos());
}
-class MockLayerSupport : public LayerSupportBase
-{
-public:
- bool IsInputSupported(const TensorInfo& /*input*/,
- Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
- {
- return true;
- }
-
- bool IsOutputSupported(const TensorInfo& /*input*/,
- Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
- {
- return true;
- }
-
- bool IsActivationSupported(const TensorInfo& /*input0*/,
- const TensorInfo& /*output*/,
- const ActivationDescriptor& /*descriptor*/,
- Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
- {
- return true;
- }
-};
-
-template <typename NamePolicy>
-class MockBackend : public IBackendInternal
-{
-public:
- MockBackend() = default;
- ~MockBackend() = default;
-
- static const BackendId& GetIdStatic()
- {
- return NamePolicy::GetIdStatic();
- }
- const BackendId& GetId() const override
- {
- return GetIdStatic();
- }
-
- IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override
- {
- return nullptr;
- };
-
- IBackendInternal::IWorkloadFactoryPtr
- CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr&) const override
- {
- return nullptr;
- }
-
- IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
- {
- return nullptr;
- }
-
- IBackendInternal::Optimizations GetOptimizations() const override
- {
- return {};
- }
- IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
- {
- return std::make_shared<MockLayerSupport>();
- }
-
- OptimizationViews OptimizeSubgraphView(const SubgraphView&) const override
- {
- return {};
- };
-};
-
TEST_CASE("BackendCapabilityTest")
{
BackendId backendId = "MockBackend";
@@ -848,4 +924,132 @@ TEST_CASE("OptimizeForExclusiveConnectionsWithoutFuseTest")
&IsLayerOfType<armnn::OutputLayer>,
&IsLayerOfType<armnn::OutputLayer>));
}
+} // Optimizer TestSuite
+
+TEST_SUITE("Runtime")
+{
+// This test really belongs into RuntimeTests.cpp but it requires all sort of MockBackends which are
+// already defined here
+TEST_CASE("RuntimeProtectedModeOption")
+{
+ using namespace armnn;
+
+ struct MockPolicy
+ {
+ static const BackendId& GetIdStatic()
+ {
+ static BackendId id = "MockBackend";
+ return id;
+ }
+ };
+
+ struct ProtectedPolicy
+ {
+ static const BackendId& GetIdStatic()
+ {
+ static BackendId id = "MockBackendProtectedContent";
+ return id;
+ }
+ };
+
+ struct SillyPolicy
+ {
+ static const BackendId& GetIdStatic()
+ {
+ static BackendId id = "SillyMockBackend";
+ return id;
+ }
+ };
+
+ BackendCapabilities mockBackendCapabilities("MockBackend",
+ {
+ {"ProtectedContentAllocation", false}
+ });
+ BackendCapabilities mockProtectedBackendCapabilities("MockBackendProtectedContent",
+ {
+ {"ProtectedContentAllocation", true}
+ });
+
+ auto& backendRegistry = BackendRegistryInstance();
+
+ // clean up from previous test runs
+ std::vector<BackendId> mockBackends = {"MockBackend", "MockBackendProtectedContent", "SillyMockBackend"};
+ for (auto& backend : mockBackends)
+ {
+ backendRegistry.Deregister(backend);
+ }
+
+ // Create a bunch of MockBackends with different capabilities
+ // 1. Doesn't support protected mode even though it knows about this capability
+ backendRegistry.Register("MockBackend", [mockBackendCapabilities]()
+ {
+ return std::make_unique<MockBackend<MockPolicy>>(mockBackendCapabilities);
+ });
+ // 2. Supports protected mode and has it implemented correctly
+ backendRegistry.Register("MockBackendProtectedContent", [mockProtectedBackendCapabilities]()
+ {
+ return std::make_unique<MockBackend<ProtectedPolicy>>(mockProtectedBackendCapabilities);
+ });
+ // 3. Claims to support protected mode but doesn't have the UseCustomMemoryAllocator function implemented
+ backendRegistry.Register("SillyMockBackend", [mockProtectedBackendCapabilities]()
+ {
+ return std::make_unique<NoProtectedModeMockBackend<SillyPolicy>>(mockProtectedBackendCapabilities);
+ });
+
+ // Creates a runtime that is not in protected mode
+ {
+ IRuntime::CreationOptions creationOptions;
+ creationOptions.m_ProtectedMode = false;
+
+ IRuntimePtr run = IRuntime::Create(creationOptions);
+
+ const armnn::BackendIdSet supportedDevices = run->GetDeviceSpec().GetSupportedBackends();
+ // Both MockBackends that are registered should show up in the runtimes supported backends list
+ for (auto& backend : mockBackends)
+ {
+ CHECK(std::find(supportedDevices.cbegin(), supportedDevices.cend(), backend) != supportedDevices.cend());
+ }
+ }
+
+ // If the runtime is in protected mode only backends that support protected content should be added
+ {
+ IRuntime::CreationOptions creationOptions;
+ creationOptions.m_ProtectedMode = true;
+
+ IRuntimePtr run = IRuntime::Create(creationOptions);
+
+ const armnn::BackendIdSet supportedDevices = run->GetDeviceSpec().GetSupportedBackends();
+ // Only the MockBackends that claims support for protected content should show up in the
+ // runtimes supported backends list
+ CHECK(std::find(supportedDevices.cbegin(),
+ supportedDevices.cend(),
+ "MockBackendProtectedContent") != supportedDevices.cend());
+ CHECK(std::find(supportedDevices.cbegin(),
+ supportedDevices.cend(),
+ "MockBackend") == supportedDevices.cend());
+ CHECK(std::find(supportedDevices.cbegin(),
+ supportedDevices.cend(),
+ "SillyMockBackend") == supportedDevices.cend());
+ }
+
+ // If the runtime is in protected mode only backends that support protected content should be added
+ {
+ IRuntime::CreationOptions creationOptions;
+ creationOptions.m_ProtectedMode = true;
+
+ IRuntimePtr run = IRuntime::Create(creationOptions);
+
+ const armnn::BackendIdSet supportedDevices = run->GetDeviceSpec().GetSupportedBackends();
+ // Only the MockBackend that claims support for protected content should show up in the
+ // runtimes supported backends list
+ CHECK(std::find(supportedDevices.cbegin(),
+ supportedDevices.cend(),
+ "MockBackendProtectedContent") != supportedDevices.cend());
+
+ CHECK(std::find(supportedDevices.cbegin(),
+ supportedDevices.cend(),
+ "MockBackend") == supportedDevices.cend());
+ }
+
+}
}