diff options
Diffstat (limited to 'src/profiling/test/ProfilingTests.hpp')
-rw-r--r-- | src/profiling/test/ProfilingTests.hpp | 200 |
1 files changed, 200 insertions, 0 deletions
diff --git a/src/profiling/test/ProfilingTests.hpp b/src/profiling/test/ProfilingTests.hpp new file mode 100644 index 0000000000..3e6cf63efe --- /dev/null +++ b/src/profiling/test/ProfilingTests.hpp @@ -0,0 +1,200 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "SendCounterPacketTests.hpp" + +#include <CommandHandlerFunctor.hpp> +#include <IProfilingConnection.hpp> +#include <IProfilingConnectionFactory.hpp> +#include <Logging.hpp> +#include <ProfilingService.hpp> + +#include <boost/test/unit_test.hpp> + +#include <chrono> +#include <iostream> +#include <thread> + +namespace armnn +{ + +namespace profiling +{ + +struct LogLevelSwapper +{ +public: + LogLevelSwapper(armnn::LogSeverity severity) + { + // Set the new log level + armnn::ConfigureLogging(true, true, severity); + } + ~LogLevelSwapper() + { + // The default log level for unit tests is "Fatal" + armnn::ConfigureLogging(true, true, armnn::LogSeverity::Fatal); + } +}; + +struct CoutRedirect +{ +public: + CoutRedirect(std::streambuf* newStreamBuffer) + : m_Old(std::cout.rdbuf(newStreamBuffer)) {} + ~CoutRedirect() { std::cout.rdbuf(m_Old); } + +private: + std::streambuf* m_Old; +}; + +struct StreamRedirector +{ +public: + StreamRedirector(std::ostream& stream, std::streambuf* newStreamBuffer) + : m_Stream(stream) + , m_BackupBuffer(m_Stream.rdbuf(newStreamBuffer)) + {} + ~StreamRedirector() { m_Stream.rdbuf(m_BackupBuffer); } + +private: + std::ostream& m_Stream; + std::streambuf* m_BackupBuffer; +}; + +class TestProfilingConnectionBase : public IProfilingConnection +{ +public: + TestProfilingConnectionBase() = default; + ~TestProfilingConnectionBase() = default; + + bool IsOpen() const override { return true; } + + void Close() override {} + + bool WritePacket(const unsigned char* buffer, uint32_t length) override { return false; } + + Packet ReadPacket(uint32_t timeout) override + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + // Return connection acknowledged packet + std::unique_ptr<char[]> packetData; + return Packet(65536, 0, packetData); + } +}; + +class TestProfilingConnectionTimeoutError : public TestProfilingConnectionBase +{ +public: + TestProfilingConnectionTimeoutError() + : m_ReadRequests(0) + {} + + Packet ReadPacket(uint32_t timeout) override + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + if (m_ReadRequests < 3) + { + m_ReadRequests++; + throw armnn::TimeoutException("Simulate a timeout error\n"); + } + + // Return connection acknowledged packet after three timeouts + std::unique_ptr<char[]> packetData; + return Packet(65536, 0, packetData); + } + +private: + int m_ReadRequests; +}; + +class TestProfilingConnectionArmnnError : public TestProfilingConnectionBase +{ +public: + Packet ReadPacket(uint32_t timeout) override + { + std::this_thread::sleep_for(std::chrono::milliseconds(timeout)); + + throw armnn::Exception("Simulate a non-timeout error"); + } +}; + +class TestFunctorA : public CommandHandlerFunctor +{ +public: + using CommandHandlerFunctor::CommandHandlerFunctor; + + int GetCount() { return m_Count; } + + void operator()(const Packet& packet) override + { + m_Count++; + } + +private: + int m_Count = 0; +}; + +class TestFunctorB : public TestFunctorA +{ + using TestFunctorA::TestFunctorA; +}; + +class TestFunctorC : public TestFunctorA +{ + using TestFunctorA::TestFunctorA; +}; + +class MockProfilingConnectionFactory : public IProfilingConnectionFactory +{ +public: + IProfilingConnectionPtr GetProfilingConnection(const ExternalProfilingOptions& options) const override + { + return std::make_unique<MockProfilingConnection>(); + } +}; + +class SwapProfilingConnectionFactoryHelper : public ProfilingService +{ +public: + using MockProfilingConnectionFactoryPtr = std::unique_ptr<MockProfilingConnectionFactory>; + + SwapProfilingConnectionFactoryHelper() + : ProfilingService() + , m_MockProfilingConnectionFactory(new MockProfilingConnectionFactory()) + , m_BackupProfilingConnectionFactory(nullptr) + { + BOOST_CHECK(m_MockProfilingConnectionFactory); + SwapProfilingConnectionFactory(ProfilingService::Instance(), + m_MockProfilingConnectionFactory.get(), + m_BackupProfilingConnectionFactory); + BOOST_CHECK(m_BackupProfilingConnectionFactory); + } + ~SwapProfilingConnectionFactoryHelper() + { + BOOST_CHECK(m_BackupProfilingConnectionFactory); + IProfilingConnectionFactory* temp = nullptr; + SwapProfilingConnectionFactory(ProfilingService::Instance(), + m_BackupProfilingConnectionFactory, + temp); + } + + MockProfilingConnection* GetMockProfilingConnection() + { + IProfilingConnection* profilingConnection = GetProfilingConnection(ProfilingService::Instance()); + return boost::polymorphic_downcast<MockProfilingConnection*>(profilingConnection); + } + +private: + MockProfilingConnectionFactoryPtr m_MockProfilingConnectionFactory; + IProfilingConnectionFactory* m_BackupProfilingConnectionFactory; +}; + +} // namespace profiling + +} // namespace armnn |