aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--reference_model/src/ops/tensor_ops.cc44
-rw-r--r--reference_model/src/ops/tensor_ops.h2
2 files changed, 21 insertions, 25 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index 7942a24..02fdf01 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -324,37 +324,33 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes()
return 0;
}
+// This calculates the number of padding elements used for each location along an axis
+// Average pooling only divides by the number of elements used, not including padding.
+// This function uses left/right, but is also used for vertical padding with top/bottom
template <DType Dtype>
-ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride)
+ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
{
ETensor1<int32_t> result(out_size);
- int32_t total_pad = (out_size - 1) * stride + kernel_size - in_size;
- total_pad = total_pad < 0 ? 0 : total_pad;
-
- int32_t pad_left = total_pad >> 1;
- int32_t pad_right = total_pad - pad_left;
-
result.setConstant(kernel_size);
- // the index left to 'left_index' and index right to 'right_index' indicates
- // the input window of this output covers a pad bit
- int32_t left_index = pad_left / stride;
- int32_t right_index = pad_right / stride;
-
- // minus the number of pad bit this index cover
- while (left_index >= 0)
- {
- result(left_index) -= (pad_left - left_index * stride);
- left_index--;
+ // adjust divisors on the left side for padding
+ // We start at the leftmost output element, and remove pad_left - (index * stride) elements
+ // until we have no more padding being used
+ for(int index = 0; (index < pad_left / stride) && (index < out_size); index++) {
+ int32_t adjust = pad_left - (index * stride);
+ result(index) -= adjust;
}
- while (right_index >= 0)
- {
- result(out_size - 1 - right_index) -= (pad_right - right_index * stride);
- right_index--;
+ // The process repeats on the right side. Padding starts taking effect as we
+ // near the rightmost input element. The first output element which touches
+ // padding is defined in the initialization of index below. Then we keep moving
+ // to the right, increasing padding until we get to the last output element.
+ int index = std::max(0, ((pad_left + in_size - kernel_size) / stride) + 1);
+ for (; index < out_size; index++) {
+ int32_t adjust = ((index * stride) + kernel_size) - (pad_left + in_size);
+ result(index) -= adjust;
}
-
return result;
}
@@ -445,8 +441,8 @@ int OpAvgPool2d<Dtype>::eval()
// calculate 1d height/width div_map (number of elements this pooling window covers)
// and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
- ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h);
- ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w);
+ ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h, padding_top, padding_bottom);
+ ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w, padding_left, padding_right);
Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 2174d62..05b1ca1 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -73,7 +73,7 @@ protected:
protected:
// return a 1D [N] tensor that describes a how many valid elements covered in the input space
- ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride);
+ ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
};
template <DType InDtype, DType WeightDtype>