aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/TestLayerVisitor.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/TestLayerVisitor.hpp')
-rw-r--r--src/armnn/test/TestLayerVisitor.hpp19
1 files changed, 15 insertions, 4 deletions
diff --git a/src/armnn/test/TestLayerVisitor.hpp b/src/armnn/test/TestLayerVisitor.hpp
index e43227f520..eaf1667800 100644
--- a/src/armnn/test/TestLayerVisitor.hpp
+++ b/src/armnn/test/TestLayerVisitor.hpp
@@ -4,13 +4,14 @@
//
#pragma once
-#include <armnn/LayerVisitorBase.hpp>
+#include <armnn/StrategyBase.hpp>
#include <armnn/Descriptors.hpp>
+#include <backendsCommon/TensorHandle.hpp>
namespace armnn
{
-// Abstract base class with do nothing implementations for all layer visit methods
-class TestLayerVisitor : public LayerVisitorBase<VisitorNoThrowPolicy>
+// Abstract base class with do nothing implementations for all layers
+class TestLayerVisitor : public StrategyBase<NoThrowStrategy>
{
protected:
virtual ~TestLayerVisitor() {}
@@ -19,7 +20,17 @@ protected:
void CheckLayerPointer(const IConnectableLayer* layer);
- void CheckConstTensors(const ConstTensor& expected, const ConstTensor& actual);
+ void CheckConstTensors(const ConstTensor& expected,
+ const ConstTensor& actual);
+ void CheckConstTensors(const ConstTensor& expected,
+ const ConstTensorHandle& actual);
+
+ void CheckConstTensorPtrs(const std::string& name,
+ const ConstTensor* expected,
+ const ConstTensor* actual);
+ void CheckConstTensorPtrs(const std::string& name,
+ const ConstTensor* expected,
+ const std::shared_ptr<ConstTensorHandle> actual);
void CheckOptionalConstTensors(const Optional<ConstTensor>& expected, const Optional<ConstTensor>& actual);