diff options
Diffstat (limited to 'src/graph2/backends/CL/CLFunctionsFactory.cpp')
-rw-r--r-- | src/graph2/backends/CL/CLFunctionsFactory.cpp | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/src/graph2/backends/CL/CLFunctionsFactory.cpp b/src/graph2/backends/CL/CLFunctionsFactory.cpp index bba0cce025..5a51b19e18 100644 --- a/src/graph2/backends/CL/CLFunctionsFactory.cpp +++ b/src/graph2/backends/CL/CLFunctionsFactory.cpp @@ -165,7 +165,13 @@ std::unique_ptr<IFunction> create_convolution_layer(ConvolutionLayerNode &node, std::shared_ptr<IMemoryManager> mm = get_memory_manager(ctx, Target::CL); std::unique_ptr<IFunction> func; std::string func_name; - if(conv_algorithm == ConvolutionMethod::DIRECT) + + if(conv_algorithm == ConvolutionMethod::WINOGRAD) + { + std::tie(func, func_name) = create_named_function<CLWinogradConvolutionLayer>( + std::string("CLWinogradConvolutionLayer"), input, weights, biases, output, conv_info); + } + else if(conv_algorithm == ConvolutionMethod::DIRECT) { std::tie(func, func_name) = create_named_function<CLDirectConvolutionLayer>( std::string("CLDirectConvolutionLayer"), input, weights, biases, output, conv_info); |