aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/CreateWorkload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r--src/armnn/test/CreateWorkload.hpp32
1 files changed, 16 insertions, 16 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index b075744434..b07197797c 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -4,9 +4,7 @@
//
#pragma once
-#include <boost/test/unit_test.hpp>
-
-#include <boost/cast.hpp>
+#include "TestUtils.hpp"
#include <backendsCommon/WorkloadData.hpp>
#include <backendsCommon/WorkloadFactory.hpp>
@@ -17,8 +15,10 @@
#include <Network.hpp>
#include <ResolveType.hpp>
-#include <utility>
+#include <boost/test/unit_test.hpp>
+#include <boost/cast.hpp>
+#include <utility>
using namespace armnn;
@@ -40,13 +40,6 @@ std::unique_ptr<Workload> MakeAndCheckWorkload(Layer& layer, Graph& graph, const
return std::unique_ptr<Workload>(static_cast<Workload*>(workload.release()));
}
-// Connects two layers.
-void Connect(Layer* from, Layer* to, const TensorInfo& tensorInfo, unsigned int fromIndex = 0, unsigned int toIndex = 0)
-{
- from->GetOutputSlot(fromIndex).Connect(to->GetInputSlot(toIndex));
- from->GetOutputHandler(fromIndex).SetTensorInfo(tensorInfo);
-}
-
// Helper function to create tensor handlers for workloads, assuming they all use the same factory.
void CreateTensorHandles(armnn::Graph& graph, armnn::IWorkloadFactory& factory)
{
@@ -1280,23 +1273,30 @@ std::unique_ptr<ConstantWorkload> CreateConstantWorkloadTest(armnn::IWorkloadFac
return workloadConstant;
}
-template <typename PreluWorkload, armnn::DataType DataType>
+template <typename PreluWorkload>
std::unique_ptr<PreluWorkload> CreatePreluWorkloadTest(armnn::IWorkloadFactory& factory,
armnn::Graph& graph,
- const armnn::TensorShape& outputShape)
+ const armnn::TensorShape& inputShape,
+ const armnn::TensorShape& alphaShape,
+ const armnn::TensorShape& outputShape,
+ armnn::DataType dataType)
{
// Creates the PReLU layer
Layer* const layer = graph.AddLayer<PreluLayer>("prelu");
+ BOOST_CHECK(layer != nullptr);
// Creates extra layers
Layer* const input = graph.AddLayer<InputLayer> (0, "input");
Layer* const alpha = graph.AddLayer<InputLayer> (1, "alpha");
Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
+ BOOST_CHECK(input != nullptr);
+ BOOST_CHECK(alpha != nullptr);
+ BOOST_CHECK(output != nullptr);
// Connects up
- armnn::TensorInfo inputTensorInfo ({ 1, 4, 1, 2 }, DataType);
- armnn::TensorInfo alphaTensorInfo ({ 5, 4, 3, 1 }, DataType);
- armnn::TensorInfo outputTensorInfo(outputShape, DataType);
+ armnn::TensorInfo inputTensorInfo (inputShape, dataType);
+ armnn::TensorInfo alphaTensorInfo (alphaShape, dataType);
+ armnn::TensorInfo outputTensorInfo(outputShape, dataType);
Connect(input, layer, inputTensorInfo, 0, 0);
Connect(alpha, layer, alphaTensorInfo, 0, 1);
Connect(layer, output, outputTensorInfo, 0, 0);