diff options
Diffstat (limited to 'src/dynamic/sample')
-rw-r--r-- | src/dynamic/sample/SampleDynamicLayerSupport.hpp | 8 | ||||
-rw-r--r-- | src/dynamic/sample/SampleDynamicWorkloadFactory.cpp | 30 | ||||
-rw-r--r-- | src/dynamic/sample/SampleDynamicWorkloadFactory.hpp | 10 |
3 files changed, 41 insertions, 7 deletions
diff --git a/src/dynamic/sample/SampleDynamicLayerSupport.hpp b/src/dynamic/sample/SampleDynamicLayerSupport.hpp index 2f0744aab7..3881ad760e 100644 --- a/src/dynamic/sample/SampleDynamicLayerSupport.hpp +++ b/src/dynamic/sample/SampleDynamicLayerSupport.hpp @@ -16,20 +16,20 @@ public: bool IsAdditionSupported(const armnn::TensorInfo& input0, const armnn::TensorInfo& input1, const armnn::TensorInfo& output, - armnn::Optional<std::string&> reasonIfUnsupported = armnn::EmptyOptional()) const override; + armnn::Optional<std::string&> reasonIfUnsupported = armnn::EmptyOptional()) const; bool IsInputSupported(const armnn::TensorInfo& input, - armnn::Optional<std::string&> reasonIfUnsupported) const override; + armnn::Optional<std::string&> reasonIfUnsupported) const; bool IsOutputSupported(const armnn::TensorInfo& output, - armnn::Optional<std::string&> reasonIfUnsupported) const override; + armnn::Optional<std::string&> reasonIfUnsupported) const; bool IsLayerSupported(const armnn::LayerType& type, const std::vector<armnn::TensorInfo>& infos, const armnn::BaseDescriptor& descriptor, const armnn::Optional<armnn::LstmInputParamsInfo>& lstmParamsInfo, const armnn::Optional<armnn::QuantizedLstmInputParamsInfo>& quantizedLstmParamsInfo, - armnn::Optional<std::string&> reasonIfUnsupported = armnn::EmptyOptional()) const override; + armnn::Optional<std::string&> reasonIfUnsupported = armnn::EmptyOptional()) const; }; } // namespace sdb diff --git a/src/dynamic/sample/SampleDynamicWorkloadFactory.cpp b/src/dynamic/sample/SampleDynamicWorkloadFactory.cpp index 8796716c98..d4be0fcb3e 100644 --- a/src/dynamic/sample/SampleDynamicWorkloadFactory.cpp +++ b/src/dynamic/sample/SampleDynamicWorkloadFactory.cpp @@ -5,6 +5,7 @@ #include <armnn/backends/MemCopyWorkload.hpp> #include <armnn/backends/TensorHandle.hpp> +#include <armnn/utility/PolymorphicDowncast.hpp> #include "SampleDynamicAdditionWorkload.hpp" #include "SampleDynamicBackend.hpp" @@ -77,4 +78,33 @@ std::unique_ptr<armnn::IWorkload> SampleDynamicWorkloadFactory::CreateOutput( return std::make_unique<armnn::CopyMemGenericWorkload>(descriptor, info); } +std::unique_ptr<armnn::IWorkload> SampleDynamicWorkloadFactory::CreateWorkload( + armnn::LayerType type, + const armnn::QueueDescriptor& descriptor, + const armnn::WorkloadInfo& info) const +{ + using namespace armnn; + using namespace sdb; + switch(type) + { + case LayerType::Addition: + { + auto additionQueueDescriptor = PolymorphicDowncast<const AdditionQueueDescriptor*>(&descriptor); + return std::make_unique<SampleDynamicAdditionWorkload>(*additionQueueDescriptor, info); + } + case LayerType::Input: + { + auto inputQueueDescriptor = PolymorphicDowncast<const InputQueueDescriptor*>(&descriptor); + return std::make_unique<CopyMemGenericWorkload>(*inputQueueDescriptor, info); + } + case LayerType::Output: + { + auto outputQueueDescriptor = PolymorphicDowncast<const OutputQueueDescriptor*>(&descriptor); + return std::make_unique<CopyMemGenericWorkload>(*outputQueueDescriptor, info); + } + default: + return nullptr; + } +} + } // namespace sdb diff --git a/src/dynamic/sample/SampleDynamicWorkloadFactory.hpp b/src/dynamic/sample/SampleDynamicWorkloadFactory.hpp index a5a31e5d6e..8cd36c5518 100644 --- a/src/dynamic/sample/SampleDynamicWorkloadFactory.hpp +++ b/src/dynamic/sample/SampleDynamicWorkloadFactory.hpp @@ -52,14 +52,18 @@ public: std::unique_ptr<armnn::IWorkload> CreateAddition( const armnn::AdditionQueueDescriptor& descriptor, - const armnn::WorkloadInfo& info) const override; + const armnn::WorkloadInfo& info) const; std::unique_ptr<armnn::IWorkload> CreateInput(const armnn::InputQueueDescriptor& descriptor, - const armnn::WorkloadInfo& info) const override; + const armnn::WorkloadInfo& info) const; std::unique_ptr<armnn::IWorkload> CreateOutput(const armnn::OutputQueueDescriptor& descriptor, - const armnn::WorkloadInfo& info) const override; + const armnn::WorkloadInfo& info) const; + + std::unique_ptr<armnn::IWorkload> CreateWorkload(armnn::LayerType type, + const armnn::QueueDescriptor& descriptor, + const armnn::WorkloadInfo& info) const override; private: mutable std::shared_ptr<SampleMemoryManager> m_MemoryManager; |