aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/ConstTensorLayerVisitor.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/ConstTensorLayerVisitor.hpp')
-rw-r--r--src/armnn/test/ConstTensorLayerVisitor.hpp117
1 files changed, 22 insertions, 95 deletions
diff --git a/src/armnn/test/ConstTensorLayerVisitor.hpp b/src/armnn/test/ConstTensorLayerVisitor.hpp
index 3b0f723542..513a471465 100644
--- a/src/armnn/test/ConstTensorLayerVisitor.hpp
+++ b/src/armnn/test/ConstTensorLayerVisitor.hpp
@@ -16,58 +16,34 @@ class TestConvolution2dLayerVisitor : public TestLayerVisitor
public:
explicit TestConvolution2dLayerVisitor(const Convolution2dDescriptor& convolution2dDescriptor,
const ConstTensor& weights,
+ const Optional<ConstTensor>& biases,
const char* name = nullptr) : TestLayerVisitor(name),
m_Descriptor(convolution2dDescriptor),
- m_Weights(weights) {};
+ m_Weights(weights),
+ m_Biases(biases) {};
virtual ~TestConvolution2dLayerVisitor() {};
void VisitConvolution2dLayer(const IConnectableLayer* layer,
const Convolution2dDescriptor& convolution2dDescriptor,
const ConstTensor& weights,
+ const Optional<ConstTensor>& biases,
const char* name = nullptr) override
{
CheckLayerPointer(layer);
CheckLayerName(name);
CheckDescriptor(convolution2dDescriptor);
CheckConstTensors(m_Weights, weights);
+ CheckOptionalConstTensors(m_Biases, biases);
}
protected:
void CheckDescriptor(const Convolution2dDescriptor& convolution2dDescriptor);
private:
- armnn::Convolution2dDescriptor m_Descriptor;
- armnn::ConstTensor m_Weights;
-};
-
-class TestConvolution2dWithBiasLayerVisitor : public TestConvolution2dLayerVisitor
-{
-public:
- explicit TestConvolution2dWithBiasLayerVisitor(const Convolution2dDescriptor& convolution2dDescriptor,
- const ConstTensor& weights,
- const ConstTensor& biases,
- const char* name = nullptr) :
- TestConvolution2dLayerVisitor(
- convolution2dDescriptor, weights, name),
- m_Biases(biases) {};
-
- // needed to suppress crappy error message about base class function i.e. version
- // without the biases argument being hidden
- using TestConvolution2dLayerVisitor::VisitConvolution2dLayer;
-
- void VisitConvolution2dLayer(const IConnectableLayer* layer,
- const Convolution2dDescriptor& convolution2dDescriptor,
- const ConstTensor& weights,
- const ConstTensor& biases,
- const char* name = nullptr) override
- {
- TestConvolution2dLayerVisitor::VisitConvolution2dLayer(layer, convolution2dDescriptor, weights, name);
- CheckConstTensors(m_Biases, biases);
- }
-
-private:
- armnn::ConstTensor m_Biases;
+ Convolution2dDescriptor m_Descriptor;
+ ConstTensor m_Weights;
+ Optional<ConstTensor> m_Biases;
};
class TestDepthwiseConvolution2dLayerVisitor : public TestLayerVisitor
@@ -75,60 +51,34 @@ class TestDepthwiseConvolution2dLayerVisitor : public TestLayerVisitor
public:
explicit TestDepthwiseConvolution2dLayerVisitor(const DepthwiseConvolution2dDescriptor& descriptor,
const ConstTensor& weights,
+ const Optional<ConstTensor>& biases,
const char* name = nullptr) : TestLayerVisitor(name),
m_Descriptor(descriptor),
- m_Weights(weights) {};
+ m_Weights(weights),
+ m_Biases(biases) {};
virtual ~TestDepthwiseConvolution2dLayerVisitor() {};
void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer,
const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
const ConstTensor& weights,
+ const Optional<ConstTensor>& biases,
const char* name = nullptr) override
{
CheckLayerPointer(layer);
CheckLayerName(name);
CheckDescriptor(convolution2dDescriptor);
CheckConstTensors(m_Weights, weights);
+ CheckOptionalConstTensors(m_Biases, biases);
}
protected:
void CheckDescriptor(const DepthwiseConvolution2dDescriptor& convolution2dDescriptor);
private:
- armnn::DepthwiseConvolution2dDescriptor m_Descriptor;
- armnn::ConstTensor m_Weights;
-};
-
-class TestDepthwiseConvolution2dWithBiasLayerVisitor : public TestDepthwiseConvolution2dLayerVisitor
-{
-public:
- explicit TestDepthwiseConvolution2dWithBiasLayerVisitor(const DepthwiseConvolution2dDescriptor& descriptor,
- const ConstTensor& weights,
- const ConstTensor& biases,
- const char* name = nullptr) :
- TestDepthwiseConvolution2dLayerVisitor(descriptor, weights, name),
- m_Biases(biases) {};
-
- ~TestDepthwiseConvolution2dWithBiasLayerVisitor() {};
-
- // needed to suppress crappy error message about base class function i.e. version
- // without the biases argument being hidden
- using TestDepthwiseConvolution2dLayerVisitor::VisitDepthwiseConvolution2dLayer;
-
- void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer,
- const DepthwiseConvolution2dDescriptor& convolution2dDescriptor,
- const ConstTensor& weights,
- const ConstTensor& biases,
- const char* name = nullptr) override
- {
- TestDepthwiseConvolution2dLayerVisitor::VisitDepthwiseConvolution2dLayer(
- layer, convolution2dDescriptor, weights, name);
- CheckConstTensors(m_Biases, biases);
- }
-
-private:
- armnn::ConstTensor m_Biases;
+ DepthwiseConvolution2dDescriptor m_Descriptor;
+ ConstTensor m_Weights;
+ Optional<ConstTensor> m_Biases;
};
class TestFullyConnectedLayerVistor : public TestLayerVisitor
@@ -136,21 +86,25 @@ class TestFullyConnectedLayerVistor : public TestLayerVisitor
public:
explicit TestFullyConnectedLayerVistor(const FullyConnectedDescriptor& descriptor,
const ConstTensor& weights,
+ const Optional<ConstTensor> biases,
const char* name = nullptr) : TestLayerVisitor(name),
m_Descriptor(descriptor),
- m_Weights(weights) {};
+ m_Weights(weights),
+ m_Biases(biases) {};
virtual ~TestFullyConnectedLayerVistor() {};
void VisitFullyConnectedLayer(const IConnectableLayer* layer,
const FullyConnectedDescriptor& fullyConnectedDescriptor,
const ConstTensor& weights,
+ const Optional<ConstTensor>& biases,
const char* name = nullptr) override
{
CheckLayerPointer(layer);
CheckLayerName(name);
CheckDescriptor(fullyConnectedDescriptor);
CheckConstTensors(m_Weights, weights);
+ CheckOptionalConstTensors(m_Biases, biases);
}
protected:
@@ -158,34 +112,7 @@ protected:
private:
FullyConnectedDescriptor m_Descriptor;
ConstTensor m_Weights;
-};
-
-class TestFullyConnectedLayerWithBiasesVisitor : public TestFullyConnectedLayerVistor
-{
-public:
- explicit TestFullyConnectedLayerWithBiasesVisitor(const FullyConnectedDescriptor& descriptor,
- const ConstTensor& weights,
- const ConstTensor& biases,
- const char* name = nullptr) :
- TestFullyConnectedLayerVistor(descriptor, weights, name),
- m_Biases(biases) {};
-
- // needed to suppress crappy error message about base class function i.e. version
- // without the biases argument being hidden
- using TestFullyConnectedLayerVistor::VisitFullyConnectedLayer;
-
- void VisitFullyConnectedLayer(const IConnectableLayer* layer,
- const FullyConnectedDescriptor& fullyConnectedDescriptor,
- const ConstTensor& weights,
- const ConstTensor& biases,
- const char* name = nullptr) override
- {
- TestFullyConnectedLayerVistor::VisitFullyConnectedLayer(layer, fullyConnectedDescriptor, weights, name);
- CheckConstTensors(m_Biases, biases);
- }
-
-private:
- ConstTensor m_Biases;
+ Optional<ConstTensor> m_Biases;
};
class TestBatchNormalizationLayerVisitor : public TestLayerVisitor