aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLConvolutionLayer.cpp')
-rw-r--r--src/runtime/CL/functions/CLConvolutionLayer.cpp8
1 files changed, 5 insertions, 3 deletions
diff --git a/src/runtime/CL/functions/CLConvolutionLayer.cpp b/src/runtime/CL/functions/CLConvolutionLayer.cpp
index a0bee520a6..1a486ce5c7 100644
--- a/src/runtime/CL/functions/CLConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLConvolutionLayer.cpp
@@ -26,6 +26,8 @@
#include "arm_compute/core/PixelValue.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/CL/CLScheduler.h"
#include <cmath>
@@ -33,6 +35,7 @@
#include <tuple>
using namespace arm_compute;
+using namespace arm_compute::misc::shape_calculator;
CLConvolutionLayer::CLConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
: _memory_manager(std::move(memory_manager)), _function()
@@ -70,7 +73,7 @@ void CLConvolutionLayer::configure(ICLTensor *input, const ICLTensor *weights, c
Status CLConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
const WeightsInfo &weights_info)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
//Configure if the parameters match the direct convolution or the gemm-based
const GPUTarget gpu_target = CLScheduler::get().target();
@@ -86,8 +89,7 @@ Status CLConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo
case ConvolutionMethod::GEMM:
{
// Validate gemm-based convolution layer
- /* TODO COMPMID-754: Add validation methods for CLGEMMConvolutionLayer
- CLGEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, weights_info); */
+ CLGEMMConvolutionLayer::validate(input, weights, biases, output, conv_info, weights_info);
break;
}
default: