diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadFactory.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadFactory.cpp | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 1dc96a5ec3..209ba6a4ed 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -190,15 +190,6 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } - case LayerType::MemCopy: - { - // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends, - // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests). - result = backendId == Compute::CpuRef || backendId == Compute::Undefined - || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc; - reason.value() = "Unsupported backend type"; - break; - } case LayerType::Debug: { auto cLayer = boost::polymorphic_downcast<const DebugLayer*>(&layer); @@ -487,6 +478,16 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, reason); break; } + case LayerType::MemCopy: + { + const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); + const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo(); + + result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType), + OverrideDataType(output, dataType), + reason); + break; + } case LayerType::Merger: { auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer); @@ -590,8 +591,11 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, } case LayerType::Reshape: { + auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer); const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo(); - result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason); + result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), + cLayer->GetParameters(), + reason); break; } case LayerType::ResizeBilinear: |