diff options
Diffstat (limited to 'tests/validation/TensorVisitors.h')
-rw-r--r-- | tests/validation/TensorVisitors.h | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/validation/TensorVisitors.h b/tests/validation/TensorVisitors.h index a274140734..1b58a16e4c 100644 --- a/tests/validation/TensorVisitors.h +++ b/tests/validation/TensorVisitors.h @@ -316,6 +316,29 @@ private: PoolingLayerInfo _pool_info; int _fixed_point_position; }; + +// ROI Pooling layer +struct roi_pooling_layer_visitor : public boost::static_visitor<> +{ +public: + explicit roi_pooling_layer_visitor(const TensorVariant &in, const std::vector<ROI> &rois, ROIPoolingLayerInfo pool_info) + : _in(in), _rois(rois), _pool_info(pool_info) + { + } + + template <typename T> + void operator()(Tensor<T> &out) const + { + const Tensor<T> &in = boost::get<Tensor<T>>(_in); + tensor_operations::roi_pooling_layer(in, out, _rois, _pool_info); + } + +private: + const TensorVariant &_in; + const std::vector<ROI> &_rois; + ROIPoolingLayerInfo _pool_info; +}; + // Softmax Layer visitor struct softmax_layer_visitor : public boost::static_visitor<> { |