aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils/misc/ShapeCalculator.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h18
1 files changed, 16 insertions, 2 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 2919625511..354f60d016 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -174,7 +174,6 @@ inline TensorShape compute_interleave_custom_shape(const TensorShape &input, con
return output_shape;
}
-
inline TensorShape compute_fully_connected_reshaped_weights_shape(const ITensorInfo *input, bool transpose_weights, bool is_batched_fc_layer, const int interleave)
{
TensorShape output_shape{ input->tensor_shape() };
@@ -194,7 +193,6 @@ inline TensorShape compute_fully_connected_reshaped_weights_shape(const ITensorI
return output_shape;
}
-
inline TensorShape compute_winograd_input_transform_shape(const ITensorInfo &input, const PadStrideInfo &conv_info, const Size2D &kernel_size)
{
// Compute height
@@ -212,6 +210,22 @@ inline TensorShape compute_winograd_input_transform_shape(const ITensorInfo &inp
return output_shape;
}
+inline TensorShape compute_deep_convolution_shape(const ITensorInfo &input, const ITensorInfo &weights, PadStrideInfo conv_info)
+{
+ const TensorShape input_shape{ input.tensor_shape() };
+ const TensorShape weights_shape{ weights.tensor_shape() };
+
+ unsigned int output_width = 0;
+ unsigned int output_height = 0;
+ std::tie(output_width, output_height) = scaled_dimensions(input_shape.x(), input_shape.y(), weights_shape.x(), weights_shape.y(), conv_info);
+
+ TensorShape output_shape{ input_shape };
+ output_shape.set(0, output_width);
+ output_shape.set(1, output_height);
+ output_shape.set(2, weights_shape[3]);
+
+ return output_shape;
+}
} // namespace shape_calculator
} // namespace misc
} // namespace arm_compute