diff options
Diffstat (limited to 'src/backends/reference/RefWorkloadFactory.cpp')
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index e7a9c19fc7..792bd7d3ad 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -24,7 +24,8 @@ template <typename F32Workload, typename U8Workload, typename QueueDescriptorTyp std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const { - return MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload>(descriptor, info); + return MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload, NullWorkload> + (descriptor, info); } template <DataType ArmnnType> @@ -54,6 +55,11 @@ bool IsQSymm16(const WorkloadInfo& info) return IsDataType<DataType::QSymmS16>(info); } +bool IsQSymm8(const WorkloadInfo& info) +{ + return IsDataType<DataType::QSymmS8>(info); +} + RefWorkloadFactory::RefWorkloadFactory(const std::shared_ptr<RefMemoryManager>& memoryManager) : m_MemoryManager(memoryManager) { @@ -185,6 +191,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescr { return std::make_unique<RefDebugQSymm16Workload>(descriptor, info); } + if (IsQSymm8(info)) + { + return std::make_unique<RefDebugQSymm8Workload>(descriptor, info); + } if (IsDataType<DataType::Signed32>(info)) { return std::make_unique<RefDebugSigned32Workload>(descriptor, info); @@ -419,7 +429,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueD return std::make_unique<RefPermuteQSymm16Workload>(descriptor, info); } return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteQAsymm8Workload, - NullWorkload, NullWorkload>(descriptor, info); + NullWorkload, NullWorkload, NullWorkload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor, |