diff options
Diffstat (limited to 'src/backends/reference/workloads/RefDebugWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefDebugWorkload.cpp | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/src/backends/reference/workloads/RefDebugWorkload.cpp b/src/backends/reference/workloads/RefDebugWorkload.cpp index 412d399adc..3da925913a 100644 --- a/src/backends/reference/workloads/RefDebugWorkload.cpp +++ b/src/backends/reference/workloads/RefDebugWorkload.cpp @@ -9,6 +9,8 @@ #include <TypeUtils.hpp> +#include <cstring> + namespace armnn { @@ -20,12 +22,27 @@ void RefDebugWorkload<DataType>::Execute() const ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, GetName() + "_Execute"); const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]); - const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); const T* inputData = GetInputTensorData<T>(0, m_Data); T* outputData = GetOutputTensorData<T>(0, m_Data); - Debug(inputInfo, outputInfo, inputData, outputData, m_Data.m_Guid, m_Data.m_LayerName, m_Data.m_SlotIndex); + if (m_Callback) + { + m_Callback(m_Data.m_Guid, m_Data.m_SlotIndex, m_Data.m_Inputs[0]); + } + else + { + Debug(inputInfo, inputData, m_Data.m_Guid, m_Data.m_LayerName, m_Data.m_SlotIndex); + } + + std::memcpy(outputData, inputData, inputInfo.GetNumElements()*sizeof(T)); + +} + +template<armnn::DataType DataType> +void RefDebugWorkload<DataType>::RegisterDebugCallback(const DebugCallbackFunction& func) +{ + m_Callback = func; } template class RefDebugWorkload<DataType::Float32>; |