aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/TensorVisitors.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/TensorVisitors.h')
-rw-r--r--tests/validation/TensorVisitors.h23
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/validation/TensorVisitors.h b/tests/validation/TensorVisitors.h
index a274140734..1b58a16e4c 100644
--- a/tests/validation/TensorVisitors.h
+++ b/tests/validation/TensorVisitors.h
@@ -316,6 +316,29 @@ private:
PoolingLayerInfo _pool_info;
int _fixed_point_position;
};
+
+// ROI Pooling layer
+struct roi_pooling_layer_visitor : public boost::static_visitor<>
+{
+public:
+ explicit roi_pooling_layer_visitor(const TensorVariant &in, const std::vector<ROI> &rois, ROIPoolingLayerInfo pool_info)
+ : _in(in), _rois(rois), _pool_info(pool_info)
+ {
+ }
+
+ template <typename T>
+ void operator()(Tensor<T> &out) const
+ {
+ const Tensor<T> &in = boost::get<Tensor<T>>(_in);
+ tensor_operations::roi_pooling_layer(in, out, _rois, _pool_info);
+ }
+
+private:
+ const TensorVariant &_in;
+ const std::vector<ROI> &_rois;
+ ROIPoolingLayerInfo _pool_info;
+};
+
// Softmax Layer visitor
struct softmax_layer_visitor : public boost::static_visitor<>
{