aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/RefWorkloadFactory.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/RefWorkloadFactory.cpp')
-rw-r--r--src/armnn/backends/RefWorkloadFactory.cpp61
1 files changed, 28 insertions, 33 deletions
diff --git a/src/armnn/backends/RefWorkloadFactory.cpp b/src/armnn/backends/RefWorkloadFactory.cpp
index d7d498e89e..9294c5accc 100644
--- a/src/armnn/backends/RefWorkloadFactory.cpp
+++ b/src/armnn/backends/RefWorkloadFactory.cpp
@@ -18,22 +18,15 @@ template <typename F32Workload, typename U8Workload, typename QueueDescriptorTyp
std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
const WorkloadInfo& info) const
{
- if (!IsOperationQueueDescriptor(descriptor) || m_OperationWorkloadsAllowed)
- {
- return armnn::MakeWorkload<F32Workload, U8Workload>(descriptor, info);
- }
- else
- {
- return std::unique_ptr<IWorkload>();
- }
+ return armnn::MakeWorkload<NullWorkload, F32Workload, U8Workload>(descriptor, info);
}
-RefWorkloadFactory::RefWorkloadFactory(bool operationWorkloadsAllowed)
- : m_OperationWorkloadsAllowed(operationWorkloadsAllowed)
+RefWorkloadFactory::RefWorkloadFactory()
{
}
-bool RefWorkloadFactory::IsLayerSupported(const Layer& layer, DataType dataType, std::string& outReasonIfUnsupported)
+bool RefWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType,
+ std::string& outReasonIfUnsupported)
{
return IWorkloadFactory::IsLayerSupported(Compute::CpuRef, layer, dataType, outReasonIfUnsupported);
}
@@ -60,7 +53,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescr
throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count.");
}
- return MakeWorkload<CopyFromCpuToCpuFloat32Workload, CopyFromCpuToCpuUint8Workload>(descriptor, info);
+ return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
@@ -79,7 +72,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDes
throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count.");
}
- return MakeWorkload<CopyFromCpuToCpuFloat32Workload, CopyFromCpuToCpuUint8Workload>(descriptor, info);
+ return MakeWorkload<CopyMemGenericWorkload, CopyMemGenericWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
@@ -168,25 +161,7 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCop
{
throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor.");
}
- // Create a workload that will copy tensor data from the inputs, which can have a number of different formats,
- // to CPU tensors.
- switch (descriptor.m_Inputs[0]->GetType())
- {
-#if ARMCOMPUTECL_ENABLED
- case ITensorHandle::CL:
- {
- return MakeWorkload<CopyFromClToCpuFloat32Workload, CopyFromClToCpuUint8Workload>(descriptor, info);
- }
-#endif
-#if ARMCOMPUTENEON_ENABLED
- case ITensorHandle::Neon:
- {
- return MakeWorkload<CopyFromNeonToCpuFloat32Workload, CopyFromNeonToCpuUint8Workload>(descriptor, info);
- }
-#endif
- default:
- throw InvalidArgumentException("RefWorkloadFactory: Destination type not supported for MemCopy Workload.");
- }
+ return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
@@ -221,9 +196,29 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueD
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
- const WorkloadInfo& info) const
+ const WorkloadInfo& info) const
{
return MakeWorkload<RefFloorFloat32Workload, NullWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return MakeWorkload<RefLstmFloat32Workload, NullWorkload>(descriptor, info);
+}
+
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32(
+ const ConvertFp16ToFp32QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::make_unique<RefConvertFp16ToFp32Workload>(descriptor, info);
+}
+
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToFp16(
+ const ConvertFp32ToFp16QueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::make_unique<RefConvertFp32ToFp16Workload>(descriptor, info);
+}
+
} // namespace armnn