aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-11-12 16:15:11 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-11-15 11:57:50 +0000
commitdb1a2834b4e9d74ed538943634212eccbd4a789b (patch)
tree10162ed219468bc043ed66e40ea657fb867b0ee1
parent60538ada2b90704abcf6473144639103d80287a5 (diff)
downloadarmnn-db1a2834b4e9d74ed538943634212eccbd4a789b.tar.gz
Add FP16 support to DebugWorkload
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: Ia879f2d84a1b977474ee0dafa976f2aab32bd3ae
-rw-r--r--src/backends/reference/RefLayerSupport.cpp3
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp5
-rw-r--r--src/backends/reference/workloads/Debug.cpp10
-rw-r--r--src/backends/reference/workloads/RefDebugWorkload.cpp1
-rw-r--r--src/backends/reference/workloads/RefDebugWorkload.hpp1
5 files changed, 19 insertions, 1 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index ef0cc8c363..5a84d8ac78 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -495,8 +495,9 @@ bool RefLayerSupport::IsDebugSupported(const TensorInfo& input,
{
bool supported = true;
- std::array<DataType,3> supportedTypes =
+ std::array<DataType, 4> supportedTypes =
{
+ DataType::Float16,
DataType::Float32,
DataType::QuantisedAsymm8,
DataType::QuantisedSymm16
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index c2cb51abf3..7fd93435e7 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -172,10 +172,15 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvolution2d(const Convolu
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
+ if (IsFloat16(info))
+ {
+ return std::make_unique<RefDebugFloat16Workload>(descriptor, info);
+ }
if (IsQSymm16(info))
{
return std::make_unique<RefDebugQSymm16Workload>(descriptor, info);
}
+
return MakeWorkload<RefDebugFloat32Workload, RefDebugQAsymm8Workload>(descriptor, info);
}
diff --git a/src/backends/reference/workloads/Debug.cpp b/src/backends/reference/workloads/Debug.cpp
index 09a0dfc03b..b7d0911ef3 100644
--- a/src/backends/reference/workloads/Debug.cpp
+++ b/src/backends/reference/workloads/Debug.cpp
@@ -2,8 +2,11 @@
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
+
#include "Debug.hpp"
+#include <Half.hpp>
+
#include <boost/numeric/conversion/cast.hpp>
#include <algorithm>
@@ -85,6 +88,12 @@ void Debug(const TensorInfo& inputInfo,
std::cout << " }" << std::endl;
}
+template void Debug<Half>(const TensorInfo& inputInfo,
+ const Half* inputData,
+ LayerGuid guid,
+ const std::string& layerName,
+ unsigned int slotIndex);
+
template void Debug<float>(const TensorInfo& inputInfo,
const float* inputData,
LayerGuid guid,
@@ -102,4 +111,5 @@ template void Debug<int16_t>(const TensorInfo& inputInfo,
LayerGuid guid,
const std::string& layerName,
unsigned int slotIndex);
+
} // namespace armnn
diff --git a/src/backends/reference/workloads/RefDebugWorkload.cpp b/src/backends/reference/workloads/RefDebugWorkload.cpp
index 325817b19f..2a3883f8f7 100644
--- a/src/backends/reference/workloads/RefDebugWorkload.cpp
+++ b/src/backends/reference/workloads/RefDebugWorkload.cpp
@@ -44,6 +44,7 @@ void RefDebugWorkload<DataType>::RegisterDebugCallback(const DebugCallbackFuncti
m_Callback = func;
}
+template class RefDebugWorkload<DataType::Float16>;
template class RefDebugWorkload<DataType::Float32>;
template class RefDebugWorkload<DataType::QuantisedAsymm8>;
template class RefDebugWorkload<DataType::QuantisedSymm16>;
diff --git a/src/backends/reference/workloads/RefDebugWorkload.hpp b/src/backends/reference/workloads/RefDebugWorkload.hpp
index 6a1fceba0a..0964515b2c 100644
--- a/src/backends/reference/workloads/RefDebugWorkload.hpp
+++ b/src/backends/reference/workloads/RefDebugWorkload.hpp
@@ -37,6 +37,7 @@ private:
DebugCallbackFunction m_Callback;
};
+using RefDebugFloat16Workload = RefDebugWorkload<DataType::Float16>;
using RefDebugFloat32Workload = RefDebugWorkload<DataType::Float32>;
using RefDebugQAsymm8Workload = RefDebugWorkload<DataType::QuantisedAsymm8>;
using RefDebugQSymm16Workload = RefDebugWorkload<DataType::QuantisedSymm16>;