aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp23
1 files changed, 17 insertions, 6 deletions
diff --git a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
index 7ad017f918..7af42904e8 100644
--- a/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLWinogradConvolutionLayer.cpp
@@ -28,6 +28,15 @@
#include "arm_compute/core/Validate.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
+#include "src/core/CL/kernels/CLFillBorderKernel.h"
+#include "src/core/CL/kernels/CLGEMMMatrixMultiplyKernel.h"
+#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.h"
+#include "src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedOnlyRHSKernel.h"
+#include "src/core/CL/kernels/CLGEMMReshapeLHSMatrixKernel.h"
+#include "src/core/CL/kernels/CLGEMMReshapeRHSMatrixKernel.h"
+#include "src/core/CL/kernels/CLWinogradFilterTransformKernel.h"
+#include "src/core/CL/kernels/CLWinogradOutputTransformKernel.h"
+#include "support/MemorySupport.h"
using namespace arm_compute;
@@ -90,11 +99,13 @@ bool check_support_fast_math(const Size2D &output_tile, const Size2D &kernel_siz
} // namespace
CLWinogradConvolutionLayer::CLWinogradConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(memory_manager), _batched_mm(memory_manager), _input_transform(), _filter_transform(), _output_transform(), _input0(), _input1(), _batched_mm_output(), _original_weights(nullptr),
- _is_prepared(false)
+ : _memory_group(memory_manager), _batched_mm(memory_manager), _input_transform(), _filter_transform(support::cpp14::make_unique<CLWinogradFilterTransformKernel>()),
+ _output_transform(support::cpp14::make_unique<CLWinogradOutputTransformKernel>()), _input0(), _input1(), _batched_mm_output(), _original_weights(nullptr), _is_prepared(false)
{
}
+CLWinogradConvolutionLayer::~CLWinogradConvolutionLayer() = default;
+
void CLWinogradConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, const ICLTensor *biases, ICLTensor *output, const PadStrideInfo &conv_info, const ActivationLayerInfo &act_info,
bool enable_fast_math)
{
@@ -139,7 +150,7 @@ void CLWinogradConvolutionLayer::configure(const CLCompileContext &compile_conte
_input_transform.configure(compile_context, input, &_input0, winograd_info);
// Configure filter transform
- _filter_transform.configure(compile_context, weights, &_input1, winograd_info);
+ _filter_transform->configure(compile_context, weights, &_input1, winograd_info);
// Configure batched matrix multiply
_batched_mm.configure(compile_context, &_input0, &_input1, nullptr, &_batched_mm_output, 1.0f, 0.0f, GEMMInfo(false, false, true /* Reshape weights only for the first run*/, 0, false, false,
@@ -147,7 +158,7 @@ void CLWinogradConvolutionLayer::configure(const CLCompileContext &compile_conte
(input->info()->data_type() == DataType::F16)));
// Configure output transform
- _output_transform.configure(compile_context, &_batched_mm_output, biases, output, winograd_info, act_info);
+ _output_transform->configure(compile_context, &_batched_mm_output, biases, output, winograd_info, act_info);
// Allocate temporary tensors
_input0.allocator()->allocate();
@@ -218,7 +229,7 @@ void CLWinogradConvolutionLayer::run()
_batched_mm.run();
// Run output transform
- CLScheduler::get().enqueue(_output_transform);
+ CLScheduler::get().enqueue(*_output_transform);
}
void CLWinogradConvolutionLayer::prepare()
@@ -227,7 +238,7 @@ void CLWinogradConvolutionLayer::prepare()
{
// Run filter transform and mark original weights as unused
_input1.allocator()->allocate();
- CLScheduler::get().enqueue(_filter_transform, false);
+ CLScheduler::get().enqueue(*_filter_transform, false);
_original_weights->mark_as_unused();
// Prepare GEMM and release reshaped weights if marked unused by CLGEMM