aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/Conv2dTestImpl.hpp
diff options
context:
space:
mode:
authorBruno Goncalves <bruno.slackware@gmail.com>2019-04-26 21:03:24 -0300
committerMatteo Martincigh <matteo.martincigh@arm.com>2019-05-14 07:36:33 +0000
commit22972f04d6aa5c4d1269ed3be8bc5fb7b508dcda (patch)
treebe61df59d9eb215bf1396a5a03797942066c6ba6 /src/backends/backendsCommon/test/Conv2dTestImpl.hpp
parentacad04e3cb6abcdd9a3fcf4584db1cbedb52cb47 (diff)
downloadarmnn-22972f04d6aa5c4d1269ed3be8bc5fb7b508dcda.tar.gz
MLCE-101 Add dilation support for DepthWiseConv workload
Adds unit tests for dilated depthwise conv Change-Id: Iad0a1b33d07fb0ef8f9f6edf0fd0f83a5800a36d Signed-off-by: Bruno Goncalves <bruno.slackware@gmail.com> Signed-off-by: Matthew Bentham <matthew.bentham@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/test/Conv2dTestImpl.hpp')
-rwxr-xr-xsrc/backends/backendsCommon/test/Conv2dTestImpl.hpp6
1 files changed, 5 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/test/Conv2dTestImpl.hpp b/src/backends/backendsCommon/test/Conv2dTestImpl.hpp
index c2e539b20f..bb5656bd01 100755
--- a/src/backends/backendsCommon/test/Conv2dTestImpl.hpp
+++ b/src/backends/backendsCommon/test/Conv2dTestImpl.hpp
@@ -827,7 +827,9 @@ LayerTestResult<T, 4> DepthwiseConvolution2dNhwcTestImpl(
uint32_t padRight = 0,
uint32_t padBottom = 0,
uint32_t strideX = 1,
- uint32_t strideY = 1)
+ uint32_t strideY = 1,
+ uint32_t dilationX = 1,
+ uint32_t dilationY = 1)
{
unsigned int inputNum = boost::numeric_cast<unsigned int>(input.shape()[0]);
unsigned int inputChannels = boost::numeric_cast<unsigned int>(input.shape()[3]);
@@ -894,6 +896,8 @@ LayerTestResult<T, 4> DepthwiseConvolution2dNhwcTestImpl(
data.m_Parameters.m_PadTop = padTop;
data.m_Parameters.m_PadBottom = padBottom;
data.m_Parameters.m_DataLayout = armnn::DataLayout::NHWC;
+ data.m_Parameters.m_DilationX = dilationX;
+ data.m_Parameters.m_DilationY = dilationY;
armnn::WorkloadInfo info;
AddInputToWorkload(data, info, inputTensorInfo, inputHandle.get());