diff options
Diffstat (limited to 'src/armnn/test/ConstTensorLayerVisitor.hpp')
-rw-r--r-- | src/armnn/test/ConstTensorLayerVisitor.hpp | 117 |
1 files changed, 22 insertions, 95 deletions
diff --git a/src/armnn/test/ConstTensorLayerVisitor.hpp b/src/armnn/test/ConstTensorLayerVisitor.hpp index 3b0f723542..513a471465 100644 --- a/src/armnn/test/ConstTensorLayerVisitor.hpp +++ b/src/armnn/test/ConstTensorLayerVisitor.hpp @@ -16,58 +16,34 @@ class TestConvolution2dLayerVisitor : public TestLayerVisitor public: explicit TestConvolution2dLayerVisitor(const Convolution2dDescriptor& convolution2dDescriptor, const ConstTensor& weights, + const Optional<ConstTensor>& biases, const char* name = nullptr) : TestLayerVisitor(name), m_Descriptor(convolution2dDescriptor), - m_Weights(weights) {}; + m_Weights(weights), + m_Biases(biases) {}; virtual ~TestConvolution2dLayerVisitor() {}; void VisitConvolution2dLayer(const IConnectableLayer* layer, const Convolution2dDescriptor& convolution2dDescriptor, const ConstTensor& weights, + const Optional<ConstTensor>& biases, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); CheckDescriptor(convolution2dDescriptor); CheckConstTensors(m_Weights, weights); + CheckOptionalConstTensors(m_Biases, biases); } protected: void CheckDescriptor(const Convolution2dDescriptor& convolution2dDescriptor); private: - armnn::Convolution2dDescriptor m_Descriptor; - armnn::ConstTensor m_Weights; -}; - -class TestConvolution2dWithBiasLayerVisitor : public TestConvolution2dLayerVisitor -{ -public: - explicit TestConvolution2dWithBiasLayerVisitor(const Convolution2dDescriptor& convolution2dDescriptor, - const ConstTensor& weights, - const ConstTensor& biases, - const char* name = nullptr) : - TestConvolution2dLayerVisitor( - convolution2dDescriptor, weights, name), - m_Biases(biases) {}; - - // needed to suppress crappy error message about base class function i.e. version - // without the biases argument being hidden - using TestConvolution2dLayerVisitor::VisitConvolution2dLayer; - - void VisitConvolution2dLayer(const IConnectableLayer* layer, - const Convolution2dDescriptor& convolution2dDescriptor, - const ConstTensor& weights, - const ConstTensor& biases, - const char* name = nullptr) override - { - TestConvolution2dLayerVisitor::VisitConvolution2dLayer(layer, convolution2dDescriptor, weights, name); - CheckConstTensors(m_Biases, biases); - } - -private: - armnn::ConstTensor m_Biases; + Convolution2dDescriptor m_Descriptor; + ConstTensor m_Weights; + Optional<ConstTensor> m_Biases; }; class TestDepthwiseConvolution2dLayerVisitor : public TestLayerVisitor @@ -75,60 +51,34 @@ class TestDepthwiseConvolution2dLayerVisitor : public TestLayerVisitor public: explicit TestDepthwiseConvolution2dLayerVisitor(const DepthwiseConvolution2dDescriptor& descriptor, const ConstTensor& weights, + const Optional<ConstTensor>& biases, const char* name = nullptr) : TestLayerVisitor(name), m_Descriptor(descriptor), - m_Weights(weights) {}; + m_Weights(weights), + m_Biases(biases) {}; virtual ~TestDepthwiseConvolution2dLayerVisitor() {}; void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer, const DepthwiseConvolution2dDescriptor& convolution2dDescriptor, const ConstTensor& weights, + const Optional<ConstTensor>& biases, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); CheckDescriptor(convolution2dDescriptor); CheckConstTensors(m_Weights, weights); + CheckOptionalConstTensors(m_Biases, biases); } protected: void CheckDescriptor(const DepthwiseConvolution2dDescriptor& convolution2dDescriptor); private: - armnn::DepthwiseConvolution2dDescriptor m_Descriptor; - armnn::ConstTensor m_Weights; -}; - -class TestDepthwiseConvolution2dWithBiasLayerVisitor : public TestDepthwiseConvolution2dLayerVisitor -{ -public: - explicit TestDepthwiseConvolution2dWithBiasLayerVisitor(const DepthwiseConvolution2dDescriptor& descriptor, - const ConstTensor& weights, - const ConstTensor& biases, - const char* name = nullptr) : - TestDepthwiseConvolution2dLayerVisitor(descriptor, weights, name), - m_Biases(biases) {}; - - ~TestDepthwiseConvolution2dWithBiasLayerVisitor() {}; - - // needed to suppress crappy error message about base class function i.e. version - // without the biases argument being hidden - using TestDepthwiseConvolution2dLayerVisitor::VisitDepthwiseConvolution2dLayer; - - void VisitDepthwiseConvolution2dLayer(const IConnectableLayer* layer, - const DepthwiseConvolution2dDescriptor& convolution2dDescriptor, - const ConstTensor& weights, - const ConstTensor& biases, - const char* name = nullptr) override - { - TestDepthwiseConvolution2dLayerVisitor::VisitDepthwiseConvolution2dLayer( - layer, convolution2dDescriptor, weights, name); - CheckConstTensors(m_Biases, biases); - } - -private: - armnn::ConstTensor m_Biases; + DepthwiseConvolution2dDescriptor m_Descriptor; + ConstTensor m_Weights; + Optional<ConstTensor> m_Biases; }; class TestFullyConnectedLayerVistor : public TestLayerVisitor @@ -136,21 +86,25 @@ class TestFullyConnectedLayerVistor : public TestLayerVisitor public: explicit TestFullyConnectedLayerVistor(const FullyConnectedDescriptor& descriptor, const ConstTensor& weights, + const Optional<ConstTensor> biases, const char* name = nullptr) : TestLayerVisitor(name), m_Descriptor(descriptor), - m_Weights(weights) {}; + m_Weights(weights), + m_Biases(biases) {}; virtual ~TestFullyConnectedLayerVistor() {}; void VisitFullyConnectedLayer(const IConnectableLayer* layer, const FullyConnectedDescriptor& fullyConnectedDescriptor, const ConstTensor& weights, + const Optional<ConstTensor>& biases, const char* name = nullptr) override { CheckLayerPointer(layer); CheckLayerName(name); CheckDescriptor(fullyConnectedDescriptor); CheckConstTensors(m_Weights, weights); + CheckOptionalConstTensors(m_Biases, biases); } protected: @@ -158,34 +112,7 @@ protected: private: FullyConnectedDescriptor m_Descriptor; ConstTensor m_Weights; -}; - -class TestFullyConnectedLayerWithBiasesVisitor : public TestFullyConnectedLayerVistor -{ -public: - explicit TestFullyConnectedLayerWithBiasesVisitor(const FullyConnectedDescriptor& descriptor, - const ConstTensor& weights, - const ConstTensor& biases, - const char* name = nullptr) : - TestFullyConnectedLayerVistor(descriptor, weights, name), - m_Biases(biases) {}; - - // needed to suppress crappy error message about base class function i.e. version - // without the biases argument being hidden - using TestFullyConnectedLayerVistor::VisitFullyConnectedLayer; - - void VisitFullyConnectedLayer(const IConnectableLayer* layer, - const FullyConnectedDescriptor& fullyConnectedDescriptor, - const ConstTensor& weights, - const ConstTensor& biases, - const char* name = nullptr) override - { - TestFullyConnectedLayerVistor::VisitFullyConnectedLayer(layer, fullyConnectedDescriptor, weights, name); - CheckConstTensors(m_Biases, biases); - } - -private: - ConstTensor m_Biases; + Optional<ConstTensor> m_Biases; }; class TestBatchNormalizationLayerVisitor : public TestLayerVisitor |