diff options
Diffstat (limited to 'src/backends/cl/ClWorkloadFactory.cpp')
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 34 |
1 files changed, 32 insertions, 2 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 58e17df5b8..4acfa570f2 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -4,6 +4,7 @@ // #include "ClWorkloadFactory.hpp" #include "ClBackendId.hpp" +#include "ClBackendModelContext.hpp" #include <Layer.hpp> @@ -42,6 +43,14 @@ bool ClWorkloadFactory::IsLayerSupported(const Layer& layer, return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported); } +bool ClWorkloadFactory::IsLayerSupported(const IConnectableLayer& layer, + Optional<DataType> dataType, + std::string& outReasonIfUnsupported, + const ModelOptions& modelOptions) +{ + return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions); +} + const BackendId& ClWorkloadFactory::GetBackendId() const { return s_Id; @@ -78,7 +87,13 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::MakeWorkload(const QueueDescriptor } ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager) - : m_MemoryManager(memoryManager) + : m_MemoryManager(memoryManager), m_ModelContextPtr(IBackendInternal::IBackendSpecificModelContextPtr{}) +{ +} + +ClWorkloadFactory::ClWorkloadFactory(const std::shared_ptr<ClMemoryManager>& memoryManager, + const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr) + : m_MemoryManager(memoryManager), m_ModelContextPtr(modelContextPtr) { } @@ -205,7 +220,22 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvertFp32ToFp16( std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<ClConvolution2dWorkload>(descriptor, info, m_MemoryManager->GetIntraLayerManager()); + bool isFastMathEnabled = false; + if (m_ModelContextPtr) + { + if (m_ModelContextPtr.get() != nullptr) + { + auto modelOptions = dynamic_cast<ClBackendModelContext*>(m_ModelContextPtr.get()); + if (modelOptions) + { + isFastMathEnabled = modelOptions->IsFastMathEnabled(); + } + } + } + return MakeWorkload<ClConvolution2dWorkload>(descriptor, + info, + m_MemoryManager->GetIntraLayerManager(), + isFastMathEnabled); } std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor, |