diff options
Diffstat (limited to 'src/armnn/backends/RefWorkloadFactory.cpp')
-rw-r--r-- | src/armnn/backends/RefWorkloadFactory.cpp | 61 |
1 files changed, 28 insertions, 33 deletions
diff --git a/src/armnn/backends/RefWorkloadFactory.cpp b/src/armnn/backends/RefWorkloadFactory.cpp index d7d498e89e..9294c5accc 100644 --- a/src/armnn/backends/RefWorkloadFactory.cpp +++ b/src/armnn/backends/RefWorkloadFactory.cpp @@ -18,22 +18,15 @@ template <typename F32Workload, typename U8Workload, typename QueueDescriptorTyp std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const { - if (!IsOperationQueueDescriptor(descriptor) || m_OperationWorkloadsAllowed) - { - return armnn::MakeWorkload<F32Workload, U8Workload>(descriptor, info); - } - else - { - return std::unique_ptr<IWorkload>(); - } + return armnn::MakeWorkload<NullWorkload, F32Workload, U8Workload>(descriptor, info); } -RefWorkloadFactory::RefWorkloadFactory(bool operationWorkloadsAllowed) - : m_OperationWorkloadsAllowed(operationWorkloadsAllowed) +RefWorkloadFactory::RefWorkloadFactory() { } -bool RefWorkloadFactory::IsLayerSupported(const Layer& layer, DataType dataType, std::string& outReasonIfUnsupported) +bool RefWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType, + std::string& outReasonIfUnsupported) { return IWorkloadFactory::IsLayerSupported(Compute::CpuRef, layer, dataType, outReasonIfUnsupported); } @@ -60,7 +53,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescr throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count."); } - return MakeWorkload<CopyFromCpuToCpuFloat32Workload, CopyFromCpuToCpuUint8Workload>(descriptor, info); + return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor, @@ -79,7 +72,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDes throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count."); } - return MakeWorkload<CopyFromCpuToCpuFloat32Workload, CopyFromCpuToCpuUint8Workload>(descriptor, info); + return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor, @@ -168,25 +161,7 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCop { throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor."); } - // Create a workload that will copy tensor data from the inputs, which can have a number of different formats, - // to CPU tensors. - switch (descriptor.m_Inputs[0]->GetType()) - { -#if ARMCOMPUTECL_ENABLED - case ITensorHandle::CL: - { - return MakeWorkload<CopyFromClToCpuFloat32Workload, CopyFromClToCpuUint8Workload>(descriptor, info); - } -#endif -#if ARMCOMPUTENEON_ENABLED - case ITensorHandle::Neon: - { - return MakeWorkload<CopyFromNeonToCpuFloat32Workload, CopyFromNeonToCpuUint8Workload>(descriptor, info); - } -#endif - default: - throw InvalidArgumentException("RefWorkloadFactory: Destination type not supported for MemCopy Workload."); - } + return std::make_unique<CopyMemGenericWorkload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor, @@ -221,9 +196,29 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueD } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor, - const WorkloadInfo& info) const + const WorkloadInfo& info) const { return MakeWorkload<RefFloorFloat32Workload, NullWorkload>(descriptor, info); } +std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return MakeWorkload<RefLstmFloat32Workload, NullWorkload>(descriptor, info); +} + +std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32( + const ConvertFp16ToFp32QueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique<RefConvertFp16ToFp32Workload>(descriptor, info); +} + +std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToFp16( + const ConvertFp32ToFp16QueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique<RefConvertFp32ToFp16Workload>(descriptor, info); +} + } // namespace armnn |