diff options
Diffstat (limited to 'src/backends/neon/NeonWorkloadFactory.cpp')
-rw-r--r-- | src/backends/neon/NeonWorkloadFactory.cpp | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp index 08168eca2f..c78b58d21d 100644 --- a/src/backends/neon/NeonWorkloadFactory.cpp +++ b/src/backends/neon/NeonWorkloadFactory.cpp @@ -155,7 +155,19 @@ std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateWorkload(LayerType type, case LayerType::BatchMatMul : { auto batchMatMulQueueDescriptor = PolymorphicDowncast<const BatchMatMulQueueDescriptor*>(&descriptor); - return std::make_unique<NeonBatchMatMulWorkload>(*batchMatMulQueueDescriptor, info); + bool isFastMathEnabled = false; + if (m_ModelContextPtr) + { + if (m_ModelContextPtr.get() != nullptr) + { + auto modelOptions = dynamic_cast<NeonBackendModelContext*>(m_ModelContextPtr.get()); + if (modelOptions) + { + isFastMathEnabled = modelOptions->IsFastMathEnabled(); + } + } + } + return std::make_unique<NeonBatchMatMulWorkload>(*batchMatMulQueueDescriptor, info, isFastMathEnabled); } case LayerType::BatchNormalization : { |