diff options
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 8 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefDebugWorkload.cpp | 3 | ||||
-rw-r--r-- | src/backends/reference/workloads/RefDebugWorkload.hpp | 3 |
3 files changed, 12 insertions, 2 deletions
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index df458c1a6d..086f8eea8d 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -77,6 +77,10 @@ bool IsQAsymmU8(const WorkloadInfo& info) { return IsDataType<DataType::QAsymmU8>(info); } +bool IsBoolean(const WorkloadInfo& info) +{ + return IsDataType<DataType::Boolean>(info); +} RefWorkloadFactory::RefWorkloadFactory(const std::shared_ptr<RefMemoryManager>& memoryManager) : m_MemoryManager(memoryManager) @@ -271,6 +275,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateWorkload(LayerType type, { return std::make_unique<RefDebugSigned64Workload>(*debugQueueDescriptor, info); } + if (IsBoolean(info)) + { + return std::make_unique<RefDebugBooleanWorkload>(*debugQueueDescriptor, info); + } return MakeWorkload<RefDebugFloat32Workload, RefDebugQAsymmU8Workload>(*debugQueueDescriptor, info); } case LayerType::DepthToSpace: diff --git a/src/backends/reference/workloads/RefDebugWorkload.cpp b/src/backends/reference/workloads/RefDebugWorkload.cpp index 94eed4ff4f..23df873063 100644 --- a/src/backends/reference/workloads/RefDebugWorkload.cpp +++ b/src/backends/reference/workloads/RefDebugWorkload.cpp @@ -1,5 +1,5 @@ // -// Copyright © 2018-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2018-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -66,5 +66,6 @@ template class RefDebugWorkload<DataType::QSymmS16>; template class RefDebugWorkload<DataType::QSymmS8>; template class RefDebugWorkload<DataType::Signed32>; template class RefDebugWorkload<DataType::Signed64>; +template class RefDebugWorkload<DataType::Boolean>; } // namespace armnn diff --git a/src/backends/reference/workloads/RefDebugWorkload.hpp b/src/backends/reference/workloads/RefDebugWorkload.hpp index 4c99990ec4..16457820c4 100644 --- a/src/backends/reference/workloads/RefDebugWorkload.hpp +++ b/src/backends/reference/workloads/RefDebugWorkload.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2018-2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2018-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -48,5 +48,6 @@ using RefDebugQSymmS16Workload = RefDebugWorkload<DataType::QSymmS16>; using RefDebugQSymmS8Workload = RefDebugWorkload<DataType::QSymmS8>; using RefDebugSigned32Workload = RefDebugWorkload<DataType::Signed32>; using RefDebugSigned64Workload = RefDebugWorkload<DataType::Signed64>; +using RefDebugBooleanWorkload = RefDebugWorkload<DataType::Boolean>; } // namespace armnn |