diff options
Diffstat (limited to 'test/Dilation.hpp')
-rw-r--r-- | test/Dilation.hpp | 54 |
1 files changed, 26 insertions, 28 deletions
diff --git a/test/Dilation.hpp b/test/Dilation.hpp index a05dba4f..c8adbe81 100644 --- a/test/Dilation.hpp +++ b/test/Dilation.hpp @@ -7,7 +7,7 @@ #include "DriverTestHelpers.hpp" -#include <armnn/LayerVisitorBase.hpp> +#include <armnn/StrategyBase.hpp> #include <armnn/utility/IgnoreUnused.hpp> #include <doctest/doctest.h> @@ -32,7 +32,7 @@ struct DilationTestOptions bool m_HasDilation; }; -class DilationTestVisitor : public LayerVisitorBase<VisitorThrowingPolicy> +class DilationTestVisitor : public StrategyBase<ThrowingStrategy> { public: DilationTestVisitor() : @@ -44,32 +44,30 @@ public: m_ExpectedDilationY{expectedDilationY} {} - void VisitConvolution2dLayer(const IConnectableLayer *layer, - const Convolution2dDescriptor& descriptor, - const ConstTensor& weights, - const Optional<ConstTensor>& biases, - const char *name = nullptr) override + void ExecuteStrategy(const armnn::IConnectableLayer* layer, + const armnn::BaseDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name, + const armnn::LayerBindingId id = 0) override { - IgnoreUnused(layer); - IgnoreUnused(weights); - IgnoreUnused(biases); - IgnoreUnused(name); - - CheckDilationParams(descriptor); - } - - void VisitDepthwiseConvolution2dLayer(const IConnectableLayer *layer, - const DepthwiseConvolution2dDescriptor& descriptor, - const ConstTensor& weights, - const Optional<ConstTensor>& biases, - const char *name = nullptr) override - { - IgnoreUnused(layer); - IgnoreUnused(weights); - IgnoreUnused(biases); - IgnoreUnused(name); - - CheckDilationParams(descriptor); + armnn::IgnoreUnused(layer, constants, id, name); + switch (layer->GetType()) + { + case armnn::LayerType::Convolution2d: + { + CheckDilationParams(static_cast<const armnn::Convolution2dDescriptor&>(descriptor)); + break; + } + case armnn::LayerType::DepthwiseConvolution2d: + { + CheckDilationParams(static_cast<const armnn::DepthwiseConvolution2dDescriptor&>(descriptor)); + break; + } + default: + { + m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType())); + } + } } private: @@ -170,5 +168,5 @@ void DilationTestImpl(const DilationTestOptions& options) // check if dilation params are as expected DilationTestVisitor visitor = options.m_HasDilation ? DilationTestVisitor(2, 2) : DilationTestVisitor(); - data.m_Network->Accept(visitor); + data.m_Network->ExecuteStrategy(visitor); } |