aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/GLES_COMPUTE/functions/GCDirectConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/GLES_COMPUTE/functions/GCDirectConvolutionLayer.cpp')
-rw-r--r--src/runtime/GLES_COMPUTE/functions/GCDirectConvolutionLayer.cpp11
1 files changed, 7 insertions, 4 deletions
diff --git a/src/runtime/GLES_COMPUTE/functions/GCDirectConvolutionLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCDirectConvolutionLayer.cpp
index a2607d4c2d..c0cf09836f 100644
--- a/src/runtime/GLES_COMPUTE/functions/GCDirectConvolutionLayer.cpp
+++ b/src/runtime/GLES_COMPUTE/functions/GCDirectConvolutionLayer.cpp
@@ -39,26 +39,27 @@ GCDirectConvolutionLayer::GCDirectConvolutionLayer()
{
}
-void GCDirectConvolutionLayer::configure(IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info)
+void GCDirectConvolutionLayer::configure(IGCTensor *input, const IGCTensor *weights, const IGCTensor *biases, IGCTensor *output, const PadStrideInfo &conv_info,
+ const ActivationLayerInfo &act_info)
{
int kernel_size = weights->info()->dimension(0);
if(kernel_size == 1)
{
auto k = arm_compute::support::cpp14::make_unique<GCDirectConvolutionLayer1x1Kernel>();
- k->configure(input, weights, biases, output, conv_info);
+ k->configure(input, weights, biases, output, conv_info, act_info);
_kernel = std::move(k);
}
else if(kernel_size == 3)
{
auto k = arm_compute::support::cpp14::make_unique<GCDirectConvolutionLayer3x3Kernel>();
- k->configure(input, weights, biases, output, conv_info);
+ k->configure(input, weights, biases, output, conv_info, act_info);
_kernel = std::move(k);
}
else if(kernel_size == 5)
{
auto k = arm_compute::support::cpp14::make_unique<GCDirectConvolutionLayer5x5Kernel>();
- k->configure(input, weights, biases, output, conv_info);
+ k->configure(input, weights, biases, output, conv_info, act_info);
_kernel = std::move(k);
}
else
@@ -79,4 +80,6 @@ void GCDirectConvolutionLayer::run()
GCScheduler::get().dispatch(_border_handler, false);
GCScheduler::get().memory_barrier();
GCScheduler::get().dispatch(*_kernel);
+ GCScheduler::get().memory_barrier();
+ GCScheduler::get().dispatch(_shift_handler);
}