aboutsummaryrefslogtreecommitdiff
path: root/src/backends/test/LayerTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/test/LayerTests.cpp')
-rwxr-xr-xsrc/backends/test/LayerTests.cpp152
1 files changed, 148 insertions, 4 deletions
diff --git a/src/backends/test/LayerTests.cpp b/src/backends/test/LayerTests.cpp
index c28a1d46ad..1faacacb5c 100755
--- a/src/backends/test/LayerTests.cpp
+++ b/src/backends/test/LayerTests.cpp
@@ -5338,14 +5338,158 @@ LayerTestResult<uint8_t, 4> ResizeBilinearMagUint8Test(armnn::IWorkloadFactory&
LayerTestResult<float, 4> BatchNormTest(armnn::IWorkloadFactory& workloadFactory)
{
- auto ret = BatchNormTestImpl<float>(workloadFactory, 0.f, 0);
- return ret;
+ // BatchSize: 1
+ // Channels: 2
+ // Height: 3
+ // Width: 2
+
+ const armnn::TensorShape inputOutputShape{ 1, 2, 3, 2 };
+ std::vector<float> inputValues
+ {
+ // Batch 0, Channel 0, Height (3) x Width (2)
+ 1.f, 4.f,
+ 4.f, 2.f,
+ 1.f, 6.f,
+
+ // Batch 0, Channel 1, Height (3) x Width (2)
+ 1.f, 1.f,
+ 4.f, 1.f,
+ -2.f, 4.f
+ };
+ std::vector<float> expectedOutputValues
+ {
+ // Batch 0, Channel 0, Height (3) x Width (2)
+ 1.f, 4.f,
+ 4.f, 2.f,
+ 1.f, 6.f,
+
+ // Batch 0, Channel 1, Height (3) x Width (2)
+ 3.f, 3.f,
+ 4.f, 3.f,
+ 2.f, 4.f
+ };
+
+ return BatchNormTestImpl<float>(workloadFactory, inputOutputShape, inputValues, expectedOutputValues,
+ 0.f, 0, armnn::DataLayout::NCHW);
+}
+
+LayerTestResult<float, 4> BatchNormNhwcTest(armnn::IWorkloadFactory& workloadFactory)
+{
+ // BatchSize: 1
+ // Height: 3
+ // Width: 2
+ // Channels: 2
+
+ const armnn::TensorShape inputOutputShape{ 1, 3, 2, 2 };
+ std::vector<float> inputValues
+ {
+ // Batch 0, Height 0, Width (2) x Channel (2)
+ 1.f, 1.f,
+ 4.f, 1.f,
+
+ // Batch 0, Height 1, Width (2) x Channel (2)
+ 4.f, 4.f,
+ 2.f, 1.f,
+
+ // Batch 0, Height 2, Width (2) x Channel (2)
+ 1.f, -2.f,
+ 6.f, 4.f
+ };
+ std::vector<float> expectedOutputValues
+ {
+ // Batch 0, Height 0, Width (2) x Channel (2)
+ 1.f, 3.f,
+ 4.f, 3.f,
+
+ // Batch 0, Height 1, Width (2) x Channel (2)
+ 4.f, 4.f,
+ 2.f, 3.f,
+
+ // Batch 0, Height 2, Width (2) x Channel (2)
+ 1.f, 2.f,
+ 6.f, 4.f
+ };
+
+ return BatchNormTestImpl<float>(workloadFactory, inputOutputShape, inputValues, expectedOutputValues,
+ 0.f, 0, armnn::DataLayout::NHWC);
}
LayerTestResult<uint8_t, 4> BatchNormUint8Test(armnn::IWorkloadFactory& workloadFactory)
{
- auto ret = BatchNormTestImpl<uint8_t>(workloadFactory, 1.f/20.f, 50);
- return ret;
+ // BatchSize: 1
+ // Channels: 2
+ // Height: 3
+ // Width: 2
+
+ const armnn::TensorShape inputOutputShape{ 1, 2, 3, 2 };
+ std::vector<float> inputValues
+ {
+ // Batch 0, Channel 0, Height (3) x Width (2)
+ 1.f, 4.f,
+ 4.f, 2.f,
+ 1.f, 6.f,
+
+ // Batch 0, Channel 1, Height (3) x Width (2)
+ 1.f, 1.f,
+ 4.f, 1.f,
+ -2.f, 4.f
+ };
+ std::vector<float> expectedOutputValues
+ {
+ // Batch 0, Channel 0, Height (3) x Width (2)
+ 1.f, 4.f,
+ 4.f, 2.f,
+ 1.f, 6.f,
+
+ // Batch 0, Channel 1, Height (3) x Width (2)
+ 3.f, 3.f,
+ 4.f, 3.f,
+ 2.f, 4.f
+ };
+
+ return BatchNormTestImpl<uint8_t>(workloadFactory, inputOutputShape, inputValues, expectedOutputValues,
+ 1.f/20.f, 50, armnn::DataLayout::NCHW);
+}
+
+LayerTestResult<uint8_t, 4> BatchNormUint8NhwcTest(armnn::IWorkloadFactory& workloadFactory)
+{
+ // BatchSize: 1
+ // Height: 3
+ // Width: 2
+ // Channels: 2
+
+ const armnn::TensorShape inputOutputShape{ 1, 3, 2, 2 };
+ std::vector<float> inputValues
+ {
+ // Batch 0, Height 0, Width (2) x Channel (2)
+ 1.f, 1.f,
+ 4.f, 1.f,
+
+ // Batch 0, Height 1, Width (2) x Channel (2)
+ 4.f, 4.f,
+ 2.f, 1.f,
+
+ // Batch 0, Height 2, Width (2) x Channel (2)
+ 1.f, -2.f,
+ 6.f, 4.f
+ };
+ std::vector<float> expectedOutputValues
+ {
+ // Batch 0, Height 0, Width (2) x Channel (2)
+ 1.f, 3.f,
+ 4.f, 3.f,
+
+ // Batch 0, Height 1, Width (2) x Channel (2)
+ 4.f, 4.f,
+ 2.f, 3.f,
+
+ // Batch 0, Height 2, Width (2) x Channel (2)
+ 1.f, 2.f,
+ 6.f, 4.f
+ };
+
+ return BatchNormTestImpl<uint8_t>(workloadFactory, inputOutputShape, inputValues, expectedOutputValues,
+ 1.f/20.f, 50, armnn::DataLayout::NHWC);
}
LayerTestResult<uint8_t, 4> ConstantUint8Test(armnn::IWorkloadFactory& workloadFactory)