diff options
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 44 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 2 |
2 files changed, 21 insertions, 25 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 7942a24..02fdf01 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -324,37 +324,33 @@ int OpAvgPool2d<Dtype>::checkTensorAttributes() return 0; } +// This calculates the number of padding elements used for each location along an axis +// Average pooling only divides by the number of elements used, not including padding. +// This function uses left/right, but is also used for vertical padding with top/bottom template <DType Dtype> -ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride) +ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) { ETensor1<int32_t> result(out_size); - int32_t total_pad = (out_size - 1) * stride + kernel_size - in_size; - total_pad = total_pad < 0 ? 0 : total_pad; - - int32_t pad_left = total_pad >> 1; - int32_t pad_right = total_pad - pad_left; - result.setConstant(kernel_size); - // the index left to 'left_index' and index right to 'right_index' indicates - // the input window of this output covers a pad bit - int32_t left_index = pad_left / stride; - int32_t right_index = pad_right / stride; - - // minus the number of pad bit this index cover - while (left_index >= 0) - { - result(left_index) -= (pad_left - left_index * stride); - left_index--; + // adjust divisors on the left side for padding + // We start at the leftmost output element, and remove pad_left - (index * stride) elements + // until we have no more padding being used + for(int index = 0; (index < pad_left / stride) && (index < out_size); index++) { + int32_t adjust = pad_left - (index * stride); + result(index) -= adjust; } - while (right_index >= 0) - { - result(out_size - 1 - right_index) -= (pad_right - right_index * stride); - right_index--; + // The process repeats on the right side. Padding starts taking effect as we + // near the rightmost input element. The first output element which touches + // padding is defined in the initialization of index below. Then we keep moving + // to the right, increasing padding until we get to the last output element. + int index = std::max(0, ((pad_left + in_size - kernel_size) / stride) + 1); + for (; index < out_size; index++) { + int32_t adjust = ((index * stride) + kernel_size) - (pad_left + in_size); + result(index) -= adjust; } - return result; } @@ -445,8 +441,8 @@ int OpAvgPool2d<Dtype>::eval() // calculate 1d height/width div_map (number of elements this pooling window covers) // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C] - ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h); - ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w); + ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h, padding_top, padding_bottom); + ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w, padding_left, padding_right); Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) }; Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels }; diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 2174d62..05b1ca1 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -73,7 +73,7 @@ protected: protected: // return a 1D [N] tensor that describes a how many valid elements covered in the input space - ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride); + ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; template <DType InDtype, DType WeightDtype> |