aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Runtime.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
committertelsoa01 <telmo.soares@arm.com>2018-03-09 14:13:49 +0000
commit4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch)
treec9a70aeb2887006160c1b3d265c27efadb7bdbae /src/armnn/Runtime.cpp
downloadarmnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'src/armnn/Runtime.cpp')
-rw-r--r--src/armnn/Runtime.cpp118
1 files changed, 118 insertions, 0 deletions
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
new file mode 100644
index 0000000000..ea6d19bd31
--- /dev/null
+++ b/src/armnn/Runtime.cpp
@@ -0,0 +1,118 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+#include "Runtime.hpp"
+
+#include "armnn/Version.hpp"
+
+#ifdef ARMCOMPUTECL_ENABLED
+#include <arm_compute/core/CL/OpenCL.h>
+#include <arm_compute/core/CL/CLKernelLibrary.h>
+#endif
+
+#include <boost/log/trivial.hpp>
+#include <boost/polymorphic_cast.hpp>
+
+using namespace armnn;
+using namespace std;
+
+namespace armnn
+{
+
+IRuntime* IRuntime::CreateRaw(const CreationOptions& options)
+{
+ return new Runtime(options);
+}
+
+IRuntimePtr IRuntime::Create(const CreationOptions& options)
+{
+ return IRuntimePtr(CreateRaw(options), &IRuntime::Destroy);
+}
+
+void IRuntime::Destroy(IRuntime* runtime)
+{
+ delete boost::polymorphic_downcast<Runtime*>(runtime);
+}
+
+int Runtime::GenerateNetworkId()
+{
+ return m_NetworkIdCounter++;
+}
+
+Status Runtime::LoadNetwork(NetworkId& networkIdOut, IOptimizedNetworkPtr inNetwork)
+{
+ IOptimizedNetwork* rawNetwork = inNetwork.release();
+ unique_ptr<LoadedNetwork> loadedNetwork = LoadedNetwork::MakeLoadedNetwork(
+ std::unique_ptr<OptimizedNetwork>(boost::polymorphic_downcast<OptimizedNetwork*>(rawNetwork)),
+ m_WorkloadFactories);
+
+ if (!loadedNetwork)
+ {
+ return Status::Failure;
+ }
+
+ networkIdOut = GenerateNetworkId();
+
+ // store the network
+ m_LoadedNetworks[networkIdOut] = std::move(loadedNetwork);
+
+ return Status::Success;
+
+}
+
+Status Runtime::UnloadNetwork(NetworkId networkId)
+{
+ if (m_LoadedNetworks.erase(networkId) == 0)
+ {
+ BOOST_LOG_TRIVIAL(warning) << "WARNING: Runtime::UnloadNetwork(): " << networkId << " not found!";
+ return Status::Failure;
+ }
+#ifdef ARMCOMPUTECL_ENABLED
+ arm_compute::CLKernelLibrary::get().clear_programs_cache();
+#endif
+ BOOST_LOG_TRIVIAL(debug) << "Runtime::UnloadNetwork(): Unloaded network with ID: " << networkId;
+ return Status::Success;
+}
+
+Runtime::Runtime(const CreationOptions& options)
+: m_NetworkIdCounter(0)
+{
+ BOOST_LOG_TRIVIAL(info) << "ArmNN v" << ARMNN_VERSION << "\n";
+ BOOST_LOG_TRIVIAL(info) << "Using compute device: " << options.m_DefaultComputeDevice << "\n";
+ m_DeviceSpec.DefaultComputeDevice = options.m_DefaultComputeDevice;
+
+ // If useCpuRefAsFallback is false, the reference workload factory will be prevented from creating
+ // operation workloads, unless the default compute device is precisely the reference backend.
+ m_WorkloadFactories.m_CpuRef = make_shared<RefWorkloadFactory>(
+ options.m_DefaultComputeDevice == Compute::CpuRef ? true : options.m_UseCpuRefAsFallback);
+ m_WorkloadFactories.m_CpuAcc = make_shared<NeonWorkloadFactory>();
+ m_WorkloadFactories.m_GpuAcc = make_shared<ClWorkloadFactory>();
+
+ if (options.m_DefaultComputeDevice == Compute::GpuAcc)
+ {
+ m_WorkloadFactories.m_GpuAcc.get()->LoadOpenClRuntime(options.m_ClTunedParameters);
+ }
+}
+
+TensorInfo Runtime::GetInputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
+{
+ LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
+ return net->GetInputTensorInfo(layerId);
+}
+
+TensorInfo Runtime::GetOutputTensorInfo(NetworkId networkId, LayerBindingId layerId) const
+{
+ const LoadedNetwork* net = m_LoadedNetworks.at(networkId).get();
+ return net->GetOutputTensorInfo(layerId);
+}
+
+Status Runtime::EnqueueWorkload(NetworkId networkId,
+ const InputTensors& inputTensors,
+ const OutputTensors& outputTensors)
+{
+ LoadedNetwork* loadedNetwork = m_LoadedNetworks.at(networkId).get();
+ return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors, m_WorkloadFactories);
+}
+
+}