diff options
author | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2020-03-11 14:51:27 +0000 |
---|---|---|
committer | Narumol Prangnawarat <narumol.prangnawarat@arm.com> | 2020-03-13 09:49:42 +0000 |
commit | 44179c372eea9f17c96cbf50ee383e57e14d70a6 (patch) | |
tree | 2a2971c2db67426107b21d9a045cfa46a4a1663a /src/backends/reference/RefWorkloadFactory.cpp | |
parent | e9b5d2989abc8008df7ff3ea287ee896ee1121a6 (diff) | |
download | armnn-44179c372eea9f17c96cbf50ee383e57e14d70a6.tar.gz |
IVGCVSW-4511 Add BFloat16 to RefLayerSupport and unit tests
Signed-off-by: Narumol Prangnawarat <narumol.prangnawarat@arm.com>
Change-Id: Ifaae4d5aac468ba927b2c6a4bf31b8c8522aeb2e
Diffstat (limited to 'src/backends/reference/RefWorkloadFactory.cpp')
-rw-r--r-- | src/backends/reference/RefWorkloadFactory.cpp | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 52d71df936..1d82421490 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -50,6 +50,11 @@ bool IsSigned32(const WorkloadInfo& info) return IsDataType<DataType::Signed32>(info); } +bool IsBFloat16(const WorkloadInfo& info) +{ + return IsDataType<DataType::BFloat16>(info); +} + bool IsFloat16(const WorkloadInfo& info) { return IsDataType<DataType::Float16>(info); @@ -441,6 +446,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescripto { return std::make_unique<RefPadFloat16Workload>(descriptor, info); } + else if (IsBFloat16(info)) + { + return std::make_unique<RefPadBFloat16Workload>(descriptor, info); + } return MakeWorkload<RefPadFloat32Workload, RefPadQAsymm8Workload>(descriptor, info); } @@ -451,6 +460,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueD { return std::make_unique<RefPermuteQSymm16Workload>(descriptor, info); } + else if (IsBFloat16(info)) + { + return std::make_unique<RefPermuteBFloat16Workload>(descriptor, info); + } return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteQAsymm8Workload, NullWorkload, NullWorkload, NullWorkload>(descriptor, info); } @@ -568,6 +581,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTranspose(const TransposeQu { return std::make_unique<RefTransposeQSymm16Workload>(descriptor, info); } + else if (IsBFloat16(info)) + { + return std::make_unique<RefTransposeBFloat16Workload>(descriptor, info); + } return MakeWorkloadHelper<RefTransposeFloat16Workload, RefTransposeFloat32Workload, RefTransposeQAsymm8Workload, NullWorkload, NullWorkload, NullWorkload>(descriptor, info); } |