aboutsummaryrefslogtreecommitdiff
path: root/src/backends/neon/NeonWorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/neon/NeonWorkloadFactory.cpp')
-rw-r--r--src/backends/neon/NeonWorkloadFactory.cpp14
1 files changed, 13 insertions, 1 deletions
diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp
index 08168eca2f..c78b58d21d 100644
--- a/src/backends/neon/NeonWorkloadFactory.cpp
+++ b/src/backends/neon/NeonWorkloadFactory.cpp
@@ -155,7 +155,19 @@ std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateWorkload(LayerType type,
case LayerType::BatchMatMul :
{
auto batchMatMulQueueDescriptor = PolymorphicDowncast<const BatchMatMulQueueDescriptor*>(&descriptor);
- return std::make_unique<NeonBatchMatMulWorkload>(*batchMatMulQueueDescriptor, info);
+ bool isFastMathEnabled = false;
+ if (m_ModelContextPtr)
+ {
+ if (m_ModelContextPtr.get() != nullptr)
+ {
+ auto modelOptions = dynamic_cast<NeonBackendModelContext*>(m_ModelContextPtr.get());
+ if (modelOptions)
+ {
+ isFastMathEnabled = modelOptions->IsFastMathEnabled();
+ }
+ }
+ }
+ return std::make_unique<NeonBatchMatMulWorkload>(*batchMatMulQueueDescriptor, info, isFastMathEnabled);
}
case LayerType::BatchNormalization :
{