aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Network.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Network.cpp')
-rw-r--r--src/armnn/Network.cpp32
1 files changed, 28 insertions, 4 deletions
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index 339da0d1b8..a3655509fb 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -1658,7 +1658,7 @@ OptimizationResult SelectTensorHandleStrategy(Graph& optGraph,
return result;
}
-IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
+IOptimizedNetworkPtr Optimize(const Graph& inGraph,
const std::vector<BackendId>& backendPreferences,
const IDeviceSpec& deviceSpec,
const OptimizerOptions& options,
@@ -1667,7 +1667,7 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
ARMNN_LOG(debug) << options.ToString();
// Enable profiling
- auto profiler = inNetwork.pNetworkImpl->GetGraph().GetProfiler();
+ auto profiler = inGraph.GetProfiler();
ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
profiler->EnableProfiling(options.m_ProfilingEnabled);
@@ -1683,9 +1683,9 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
}
// Ensure TensorInfo is set on all output slots of ConstantLayers in the graph
- inNetwork.pNetworkImpl->GetGraph().VerifyConstantLayerSetTensorInfo();
+ inGraph.VerifyConstantLayerSetTensorInfo();
- std::unique_ptr<Graph> graph = std::make_unique<Graph>(inNetwork.pNetworkImpl->GetGraph());
+ std::unique_ptr<Graph> graph = std::make_unique<Graph>(inGraph);
auto optNet = IOptimizedNetworkPtr(new IOptimizedNetwork(std::move(graph), options.m_ModelOptions),
&IOptimizedNetwork::Destroy);
@@ -1827,6 +1827,20 @@ IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
}
return optNet;
}
+
+IOptimizedNetworkPtr Optimize(const INetwork& inNetwork,
+ const std::vector<BackendId>& backendPreferences,
+ const IDeviceSpec& deviceSpec,
+ const OptimizerOptions& options,
+ Optional<std::vector<std::string>&> messages)
+{
+ return Optimize(inNetwork.pNetworkImpl->GetGraph(),
+ backendPreferences,
+ deviceSpec,
+ options,
+ messages);
+}
+
bool NetworkImpl::GetShapeInferenceMethod()
{
if (m_NetworkOptions.size() > 0 && m_NetworkOptions[0].GetBackendId().Get() == "ShapeInferenceMethod")
@@ -2000,6 +2014,16 @@ IConnectableLayer* NetworkImpl::AddConvolution2dLayerImpl(const Convolution2dDes
return layer;
}
+IConnectableLayer* NetworkImpl::AddConvertFp16ToFp32Layer(const char* name)
+{
+ return m_Graph->AddLayer<ConvertFp16ToFp32Layer>(name);
+}
+
+IConnectableLayer* NetworkImpl::AddConvertFp32ToFp16Layer(const char* name)
+{
+ return m_Graph->AddLayer<ConvertFp32ToFp16Layer>(name);
+}
+
IConnectableLayer* NetworkImpl::AddConvolution2dLayer(const Convolution2dDescriptor& convolution2dDescriptor,
const ConstTensor& weights,
const Optional<ConstTensor>& biases,