diff options
Diffstat (limited to 'tests/validation/reference/PoolingLayer.cpp')
-rw-r--r-- | tests/validation/reference/PoolingLayer.cpp | 78 |
1 files changed, 42 insertions, 36 deletions
diff --git a/tests/validation/reference/PoolingLayer.cpp b/tests/validation/reference/PoolingLayer.cpp index 778e28d7c1..c110a67842 100644 --- a/tests/validation/reference/PoolingLayer.cpp +++ b/tests/validation/reference/PoolingLayer.cpp @@ -43,9 +43,10 @@ SimpleTensor<T> pooling_layer_internal(const SimpleTensor<T> &src, const Pooling ARM_COMPUTE_ERROR_ON(info.is_global_pooling && (src.shape().x() != src.shape().y())); // Create reference SimpleTensor<T> dst{ compute_pool_shape(TensorInfo(src.shape(), 1, src.data_type()), info), src.data_type(), 1 }; + auto pooled_shape = compute_pool_shape(TensorInfo(src.shape(), 1, src.data_type()), info); if(indices) { - *indices = SimpleTensor<uint32_t> { compute_pool_shape(TensorInfo(src.shape(), 1, src.data_type()), info), DataType::U32, 1 }; + *indices = SimpleTensor<uint32_t> { pooled_shape, DataType::U32, 1 }; } const int pool_size_x = info.is_global_pooling ? src.shape().x() : info.pool_size.width; const int pool_size_y = info.is_global_pooling ? src.shape().y() : info.pool_size.height; @@ -58,56 +59,62 @@ SimpleTensor<T> pooling_layer_internal(const SimpleTensor<T> &src, const Pooling int pad_bottom = info.pad_stride_info.pad_bottom(); bool exclude_padding = info.exclude_padding; - const auto w_src = static_cast<int>(src.shape()[0]); - const auto h_src = static_cast<int>(src.shape()[1]); - const int upper_dims = src.shape().total_size() / (w_src * h_src); + const auto w_src = static_cast<int>(src.shape()[0]); + const auto h_src = static_cast<int>(src.shape()[1]); + const auto z_src = static_cast<int>(src.shape()[2]); + const auto b_src = static_cast<int>(src.shape()[3]); + + const int upper_dims = src.shape().total_size() / (w_src * h_src); + + const auto w_dst = static_cast<int>(dst.shape()[0]); + const auto h_dst = static_cast<int>(dst.shape()[1]); + const auto z_dst = static_cast<int>(dst.shape()[2]); - const auto w_dst = static_cast<int>(dst.shape()[0]); - const auto h_dst = static_cast<int>(dst.shape()[1]); TensorShape shape_nhwc(src.shape()); permute(shape_nhwc, PermutationVector(2U, 0U, 1U)); - if(type == PoolingType::MAX) { - for(int r = 0; r < upper_dims; ++r) + for(int b = 0; b < b_src; ++b) { - for(int h = 0; h < h_dst; ++h) + for(int r = 0; r < z_src; ++r) { - for(int w = 0; w < w_dst; ++w) + for(int h = 0; h < h_dst; ++h) { - int wstart = w * pool_stride_x - pad_left; - int hstart = h * pool_stride_y - pad_top; - int wend = std::min(wstart + pool_size_x, w_src); - int hend = std::min(hstart + pool_size_y, h_src); - wstart = std::max(wstart, 0); - hstart = std::max(hstart, 0); - - auto max_val = std::numeric_limits<ACC_T>::lowest(); - int max_index{ 0 }; - for(int y = hstart; y < hend; ++y) + for(int w = 0; w < w_dst; ++w) { - for(int x = wstart; x < wend; ++x) + int wstart = w * pool_stride_x - pad_left; + int hstart = h * pool_stride_y - pad_top; + int wend = std::min(wstart + pool_size_x, w_src); + int hend = std::min(hstart + pool_size_y, h_src); + wstart = std::max(wstart, 0); + hstart = std::max(hstart, 0); + auto max_val = std::numeric_limits<ACC_T>::lowest(); + int max_index{ 0 }; + for(int y = hstart; y < hend; ++y) { - const auto val = static_cast<ACC_T>(src[r * h_src * w_src + y * w_src + x]); - if(val > max_val) + for(int x = wstart; x < wend; ++x) { - max_val = val; - if(data_layout == DataLayout::NCHW) + const auto val = static_cast<ACC_T>(src[b * z_src * h_src * w_src + r * h_src * w_src + y * w_src + x]); + if(val > max_val) { - max_index = coord2index(src.shape(), Coordinates(x, y, r)); - } - else - { - max_index = coord2index(shape_nhwc, Coordinates(r, x, y)); + max_val = val; + if(data_layout == DataLayout::NCHW) + { + max_index = coord2index(src.shape(), Coordinates(x, y, r, 0)); + } + else + { + max_index = coord2index(shape_nhwc, Coordinates(r, x, y, 0)); + } } } } - } - dst[r * h_dst * w_dst + h * w_dst + w] = static_cast<T>(max_val); - if(indices) - { - (*indices)[r * h_dst * w_dst + h * w_dst + w] = max_index; + dst[b * z_dst * h_dst * w_dst + r * h_dst * w_dst + h * w_dst + w] = static_cast<T>(max_val); + if(indices) + { + (*indices)[b * z_dst * h_dst * w_dst + r * h_dst * w_dst + h * w_dst + w] = max_index; + } } } } @@ -164,7 +171,6 @@ SimpleTensor<T> pooling_layer_internal(const SimpleTensor<T> &src, const Pooling } } } - return dst; } |