diff options
Diffstat (limited to 'src/armnn/test/TestLayerVisitor.hpp')
-rw-r--r-- | src/armnn/test/TestLayerVisitor.hpp | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/src/armnn/test/TestLayerVisitor.hpp b/src/armnn/test/TestLayerVisitor.hpp index e43227f520..eaf1667800 100644 --- a/src/armnn/test/TestLayerVisitor.hpp +++ b/src/armnn/test/TestLayerVisitor.hpp @@ -4,13 +4,14 @@ // #pragma once -#include <armnn/LayerVisitorBase.hpp> +#include <armnn/StrategyBase.hpp> #include <armnn/Descriptors.hpp> +#include <backendsCommon/TensorHandle.hpp> namespace armnn { -// Abstract base class with do nothing implementations for all layer visit methods -class TestLayerVisitor : public LayerVisitorBase<VisitorNoThrowPolicy> +// Abstract base class with do nothing implementations for all layers +class TestLayerVisitor : public StrategyBase<NoThrowStrategy> { protected: virtual ~TestLayerVisitor() {} @@ -19,7 +20,17 @@ protected: void CheckLayerPointer(const IConnectableLayer* layer); - void CheckConstTensors(const ConstTensor& expected, const ConstTensor& actual); + void CheckConstTensors(const ConstTensor& expected, + const ConstTensor& actual); + void CheckConstTensors(const ConstTensor& expected, + const ConstTensorHandle& actual); + + void CheckConstTensorPtrs(const std::string& name, + const ConstTensor* expected, + const ConstTensor* actual); + void CheckConstTensorPtrs(const std::string& name, + const ConstTensor* expected, + const std::shared_ptr<ConstTensorHandle> actual); void CheckOptionalConstTensors(const Optional<ConstTensor>& expected, const Optional<ConstTensor>& actual); |