diff options
Diffstat (limited to 'src/backends/reference/RefWorkloadFactory.cpp')
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 52d71df936..1d82421490 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -50,6 +50,11 @@ bool IsSigned32(const WorkloadInfo& info) return IsDataType<DataType::Signed32>(info); } +bool IsBFloat16(const WorkloadInfo& info) +{ + return IsDataType<DataType::BFloat16>(info); +} + bool IsFloat16(const WorkloadInfo& info) { return IsDataType<DataType::Float16>(info); @@ -441,6 +446,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescripto { return std::make_unique<RefPadFloat16Workload>(descriptor, info); } + else if (IsBFloat16(info)) + { + return std::make_unique<RefPadBFloat16Workload>(descriptor, info); + } return MakeWorkload<RefPadFloat32Workload, RefPadQAsymm8Workload>(descriptor, info); } @@ -451,6 +460,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueD { return std::make_unique<RefPermuteQSymm16Workload>(descriptor, info); } + else if (IsBFloat16(info)) + { + return std::make_unique<RefPermuteBFloat16Workload>(descriptor, info); + } return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteQAsymm8Workload, NullWorkload, NullWorkload, NullWorkload>(descriptor, info); } @@ -568,6 +581,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTranspose(const TransposeQu { return std::make_unique<RefTransposeQSymm16Workload>(descriptor, info); } + else if (IsBFloat16(info)) + { + return std::make_unique<RefTransposeBFloat16Workload>(descriptor, info); + } return MakeWorkloadHelper<RefTransposeFloat16Workload, RefTransposeFloat32Workload, RefTransposeQAsymm8Workload, NullWorkload, NullWorkload, NullWorkload>(descriptor, info); } |