aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/tensor_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r--reference_model/src/ops/tensor_ops.cc9
1 files changed, 2 insertions, 7 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index be4e4aa..059638a 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -92,13 +92,8 @@ int check_pool2d_attribute_common(tosa::TosaPoolAttribute* attribute,
return 1;
}
- int32_t allowed_min_input_height = (OH * stride_y) - pad_top - pad_bottom - stride_y + kernel_y;
- int32_t allowed_min_input_width = (OW * stride_x) - pad_left - pad_right - stride_x + kernel_x;
-
- int32_t d_height = IH - allowed_min_input_height;
- int32_t d_width = IW - allowed_min_input_width;
-
- if (d_height < 0 || d_height > stride_y || d_width < 0 || d_width > stride_x)
+ if ( OH != (IH + pad_top + pad_bottom + stride_y - kernel_y) / stride_y ||
+ OW != (IW + pad_left + pad_right + stride_x - kernel_x) / stride_x )
{
msg = "Mismatch between output shape provided and expected output shape";
return 1;