aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp35
1 files changed, 29 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 09d7c2d568..0bafda257c 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -39,10 +39,11 @@ const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> typ
} // anonymous namespace
-bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
- const IConnectableLayer& connectableLayer,
- Optional<DataType> dataType,
- std::string& outReasonIfUnsupported)
+bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
+ const IConnectableLayer& connectableLayer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported,
+ const ModelOptions& modelOptions)
{
Optional<std::string&> reason = outReasonIfUnsupported;
bool result;
@@ -61,7 +62,7 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
auto backendFactory = backendRegistry.GetFactory(backendId);
auto backendObject = backendFactory();
- auto layerSupportObject = backendObject->GetLayerSupport();
+ auto layerSupportObject = backendObject->GetLayerSupport(modelOptions);
switch(layer.GetType())
{
@@ -1212,12 +1213,34 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
return result;
}
+bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
+ const IConnectableLayer& connectableLayer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported)
+{
+ return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
+}
+
bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
Optional<DataType> dataType,
std::string& outReasonIfUnsupported)
{
auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
- return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
+ return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
+}
+
+// TODO merge with defaulted modelOptions above
+bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported,
+ const ModelOptions& modelOptions)
+{
+ auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
+ return IsLayerConfigurationSupported(layer->GetBackendId(),
+ connectableLayer,
+ dataType,
+ outReasonIfUnsupported,
+ modelOptions);
}
// Default Implementations