aboutsummaryrefslogtreecommitdiff
path: root/src/backends/gpuFsa/GpuFsaLayerSupport.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/gpuFsa/GpuFsaLayerSupport.cpp')
-rw-r--r--src/backends/gpuFsa/GpuFsaLayerSupport.cpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/src/backends/gpuFsa/GpuFsaLayerSupport.cpp b/src/backends/gpuFsa/GpuFsaLayerSupport.cpp
index 96c986ab33..18c9ac8f5b 100644
--- a/src/backends/gpuFsa/GpuFsaLayerSupport.cpp
+++ b/src/backends/gpuFsa/GpuFsaLayerSupport.cpp
@@ -11,6 +11,7 @@
#if defined(ARMCOMPUTEGPUFSA_ENABLED)
#include "layers/GpuFsaConvolution2d.hpp"
+#include "layers/GpuFsaDepthwiseConvolution2d.hpp"
#endif
#include <vector>
@@ -98,6 +99,34 @@ bool GpuFsaLayerSupport::IsLayerSupported(const LayerType& type,
infos[3]);
}
}
+ case LayerType::DepthwiseConvolution2d:
+ {
+ if (infos.size() != 4)
+ {
+ throw InvalidArgumentException("Invalid number of DepthwiseConvolution2dDescriptor TensorInfos. "
+ "TensorInfos should be of format: {input, output, weights, biases}.");
+ }
+
+ auto desc = *(PolymorphicDowncast<const DepthwiseConvolution2dDescriptor*>(&descriptor));
+ if (infos[3] == TensorInfo())
+ {
+ FORWARD_LAYER_VALIDATE_FUNC(GpuFsaDepthwiseConvolution2dValidate,
+ reasonIfUnsupported,
+ infos[0],
+ desc,
+ infos[2],
+ EmptyOptional());
+ }
+ else
+ {
+ FORWARD_LAYER_VALIDATE_FUNC(GpuFsaDepthwiseConvolution2dValidate,
+ reasonIfUnsupported,
+ infos[0],
+ desc,
+ infos[2],
+ infos[3]);
+ }
+ }
case LayerType::Constant:
case LayerType::Input:
case LayerType::Output: