diff options
author | Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> | 2019-11-12 16:15:11 +0000 |
---|---|---|
committer | Matteo Martincigh <matteo.martincigh@arm.com> | 2019-11-15 11:57:50 +0000 |
commit | db1a2834b4e9d74ed538943634212eccbd4a789b (patch) | |
tree | 10162ed219468bc043ed66e40ea657fb867b0ee1 | |
parent | 60538ada2b90704abcf6473144639103d80287a5 (diff) | |
download | armnn-db1a2834b4e9d74ed538943634212eccbd4a789b.tar.gz |
Add FP16 support to DebugWorkload
Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: Ia879f2d84a1b977474ee0dafa976f2aab32bd3ae
-rw-r--r-- | src/backends/reference/RefLayerSupport.cpp | 3 | ||||
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 5 | ||||
-rw-r--r-- | src/backends/reference/workloads/Debug.cpp | 10 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefDebugWorkload.cpp | 1 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefDebugWorkload.hpp | 1 |
5 files changed, 19 insertions, 1 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index ef0cc8c363..5a84d8ac78 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -495,8 +495,9 @@ bool RefLayerSupport::IsDebugSupported(const TensorInfo& input, { bool supported = true; - std::array<DataType,3> supportedTypes = + std::array<DataType, 4> supportedTypes = { + DataType::Float16, DataType::Float32, DataType::QuantisedAsymm8, DataType::QuantisedSymm16 diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index c2cb51abf3..7fd93435e7 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -172,10 +172,15 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvolution2d(const Convolu std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor, const WorkloadInfo& info) const { + if (IsFloat16(info)) + { + return std::make_unique<RefDebugFloat16Workload>(descriptor, info); + } if (IsQSymm16(info)) { return std::make_unique<RefDebugQSymm16Workload>(descriptor, info); } + return MakeWorkload<RefDebugFloat32Workload, RefDebugQAsymm8Workload>(descriptor, info); } diff --git a/src/backends/reference/workloads/Debug.cpp b/src/backends/reference/workloads/Debug.cpp index 09a0dfc03b..b7d0911ef3 100644 --- a/src/backends/reference/workloads/Debug.cpp +++ b/src/backends/reference/workloads/Debug.cpp @@ -2,8 +2,11 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // + #include "Debug.hpp" +#include <Half.hpp> + #include <boost/numeric/conversion/cast.hpp> #include <algorithm> @@ -85,6 +88,12 @@ void Debug(const TensorInfo& inputInfo, std::cout << " }" << std::endl; } +template void Debug<Half>(const TensorInfo& inputInfo, + const Half* inputData, + LayerGuid guid, + const std::string& layerName, + unsigned int slotIndex); + template void Debug<float>(const TensorInfo& inputInfo, const float* inputData, LayerGuid guid, @@ -102,4 +111,5 @@ template void Debug<int16_t>(const TensorInfo& inputInfo, LayerGuid guid, const std::string& layerName, unsigned int slotIndex); + } // namespace armnn diff --git a/src/backends/reference/workloads/RefDebugWorkload.cpp b/src/backends/reference/workloads/RefDebugWorkload.cpp index 325817b19f..2a3883f8f7 100644 --- a/src/backends/reference/workloads/RefDebugWorkload.cpp +++ b/src/backends/reference/workloads/RefDebugWorkload.cpp @@ -44,6 +44,7 @@ void RefDebugWorkload<DataType>::RegisterDebugCallback(const DebugCallbackFuncti m_Callback = func; } +template class RefDebugWorkload<DataType::Float16>; template class RefDebugWorkload<DataType::Float32>; template class RefDebugWorkload<DataType::QuantisedAsymm8>; template class RefDebugWorkload<DataType::QuantisedSymm16>; diff --git a/src/backends/reference/workloads/RefDebugWorkload.hpp b/src/backends/reference/workloads/RefDebugWorkload.hpp index 6a1fceba0a..0964515b2c 100644 --- a/src/backends/reference/workloads/RefDebugWorkload.hpp +++ b/src/backends/reference/workloads/RefDebugWorkload.hpp @@ -37,6 +37,7 @@ private: DebugCallbackFunction m_Callback; }; +using RefDebugFloat16Workload = RefDebugWorkload<DataType::Float16>; using RefDebugFloat32Workload = RefDebugWorkload<DataType::Float32>; using RefDebugQAsymm8Workload = RefDebugWorkload<DataType::QuantisedAsymm8>; using RefDebugQSymm16Workload = RefDebugWorkload<DataType::QuantisedSymm16>; |