aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/graph/nodes/DepthwiseConvolutionLayer.cpp6
1 files changed, 2 insertions, 4 deletions
diff --git a/src/graph/nodes/DepthwiseConvolutionLayer.cpp b/src/graph/nodes/DepthwiseConvolutionLayer.cpp
index 1209d0376e..e5101cc33c 100644
--- a/src/graph/nodes/DepthwiseConvolutionLayer.cpp
+++ b/src/graph/nodes/DepthwiseConvolutionLayer.cpp
@@ -40,10 +40,8 @@ std::unique_ptr<arm_compute::IFunction> DepthwiseConvolutionLayer::instantiate_n
if(_weights.tensor() == nullptr)
{
- TensorShape shape = in->info()->tensor_shape();
- shape.set(Window::DimX, _conv_width);
- shape.set(Window::DimY, _conv_height);
- TensorInfo info = TensorInfo(TensorShape(shape), in->info()->num_channels(), in->info()->data_type(), in->info()->fixed_point_position());
+ TensorShape weights_shape(_conv_width, _conv_height, input->tensor()->info()->tensor_shape().z());
+ TensorInfo info = TensorInfo(TensorShape(weights_shape), in->info()->num_channels(), in->info()->data_type(), in->info()->fixed_point_position());
info.set_quantization_info(_quant_info);
_weights.set_info(std::move(info));
}