aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorCathal Corbett <cathal.corbett@arm.com>2022-03-04 11:36:39 +0000
committerJim Flynn <jim.flynn@arm.com>2022-03-08 21:26:31 +0000
commit3464ba127b83cd36d65cdc7ee9f5dd7b3715a18e (patch)
tree16a6592067a297dd8f45fec9edf4ff02ebe67112 /include
parentfbe4594f9700ca13177fb1a36b82ede539f31e2f (diff)
downloadarmnn-3464ba127b83cd36d65cdc7ee9f5dd7b3715a18e.tar.gz
IVGCVSW-6772 Eliminate armnn/src/backends/backendsCommon/test/MockBackend.hpp
Signed-off-by: Cathal Corbett <cathal.corbett@arm.com> Change-Id: Ie99fe9786eb5e30585f437d0c6362c73688148db
Diffstat (limited to 'include')
-rw-r--r--include/armnnTestUtils/MockBackend.hpp229
1 files changed, 220 insertions, 9 deletions
diff --git a/include/armnnTestUtils/MockBackend.hpp b/include/armnnTestUtils/MockBackend.hpp
index 8bc41b3f3f..425062ac28 100644
--- a/include/armnnTestUtils/MockBackend.hpp
+++ b/include/armnnTestUtils/MockBackend.hpp
@@ -4,9 +4,12 @@
//
#pragma once
+#include <atomic>
+
#include <armnn/backends/IBackendInternal.hpp>
#include <armnn/backends/MemCopyWorkload.hpp>
#include <armnnTestUtils/MockTensorHandle.hpp>
+#include <backendsCommon/LayerSupportBase.hpp>
namespace armnn
{
@@ -26,16 +29,20 @@ public:
return GetIdStatic();
}
IBackendInternal::IWorkloadFactoryPtr
- CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override
- {
- IgnoreUnused(memoryManager);
- return nullptr;
- }
+ CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override;
- IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
- {
- return nullptr;
- };
+ IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override;
+
+ IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override;
+
+ IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
+ IBackendInternal::IBackendProfilingContextPtr
+ CreateBackendProfilingContext(const IRuntime::CreationOptions& creationOptions,
+ IBackendProfilingPtr& backendProfiling) override;
+
+ OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const override;
+
+ std::unique_ptr<ICustomAllocator> GetDefaultAllocator() const override;
};
class MockWorkloadFactory : public IWorkloadFactory
@@ -112,4 +119,208 @@ private:
mutable std::shared_ptr<MockMemoryManager> m_MemoryManager;
};
+class MockBackendInitialiser
+{
+public:
+ MockBackendInitialiser();
+ ~MockBackendInitialiser();
+};
+
+class MockBackendProfilingContext : public arm::pipe::IBackendProfilingContext
+{
+public:
+ MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling)
+ : m_BackendProfiling(std::move(backendProfiling))
+ , m_CapturePeriod(0)
+ , m_IsTimelineEnabled(true)
+ {}
+
+ ~MockBackendProfilingContext() = default;
+
+ IBackendInternal::IBackendProfilingPtr& GetBackendProfiling()
+ {
+ return m_BackendProfiling;
+ }
+
+ uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterId)
+ {
+ std::unique_ptr<arm::pipe::IRegisterBackendCounters> counterRegistrar =
+ m_BackendProfiling->GetCounterRegistrationInterface(static_cast<uint16_t>(currentMaxGlobalCounterId));
+
+ std::string categoryName("MockCounters");
+ counterRegistrar->RegisterCategory(categoryName);
+
+ counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter");
+
+ counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two",
+ "Another notional counter");
+
+ std::string units("microseconds");
+ uint16_t nextMaxGlobalCounterId =
+ counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter",
+ "A dummy four core counter", units, 4);
+ return nextMaxGlobalCounterId;
+ }
+
+ Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
+ {
+ if (capturePeriod == 0 || counterIds.size() == 0)
+ {
+ m_ActiveCounters.clear();
+ }
+ else if (capturePeriod == 15939u)
+ {
+ return armnn::Optional<std::string>("ActivateCounters example test error");
+ }
+ m_CapturePeriod = capturePeriod;
+ m_ActiveCounters = counterIds;
+ return armnn::Optional<std::string>();
+ }
+
+ std::vector<arm::pipe::Timestamp> ReportCounterValues()
+ {
+ std::vector<arm::pipe::CounterValue> counterValues;
+
+ for (auto counterId : m_ActiveCounters)
+ {
+ counterValues.emplace_back(arm::pipe::CounterValue{ counterId, counterId + 1u });
+ }
+
+ uint64_t timestamp = m_CapturePeriod;
+ return { arm::pipe::Timestamp{ timestamp, counterValues } };
+ }
+
+ bool EnableProfiling(bool)
+ {
+ auto sendTimelinePacket = m_BackendProfiling->GetSendTimelinePacket();
+ sendTimelinePacket->SendTimelineEntityBinaryPacket(4256);
+ sendTimelinePacket->Commit();
+ return true;
+ }
+
+ bool EnableTimelineReporting(bool isEnabled)
+ {
+ m_IsTimelineEnabled = isEnabled;
+ return isEnabled;
+ }
+
+ bool TimelineReportingEnabled()
+ {
+ return m_IsTimelineEnabled;
+ }
+
+private:
+ IBackendInternal::IBackendProfilingPtr m_BackendProfiling;
+ uint32_t m_CapturePeriod;
+ std::vector<uint16_t> m_ActiveCounters;
+ std::atomic<bool> m_IsTimelineEnabled;
+};
+
+class MockBackendProfilingService
+{
+public:
+ // Getter for the singleton instance
+ static MockBackendProfilingService& Instance()
+ {
+ static MockBackendProfilingService instance;
+ return instance;
+ }
+
+ MockBackendProfilingContext* GetContext()
+ {
+ return m_sharedContext.get();
+ }
+
+ void SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared)
+ {
+ m_sharedContext = shared;
+ }
+
+private:
+ std::shared_ptr<MockBackendProfilingContext> m_sharedContext;
+};
+
+class MockLayerSupport : public LayerSupportBase
+{
+public:
+ bool IsLayerSupported(const LayerType& type,
+ const std::vector<TensorInfo>& infos,
+ const BaseDescriptor& descriptor,
+ const Optional<LstmInputParamsInfo>& /*lstmParamsInfo*/,
+ const Optional<QuantizedLstmInputParamsInfo>& /*quantizedLstmParamsInfo*/,
+ Optional<std::string&> reasonIfUnsupported) const override
+ {
+ switch(type)
+ {
+ case LayerType::Input:
+ return IsInputSupported(infos[0], reasonIfUnsupported);
+ case LayerType::Output:
+ return IsOutputSupported(infos[0], reasonIfUnsupported);
+ case LayerType::Addition:
+ return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
+ case LayerType::Convolution2d:
+ {
+ if (infos.size() != 4)
+ {
+ throw InvalidArgumentException("Invalid number of TransposeConvolution2d "
+ "TensorInfos. TensorInfos should be of format: "
+ "{input, output, weights, biases}.");
+ }
+
+ auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
+ if (infos[3] == TensorInfo())
+ {
+ return IsConvolution2dSupported(infos[0],
+ infos[1],
+ desc,
+ infos[2],
+ EmptyOptional(),
+ reasonIfUnsupported);
+ }
+ else
+ {
+ return IsConvolution2dSupported(infos[0],
+ infos[1],
+ desc,
+ infos[2],
+ infos[3],
+ reasonIfUnsupported);
+ }
+ }
+ default:
+ return false;
+ }
+ }
+
+ bool IsInputSupported(const TensorInfo& /*input*/,
+ Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
+ {
+ return true;
+ }
+
+ bool IsOutputSupported(const TensorInfo& /*input*/,
+ Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
+ {
+ return true;
+ }
+
+ bool IsAdditionSupported(const TensorInfo& /*input0*/,
+ const TensorInfo& /*input1*/,
+ const TensorInfo& /*output*/,
+ Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
+ {
+ return true;
+ }
+
+ bool IsConvolution2dSupported(const TensorInfo& /*input*/,
+ const TensorInfo& /*output*/,
+ const Convolution2dDescriptor& /*descriptor*/,
+ const TensorInfo& /*weights*/,
+ const Optional<TensorInfo>& /*biases*/,
+ Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
+ {
+ return true;
+ }
+};
+
} // namespace armnn