aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadFactory.cpp
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2020-09-10 13:37:32 +0100
committerRyan O'Shea <ryan.oshea2@arm.com>2020-09-10 18:04:17 +0000
commit045f6be924240a560293a3a7a0ecae49bcf0d1fa (patch)
tree3193fb35288ad8011cdfb9082d82085f48b6792b /src/backends/backendsCommon/WorkloadFactory.cpp
parent08f4016b8ae8ee836fc813abcbc7db826924f3ec (diff)
downloadarmnn-045f6be924240a560293a3a7a0ecae49bcf0d1fa.tar.gz
IVGCVSW-5156 Introduce ModelOptions to OptimizedNetwork
* Introduced ModelOptions to IBackendInternal * Introduced ModelOptions to Network * Added FastMathEnabled parameter to Conv2d Validate function in CL and NEON * Added Optimizer tests Signed-off-by: Ryan OShea <Ryan.OShea2@arm.com> Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: Ib54c1e82cb3d89a52756ed499cf91b6a7fdb2063
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp35
1 files changed, 29 insertions, 6 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index 09d7c2d568..0bafda257c 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -39,10 +39,11 @@ const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> typ
} // anonymous namespace
-bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
- const IConnectableLayer& connectableLayer,
- Optional<DataType> dataType,
- std::string& outReasonIfUnsupported)
+bool IWorkloadFactory::IsLayerConfigurationSupported(const BackendId& backendId,
+ const IConnectableLayer& connectableLayer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported,
+ const ModelOptions& modelOptions)
{
Optional<std::string&> reason = outReasonIfUnsupported;
bool result;
@@ -61,7 +62,7 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
auto backendFactory = backendRegistry.GetFactory(backendId);
auto backendObject = backendFactory();
- auto layerSupportObject = backendObject->GetLayerSupport();
+ auto layerSupportObject = backendObject->GetLayerSupport(modelOptions);
switch(layer.GetType())
{
@@ -1212,12 +1213,34 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
return result;
}
+bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
+ const IConnectableLayer& connectableLayer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported)
+{
+ return IsLayerConfigurationSupported(backendId, connectableLayer, dataType, outReasonIfUnsupported);
+}
+
bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
Optional<DataType> dataType,
std::string& outReasonIfUnsupported)
{
auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
- return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
+ return IsLayerConfigurationSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
+}
+
+// TODO merge with defaulted modelOptions above
+bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
+ Optional<DataType> dataType,
+ std::string& outReasonIfUnsupported,
+ const ModelOptions& modelOptions)
+{
+ auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
+ return IsLayerConfigurationSupported(layer->GetBackendId(),
+ connectableLayer,
+ dataType,
+ outReasonIfUnsupported,
+ modelOptions);
}
// Default Implementations