diff options
Diffstat (limited to 'src/backends/gpuFsa/test/GpuFsaLayerSupportTests.cpp')
-rw-r--r-- | src/backends/gpuFsa/test/GpuFsaLayerSupportTests.cpp | 30 |
1 files changed, 30 insertions, 0 deletions
diff --git a/src/backends/gpuFsa/test/GpuFsaLayerSupportTests.cpp b/src/backends/gpuFsa/test/GpuFsaLayerSupportTests.cpp index 9d4b3b9367..fee0d07820 100644 --- a/src/backends/gpuFsa/test/GpuFsaLayerSupportTests.cpp +++ b/src/backends/gpuFsa/test/GpuFsaLayerSupportTests.cpp @@ -101,4 +101,34 @@ TEST_CASE("IsLayerSupportedGpuFsaElementWiseBinarySub") CHECK(supported); } +TEST_CASE("IsLayerSupportedGpuFsaPooling2d") +{ + TensorInfo inputInfo({ 1, 3, 4, 1 }, DataType::Float32); + TensorInfo outputInfo({ 1, 2, 2, 1 }, DataType::Float32); + + Pooling2dDescriptor desc{}; + desc.m_PoolType = PoolingAlgorithm::Max; + desc.m_PadLeft = 0; + desc.m_PadRight = 0; + desc.m_PadTop = 0; + desc.m_PadBottom = 0; + desc.m_PoolWidth = 2; + desc.m_PoolHeight = 2; + desc.m_StrideX = 1; + desc.m_StrideY = 1; + desc.m_OutputShapeRounding = OutputShapeRounding::Floor; + desc.m_PaddingMethod = PaddingMethod::Exclude; + desc.m_DataLayout = DataLayout::NHWC; + + GpuFsaLayerSupport supportChecker; + std::string reasonIfNotSupported; + auto supported = supportChecker.IsLayerSupported(LayerType::Pooling2d, + {inputInfo, outputInfo}, + desc, + EmptyOptional(), + EmptyOptional(), + reasonIfNotSupported); + CHECK(supported); +} + }
\ No newline at end of file |