aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNattapat Chaimanowong <nattapat.chaimanowong@arm.com>2019-03-22 14:01:46 +0000
committernattapat.chaimanowong <nattapat.chaimanowong@arm.com>2019-03-22 14:38:31 +0000
commit6e9482013f41725ccca0767c0c5db9b29f77d981 (patch)
tree7b33e2cb17009c3103b5731f8d8ffa2b37484347
parent2a304ede65afb3426dccbd742ac3f5b42e8fb04a (diff)
downloadarmnn-6e9482013f41725ccca0767c0c5db9b29f77d981.tar.gz
IVGCVSW-2865 Extend IRuntime to add a new method RegisterDebugCallback(...)
* Made changes to LoadedNetwork and IWorkload to pass on the registered callback function Change-Id: I6ea10f2a299d6de8bf681c8ff36d3fbed1d6d887 Signed-off-by: Nattapat Chaimanowong <nattapat.chaimanowong@arm.com>
-rw-r--r--include/armnn/IRuntime.hpp5
-rw-r--r--include/armnn/Types.hpp5
-rw-r--r--src/armnn/LoadedNetwork.cpp8
-rw-r--r--src/armnn/LoadedNetwork.hpp2
-rw-r--r--src/armnn/Runtime.cpp6
-rw-r--r--src/armnn/Runtime.hpp5
-rw-r--r--src/backends/backendsCommon/Workload.hpp2
7 files changed, 33 insertions, 0 deletions
diff --git a/include/armnn/IRuntime.hpp b/include/armnn/IRuntime.hpp
index b977afe5e7..44864ce86d 100644
--- a/include/armnn/IRuntime.hpp
+++ b/include/armnn/IRuntime.hpp
@@ -83,6 +83,11 @@ public:
/// @return A pointer to the requested profiler, or nullptr if not found.
virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const = 0;
+ /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
+ /// @param networkId The id of the network to register the callback.
+ /// @param func callback function to pass to the debug layer.
+ virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) = 0;
+
protected:
~IRuntime() {}
};
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index 36e3c5b52a..693a050586 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -180,4 +180,9 @@ private:
/// Define LayerGuid type.
using LayerGuid = unsigned int;
+class ITensorHandle;
+
+/// Define the type of callback for the debug layer to call
+using DebugCallbackFunction = std::function<void(LayerGuid, unsigned int, ITensorHandle*)>;
+
} // namespace armnn
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 4221d36036..9263f1a6e9 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -485,4 +485,12 @@ bool LoadedNetwork::Execute()
return success;
}
+void LoadedNetwork::RegisterDebugCallback(const DebugCallbackFunction& func)
+{
+ for (auto&& workloadPtr: m_WorkloadQueue)
+ {
+ workloadPtr.get()->RegisterDebugCallback(func);
+ }
+}
+
}
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index 9c0fe0b108..75af4a4e28 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -49,6 +49,8 @@ public:
void FreeWorkingMemory();
+ void RegisterDebugCallback(const DebugCallbackFunction& func);
+
private:
void AllocateWorkingMemory();
diff --git a/src/armnn/Runtime.cpp b/src/armnn/Runtime.cpp
index 09be92c709..f8b2462f96 100644
--- a/src/armnn/Runtime.cpp
+++ b/src/armnn/Runtime.cpp
@@ -231,4 +231,10 @@ Status Runtime::EnqueueWorkload(NetworkId networkId,
return loadedNetwork->EnqueueWorkload(inputTensors, outputTensors);
}
+void Runtime::RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func)
+{
+ LoadedNetwork* loadedNetwork = GetLoadedNetworkPtr(networkId);
+ loadedNetwork->RegisterDebugCallback(func);
+}
+
}
diff --git a/src/armnn/Runtime.hpp b/src/armnn/Runtime.hpp
index a3f4a3930b..10383bc970 100644
--- a/src/armnn/Runtime.hpp
+++ b/src/armnn/Runtime.hpp
@@ -59,6 +59,11 @@ public:
/// @return A pointer to the requested profiler, or nullptr if not found.
virtual const std::shared_ptr<IProfiler> GetProfiler(NetworkId networkId) const override;
+ /// Registers a callback function to debug layers performing custom computations on intermediate tensors.
+ /// @param networkId The id of the network to register the callback.
+ /// @param func callback function to pass to the debug layer.
+ virtual void RegisterDebugCallback(NetworkId networkId, const DebugCallbackFunction& func) override;
+
/// Creates a runtime for workload execution.
/// May throw a ClRuntimeUnavailableException if @a defaultComputeDevice requires a CL runtime but
/// it cannot be setup for some reason.
diff --git a/src/backends/backendsCommon/Workload.hpp b/src/backends/backendsCommon/Workload.hpp
index 7fb26f8b56..447ec1b4d6 100644
--- a/src/backends/backendsCommon/Workload.hpp
+++ b/src/backends/backendsCommon/Workload.hpp
@@ -21,6 +21,8 @@ public:
virtual ~IWorkload() {}
virtual void Execute() const = 0;
+
+ virtual void RegisterDebugCallback(const DebugCallbackFunction& func) {}
};
// NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template