diff options
Diffstat (limited to 'src/core/CL/kernels/CLChannelExtractKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLChannelExtractKernel.cpp | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/src/core/CL/kernels/CLChannelExtractKernel.cpp b/src/core/CL/kernels/CLChannelExtractKernel.cpp index d2a0f984da..8df162c4ee 100644 --- a/src/core/CL/kernels/CLChannelExtractKernel.cpp +++ b/src/core/CL/kernels/CLChannelExtractKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2019 ARM Limited. + * Copyright (c) 2016-2020 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -49,6 +49,11 @@ CLChannelExtractKernel::CLChannelExtractKernel() void CLChannelExtractKernel::configure(const ICLTensor *input, Channel channel, ICLTensor *output) { + configure(CLKernelLibrary::get().get_compile_context(), input, channel, output); +} + +void CLChannelExtractKernel::configure(CLCompileContext &compile_context, const ICLTensor *input, Channel channel, ICLTensor *output) +{ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_ERROR_ON(input == output); @@ -89,7 +94,7 @@ void CLChannelExtractKernel::configure(const ICLTensor *input, Channel channel, // Create kernel std::string kernel_name = "channel_extract_" + string_from_format(format); std::set<std::string> build_opts = { ("-DCHANNEL_" + string_from_channel(channel)) }; - _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts)); + _kernel = create_kernel(compile_context, kernel_name, build_opts); // Configure window Window win = calculate_max_window(*input->info(), Steps(_num_elems_processed_per_iteration)); @@ -106,6 +111,11 @@ void CLChannelExtractKernel::configure(const ICLTensor *input, Channel channel, void CLChannelExtractKernel::configure(const ICLMultiImage *input, Channel channel, ICLImage *output) { + configure(CLKernelLibrary::get().get_compile_context(), input, channel, output); +} + +void CLChannelExtractKernel::configure(CLCompileContext &compile_context, const ICLMultiImage *input, Channel channel, ICLImage *output) +{ ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); ARM_COMPUTE_ERROR_ON_TENSOR_NOT_2D(output); @@ -151,7 +161,7 @@ void CLChannelExtractKernel::configure(const ICLMultiImage *input, Channel chann kernel_name = "channel_extract_" + string_from_format(format); build_opts.insert(("-DCHANNEL_" + string_from_channel(channel))); } - _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts)); + _kernel = create_kernel(compile_context, kernel_name, build_opts); // Configure window Window win = calculate_max_window(*input_plane->info(), Steps(_num_elems_processed_per_iteration)); |