aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/test/CreateWorkload.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/test/CreateWorkload.hpp')
-rw-r--r--src/armnn/test/CreateWorkload.hpp11
1 files changed, 9 insertions, 2 deletions
diff --git a/src/armnn/test/CreateWorkload.hpp b/src/armnn/test/CreateWorkload.hpp
index f3cf544fa3..51820a425f 100644
--- a/src/armnn/test/CreateWorkload.hpp
+++ b/src/armnn/test/CreateWorkload.hpp
@@ -517,9 +517,16 @@ std::unique_ptr<NormalizationWorkload> CreateNormalizationWorkloadTest(armnn::IW
Layer* const input = graph.AddLayer<InputLayer>(0, "input");
Layer* const output = graph.AddLayer<OutputLayer>(0, "output");
+ TensorShape inputShape = (dataLayout == DataLayout::NCHW) ?
+ TensorShape{ 3, 5, 5, 1 } : TensorShape{ 3, 1, 5, 5 };
+ TensorShape outputShape = (dataLayout == DataLayout::NCHW) ?
+ TensorShape{ 3, 5, 5, 1 } : TensorShape{ 3, 1, 5, 5 };
+
// Connects up.
- Connect(input, layer, TensorInfo({3, 5, 5, 1}, DataType));
- Connect(layer, output, TensorInfo({3, 5, 5, 1}, DataType));
+ armnn::TensorInfo inputTensorInfo(inputShape, DataType);
+ armnn::TensorInfo outputTensorInfo(outputShape, DataType);
+ Connect(input, layer, inputTensorInfo);
+ Connect(layer, output, outputTensorInfo);
CreateTensorHandles(graph, factory);
// Makes the workload and checks it.