diff options
Diffstat (limited to 'include/armnnTestUtils/MockBackend.hpp')
-rw-r--r-- | include/armnnTestUtils/MockBackend.hpp | 229 |
1 files changed, 220 insertions, 9 deletions
diff --git a/include/armnnTestUtils/MockBackend.hpp b/include/armnnTestUtils/MockBackend.hpp index 8bc41b3f3f..425062ac28 100644 --- a/include/armnnTestUtils/MockBackend.hpp +++ b/include/armnnTestUtils/MockBackend.hpp @@ -4,9 +4,12 @@ // #pragma once +#include <atomic> + #include <armnn/backends/IBackendInternal.hpp> #include <armnn/backends/MemCopyWorkload.hpp> #include <armnnTestUtils/MockTensorHandle.hpp> +#include <backendsCommon/LayerSupportBase.hpp> namespace armnn { @@ -26,16 +29,20 @@ public: return GetIdStatic(); } IBackendInternal::IWorkloadFactoryPtr - CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override - { - IgnoreUnused(memoryManager); - return nullptr; - } + CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override; - IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override - { - return nullptr; - }; + IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override; + + IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override; + + IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override; + IBackendInternal::IBackendProfilingContextPtr + CreateBackendProfilingContext(const IRuntime::CreationOptions& creationOptions, + IBackendProfilingPtr& backendProfiling) override; + + OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const override; + + std::unique_ptr<ICustomAllocator> GetDefaultAllocator() const override; }; class MockWorkloadFactory : public IWorkloadFactory @@ -112,4 +119,208 @@ private: mutable std::shared_ptr<MockMemoryManager> m_MemoryManager; }; +class MockBackendInitialiser +{ +public: + MockBackendInitialiser(); + ~MockBackendInitialiser(); +}; + +class MockBackendProfilingContext : public arm::pipe::IBackendProfilingContext +{ +public: + MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling) + : m_BackendProfiling(std::move(backendProfiling)) + , m_CapturePeriod(0) + , m_IsTimelineEnabled(true) + {} + + ~MockBackendProfilingContext() = default; + + IBackendInternal::IBackendProfilingPtr& GetBackendProfiling() + { + return m_BackendProfiling; + } + + uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterId) + { + std::unique_ptr<arm::pipe::IRegisterBackendCounters> counterRegistrar = + m_BackendProfiling->GetCounterRegistrationInterface(static_cast<uint16_t>(currentMaxGlobalCounterId)); + + std::string categoryName("MockCounters"); + counterRegistrar->RegisterCategory(categoryName); + + counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter"); + + counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two", + "Another notional counter"); + + std::string units("microseconds"); + uint16_t nextMaxGlobalCounterId = + counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter", + "A dummy four core counter", units, 4); + return nextMaxGlobalCounterId; + } + + Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds) + { + if (capturePeriod == 0 || counterIds.size() == 0) + { + m_ActiveCounters.clear(); + } + else if (capturePeriod == 15939u) + { + return armnn::Optional<std::string>("ActivateCounters example test error"); + } + m_CapturePeriod = capturePeriod; + m_ActiveCounters = counterIds; + return armnn::Optional<std::string>(); + } + + std::vector<arm::pipe::Timestamp> ReportCounterValues() + { + std::vector<arm::pipe::CounterValue> counterValues; + + for (auto counterId : m_ActiveCounters) + { + counterValues.emplace_back(arm::pipe::CounterValue{ counterId, counterId + 1u }); + } + + uint64_t timestamp = m_CapturePeriod; + return { arm::pipe::Timestamp{ timestamp, counterValues } }; + } + + bool EnableProfiling(bool) + { + auto sendTimelinePacket = m_BackendProfiling->GetSendTimelinePacket(); + sendTimelinePacket->SendTimelineEntityBinaryPacket(4256); + sendTimelinePacket->Commit(); + return true; + } + + bool EnableTimelineReporting(bool isEnabled) + { + m_IsTimelineEnabled = isEnabled; + return isEnabled; + } + + bool TimelineReportingEnabled() + { + return m_IsTimelineEnabled; + } + +private: + IBackendInternal::IBackendProfilingPtr m_BackendProfiling; + uint32_t m_CapturePeriod; + std::vector<uint16_t> m_ActiveCounters; + std::atomic<bool> m_IsTimelineEnabled; +}; + +class MockBackendProfilingService +{ +public: + // Getter for the singleton instance + static MockBackendProfilingService& Instance() + { + static MockBackendProfilingService instance; + return instance; + } + + MockBackendProfilingContext* GetContext() + { + return m_sharedContext.get(); + } + + void SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared) + { + m_sharedContext = shared; + } + +private: + std::shared_ptr<MockBackendProfilingContext> m_sharedContext; +}; + +class MockLayerSupport : public LayerSupportBase +{ +public: + bool IsLayerSupported(const LayerType& type, + const std::vector<TensorInfo>& infos, + const BaseDescriptor& descriptor, + const Optional<LstmInputParamsInfo>& /*lstmParamsInfo*/, + const Optional<QuantizedLstmInputParamsInfo>& /*quantizedLstmParamsInfo*/, + Optional<std::string&> reasonIfUnsupported) const override + { + switch(type) + { + case LayerType::Input: + return IsInputSupported(infos[0], reasonIfUnsupported); + case LayerType::Output: + return IsOutputSupported(infos[0], reasonIfUnsupported); + case LayerType::Addition: + return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported); + case LayerType::Convolution2d: + { + if (infos.size() != 4) + { + throw InvalidArgumentException("Invalid number of TransposeConvolution2d " + "TensorInfos. TensorInfos should be of format: " + "{input, output, weights, biases}."); + } + + auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor)); + if (infos[3] == TensorInfo()) + { + return IsConvolution2dSupported(infos[0], + infos[1], + desc, + infos[2], + EmptyOptional(), + reasonIfUnsupported); + } + else + { + return IsConvolution2dSupported(infos[0], + infos[1], + desc, + infos[2], + infos[3], + reasonIfUnsupported); + } + } + default: + return false; + } + } + + bool IsInputSupported(const TensorInfo& /*input*/, + Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override + { + return true; + } + + bool IsOutputSupported(const TensorInfo& /*input*/, + Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override + { + return true; + } + + bool IsAdditionSupported(const TensorInfo& /*input0*/, + const TensorInfo& /*input1*/, + const TensorInfo& /*output*/, + Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override + { + return true; + } + + bool IsConvolution2dSupported(const TensorInfo& /*input*/, + const TensorInfo& /*output*/, + const Convolution2dDescriptor& /*descriptor*/, + const TensorInfo& /*weights*/, + const Optional<TensorInfo>& /*biases*/, + Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override + { + return true; + } +}; + } // namespace armnn |