aboutsummaryrefslogtreecommitdiff
path: root/tests/datasets/DepthwiseConvolutionLayerDataset.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/datasets/DepthwiseConvolutionLayerDataset.h')
-rw-r--r--tests/datasets/DepthwiseConvolutionLayerDataset.h23
1 files changed, 15 insertions, 8 deletions
diff --git a/tests/datasets/DepthwiseConvolutionLayerDataset.h b/tests/datasets/DepthwiseConvolutionLayerDataset.h
index 8a1f3b6f39..4c78eb87ea 100644
--- a/tests/datasets/DepthwiseConvolutionLayerDataset.h
+++ b/tests/datasets/DepthwiseConvolutionLayerDataset.h
@@ -38,16 +38,18 @@ namespace datasets
class DepthwiseConvolutionLayerDataset
{
public:
- using type = std::tuple<TensorShape, Size2D, PadStrideInfo>;
+ using type = std::tuple<TensorShape, Size2D, PadStrideInfo, Size2D>;
struct iterator
{
iterator(std::vector<TensorShape>::const_iterator src_it,
std::vector<Size2D>::const_iterator weights_it,
- std::vector<PadStrideInfo>::const_iterator infos_it)
+ std::vector<PadStrideInfo>::const_iterator infos_it,
+ std::vector<Size2D>::const_iterator dilation_it)
: _src_it{ std::move(src_it) },
_weights_it{ std::move(weights_it) },
- _infos_it{ std::move(infos_it) }
+ _infos_it{ std::move(infos_it) },
+ _dilation_it{ std::move(dilation_it) }
{
}
@@ -56,13 +58,14 @@ public:
std::stringstream description;
description << "In=" << *_src_it << ":";
description << "Weights=" << *_weights_it << ":";
- description << "Info=" << *_infos_it;
+ description << "Info=" << *_infos_it << ":";
+ description << "Dilation=" << *_dilation_it;
return description.str();
}
DepthwiseConvolutionLayerDataset::type operator*() const
{
- return std::make_tuple(*_src_it, *_weights_it, *_infos_it);
+ return std::make_tuple(*_src_it, *_weights_it, *_infos_it, *_dilation_it);
}
iterator &operator++()
@@ -70,6 +73,7 @@ public:
++_src_it;
++_weights_it;
++_infos_it;
+ ++_dilation_it;
return *this;
}
@@ -78,23 +82,25 @@ public:
std::vector<TensorShape>::const_iterator _src_it;
std::vector<Size2D>::const_iterator _weights_it;
std::vector<PadStrideInfo>::const_iterator _infos_it;
+ std::vector<Size2D>::const_iterator _dilation_it;
};
iterator begin() const
{
- return iterator(_src_shapes.begin(), _weight_shapes.begin(), _infos.begin());
+ return iterator(_src_shapes.begin(), _weight_shapes.begin(), _infos.begin(), _dilations.begin());
}
int size() const
{
- return std::min(_src_shapes.size(), std::min(_weight_shapes.size(), _infos.size()));
+ return std::min(_src_shapes.size(), std::min(_weight_shapes.size(), std::min(_infos.size(), _dilations.size())));
}
- void add_config(TensorShape src, Size2D weights, PadStrideInfo info)
+ void add_config(TensorShape src, Size2D weights, PadStrideInfo info, Size2D dilation = Size2D(1U, 1U))
{
_src_shapes.emplace_back(std::move(src));
_weight_shapes.emplace_back(std::move(weights));
_infos.emplace_back(std::move(info));
+ _dilations.emplace_back(std::move(dilation));
}
protected:
@@ -105,6 +111,7 @@ private:
std::vector<TensorShape> _src_shapes{};
std::vector<Size2D> _weight_shapes{};
std::vector<PadStrideInfo> _infos{};
+ std::vector<Size2D> _dilations{};
};
/** Dataset containing small, generic depthwise convolution shapes. */