diff options
Diffstat (limited to 'test/Dilation.hpp')
-rw-r--r-- | test/Dilation.hpp | 71 |
1 files changed, 32 insertions, 39 deletions
diff --git a/test/Dilation.hpp b/test/Dilation.hpp index d0189c96..dbd24933 100644 --- a/test/Dilation.hpp +++ b/test/Dilation.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // @@ -7,17 +7,12 @@ #include "DriverTestHelpers.hpp" -#include <armnn/LayerVisitorBase.hpp> +#include <armnn/StrategyBase.hpp> #include <armnn/utility/IgnoreUnused.hpp> -#include <boost/test/unit_test.hpp> - #include <numeric> -BOOST_AUTO_TEST_SUITE(DilationTests) - using namespace armnn; -using namespace boost; using namespace driverTestHelpers; struct DilationTestOptions @@ -35,7 +30,7 @@ struct DilationTestOptions bool m_HasDilation; }; -class DilationTestVisitor : public LayerVisitorBase<VisitorThrowingPolicy> +class DilationTestVisitor : public StrategyBase<ThrowingStrategy> { public: DilationTestVisitor() : @@ -47,32 +42,32 @@ public: m_ExpectedDilationY{expectedDilationY} {} - void VisitConvolution2dLayer(const IConnectableLayer *layer, - const Convolution2dDescriptor& descriptor, - const ConstTensor& weights, - const Optional<ConstTensor>& biases, - const char *name = nullptr) override - { - IgnoreUnused(layer); - IgnoreUnused(weights); - IgnoreUnused(biases); - IgnoreUnused(name); - - CheckDilationParams(descriptor); - } - - void VisitDepthwiseConvolution2dLayer(const IConnectableLayer *layer, - const DepthwiseConvolution2dDescriptor& descriptor, - const ConstTensor& weights, - const Optional<ConstTensor>& biases, - const char *name = nullptr) override + void ExecuteStrategy(const armnn::IConnectableLayer* layer, + const armnn::BaseDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name, + const armnn::LayerBindingId id = 0) override { - IgnoreUnused(layer); - IgnoreUnused(weights); - IgnoreUnused(biases); - IgnoreUnused(name); - - CheckDilationParams(descriptor); + armnn::IgnoreUnused(layer, constants, id, name); + switch (layer->GetType()) + { + case armnn::LayerType::Constant: + break; + case armnn::LayerType::Convolution2d: + { + CheckDilationParams(static_cast<const armnn::Convolution2dDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::DepthwiseConvolution2d: + { + CheckDilationParams(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor)); + break; + } + default: + { + m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType())); + } + } } private: @@ -82,8 +77,8 @@ private: template<typename ConvolutionDescriptor> void CheckDilationParams(const ConvolutionDescriptor& descriptor) { - BOOST_CHECK_EQUAL(descriptor.m_DilationX, m_ExpectedDilationX); - BOOST_CHECK_EQUAL(descriptor.m_DilationY, m_ExpectedDilationY); + CHECK_EQ(descriptor.m_DilationX, m_ExpectedDilationX); + CHECK_EQ(descriptor.m_DilationY, m_ExpectedDilationY); } }; @@ -169,11 +164,9 @@ void DilationTestImpl(const DilationTestOptions& options) data.m_OutputSlotForOperand = std::vector<IOutputSlot*>(model.operands.size(), nullptr); bool ok = HalPolicy::ConvertOperation(model.operations[0], model, data); - BOOST_CHECK(ok); + DOCTEST_CHECK(ok); // check if dilation params are as expected DilationTestVisitor visitor = options.m_HasDilation ? DilationTestVisitor(2, 2) : DilationTestVisitor(); - data.m_Network->Accept(visitor); + data.m_Network->ExecuteStrategy(visitor); } - -BOOST_AUTO_TEST_SUITE_END() |