diff options
Diffstat (limited to 'src/backends/neon/NeonWorkloadFactory.cpp')
-rw-r--r-- | src/backends/neon/NeonWorkloadFactory.cpp | 52 |
1 files changed, 40 insertions, 12 deletions
diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp index 709dd93e9b..3077ae0a8c 100644 --- a/src/backends/neon/NeonWorkloadFactory.cpp +++ b/src/backends/neon/NeonWorkloadFactory.cpp @@ -260,25 +260,27 @@ std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateElementwiseUnary( switch(descriptor.m_Parameters.m_Operation) { case UnaryOperation::Abs: - { - AbsQueueDescriptor absQueueDescriptor; - absQueueDescriptor.m_Inputs = descriptor.m_Inputs; - absQueueDescriptor.m_Outputs = descriptor.m_Outputs; + { + AbsQueueDescriptor absQueueDescriptor; + absQueueDescriptor.m_Inputs = descriptor.m_Inputs; + absQueueDescriptor.m_Outputs = descriptor.m_Outputs; - return std::make_unique<NeonAbsWorkload>(absQueueDescriptor, info); - } + return std::make_unique<NeonAbsWorkload>(absQueueDescriptor, info); + } case UnaryOperation::Rsqrt: - { - RsqrtQueueDescriptor rsqrtQueueDescriptor; - rsqrtQueueDescriptor.m_Inputs = descriptor.m_Inputs; - rsqrtQueueDescriptor.m_Outputs = descriptor.m_Outputs; + { + RsqrtQueueDescriptor rsqrtQueueDescriptor; + rsqrtQueueDescriptor.m_Inputs = descriptor.m_Inputs; + rsqrtQueueDescriptor.m_Outputs = descriptor.m_Outputs; - return std::make_unique<NeonRsqrtWorkload>(rsqrtQueueDescriptor, info); - } + return std::make_unique<NeonRsqrtWorkload>(rsqrtQueueDescriptor, info); + } case UnaryOperation::Neg: return std::make_unique<NeonNegWorkload>(descriptor, info); case UnaryOperation::Exp: return std::make_unique<NeonExpWorkload>(descriptor, info); + case UnaryOperation::LogicalNot: + return std::make_unique<NeonLogicalNotWorkload>(descriptor, info); default: return nullptr; } @@ -356,6 +358,32 @@ std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateLogSoftmax(const LogSoftma return std::make_unique<NeonLogSoftmaxWorkload>(descriptor, info, m_MemoryManager->GetIntraLayerManager()); } +std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + switch(descriptor.m_Parameters.m_Operation) + { + case LogicalBinaryOperation::LogicalAnd: + return std::make_unique<NeonLogicalAndWorkload>(descriptor, info); + case LogicalBinaryOperation::LogicalOr: + return std::make_unique<NeonLogicalOrWorkload>(descriptor, info); + default: + return nullptr; + } +} + +std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateLogicalUnary(const ElementwiseUnaryQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + switch(descriptor.m_Parameters.m_Operation) + { + case UnaryOperation::LogicalNot: + return std::make_unique<NeonLogicalNotWorkload>(descriptor, info); + default: + return nullptr; + } +} + std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info) const { |