aboutsummaryrefslogtreecommitdiff
path: root/src/backends/gpuFsa/GpuFsaBackend.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/gpuFsa/GpuFsaBackend.cpp')
-rw-r--r--src/backends/gpuFsa/GpuFsaBackend.cpp9
1 files changed, 9 insertions, 0 deletions
diff --git a/src/backends/gpuFsa/GpuFsaBackend.cpp b/src/backends/gpuFsa/GpuFsaBackend.cpp
index de0d01973c..72f8af7b76 100644
--- a/src/backends/gpuFsa/GpuFsaBackend.cpp
+++ b/src/backends/gpuFsa/GpuFsaBackend.cpp
@@ -20,6 +20,7 @@
#include <arm_compute/core/CL/CLKernelLibrary.h>
#include <arm_compute/runtime/CL/CLBufferAllocator.h>
+#include "layers/GpuFsaBatchMatMul.hpp"
#include "layers/GpuFsaCast.hpp"
#include "layers/GpuFsaConvolution2d.hpp"
#include "layers/GpuFsaDepthwiseConvolution2d.hpp"
@@ -280,6 +281,14 @@ OptimizationViews GpuFsaBackend::OptimizeSubgraphView(const SubgraphView& subgra
}
break;
}
+ case (LayerType::BatchMatMul):
+ {
+ auto input0 = base.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
+ auto input1 = base.GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
+ auto desc = PolymorphicDowncast<const BatchMatMulDescriptor*>(&base.GetParameters());
+ GpuFsaBatchMatMulCreateOp(preCompiledBlobPtr, input0, input1, *desc);
+ break;
+ }
case (LayerType::DepthwiseConvolution2d):
{
auto input = base.GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();