From a1a28e0386d313b015519746e0f15e7bbbdf5ff9 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Wed, 10 Nov 2021 19:43:51 +0000 Subject: IVGCVSW-6569 Fix SingleAxisPriorityList * Fix overlap detection with strategy validator Signed-off-by: Finn Williams Change-Id: If9d9d9586864cef7d109aad24bdb0f682fefb1bd --- .../strategies/SingleAxisPriorityList.cpp | 8 ++-- .../strategies/StrategyValidator.cpp | 49 +++++++++------------- .../test/ValidatorStrategyTests.cpp | 16 +++---- 3 files changed, 31 insertions(+), 42 deletions(-) diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/strategies/SingleAxisPriorityList.cpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/strategies/SingleAxisPriorityList.cpp index 738b7137a7..002cd80bb0 100644 --- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/strategies/SingleAxisPriorityList.cpp +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/strategies/SingleAxisPriorityList.cpp @@ -8,7 +8,7 @@ #include #include - +#include namespace armnn { @@ -155,10 +155,8 @@ void SingleAxisPriorityList::PlaceBlocks(const std::list& priorityLis // The indexes don't match we need at least two words // Zero the bits to the right of curBlock->m_EndOfLife - remainder = (curBlock->m_EndOfLife - lastWordIndex * wordSize); - - size_t lastWord = (1ul << remainder) - 1; - lastWord = lastWord << (wordSize - remainder); + remainder = (curBlock->m_EndOfLife + 1 - lastWordIndex * wordSize); + size_t lastWord = std::numeric_limits::max() << (wordSize - remainder); if(firstWordIndex + 1 == lastWordIndex) { diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/strategies/StrategyValidator.cpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/strategies/StrategyValidator.cpp index 48cdfb040c..173ced4cac 100644 --- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/strategies/StrategyValidator.cpp +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/strategies/StrategyValidator.cpp @@ -68,45 +68,28 @@ std::vector StrategyValidator::Optimize(std::vector& memBlocks { for (unsigned int i = 0; i < bin.m_MemBlocks.size(); ++i) { - auto assignedBlock = bin.m_MemBlocks[i]; - auto xStart = assignedBlock.m_Offset; - auto xEnd = assignedBlock.m_Offset + assignedBlock.m_MemSize; + auto Block1 = bin.m_MemBlocks[i]; + auto B1Left = Block1.m_Offset; + auto B1Right = Block1.m_Offset + Block1.m_MemSize; - auto yStart = assignedBlock.m_StartOfLife; - auto yEnd = assignedBlock.m_EndOfLife; - auto assignedIndex = assignedBlock.m_Index; + auto B1Top = Block1.m_StartOfLife; + auto B1Bottom = Block1.m_EndOfLife; // Only compare with blocks after the current one as previous have already been checked for (unsigned int j = i + 1; j < bin.m_MemBlocks.size(); ++j) { - auto otherAssignedBlock = bin.m_MemBlocks[j]; - auto xStartAssigned = otherAssignedBlock.m_Offset; - auto xEndAssigned = otherAssignedBlock.m_Offset + otherAssignedBlock.m_MemSize; - - auto yStartAssigned = otherAssignedBlock.m_StartOfLife; - auto yEndAssigned = otherAssignedBlock.m_EndOfLife; - auto otherIndex = otherAssignedBlock.m_Index; - - // If overlapping on both X and Y then invalid - // Inside left of rectangle & Inside right of rectangle - if ((((xStart >= xStartAssigned) && (xEnd <= xEndAssigned)) && - // Inside bottom of rectangle & Inside top of rectangle - ((yStart >= yStartAssigned) && (yEnd <= yEndAssigned))) && - // Cant overlap with itself - (assignedIndex != otherIndex)) - { - // Condition #3: two Memblocks overlap on both the X and Y axis - throw MemoryValidationException("Condition #3: two Memblocks overlap on both the X and Y axis"); - } + auto Block2 = bin.m_MemBlocks[j]; + auto B2Left = Block2.m_Offset; + auto B2Right = Block2.m_Offset + Block2.m_MemSize; + + auto B2Top = Block2.m_StartOfLife; + auto B2Bottom = Block2.m_EndOfLife; switch (m_Strategy->GetMemBlockStrategyType()) { case (MemBlockStrategyType::SingleAxisPacking): { - // Inside bottom of rectangle & Inside top of rectangle - if (((yStart >= yStartAssigned) && (yEnd <= yEndAssigned)) && - // Cant overlap with itself - (assignedIndex != otherIndex)) + if (B1Top <= B2Bottom && B1Bottom >= B2Top) { throw MemoryValidationException("Condition #3: " "invalid as two Memblocks overlap on the Y axis for SingleAxisPacking"); @@ -116,6 +99,14 @@ std::vector StrategyValidator::Optimize(std::vector& memBlocks } case (MemBlockStrategyType::MultiAxisPacking): { + // If overlapping on both X and Y then invalid + if (B1Left <= B2Right && B1Right >= B2Left && + B1Top <= B2Bottom && B1Bottom >= B2Top) + { + // Condition #3: two Memblocks overlap on both the X and Y axis + throw MemoryValidationException("Condition #3: " + "two Memblocks overlap on both the X and Y axis"); + } break; } default: diff --git a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/ValidatorStrategyTests.cpp b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/ValidatorStrategyTests.cpp index bc04105f4b..2438019878 100644 --- a/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/ValidatorStrategyTests.cpp +++ b/src/backends/backendsCommon/memoryOptimizerStrategyLibrary/test/ValidatorStrategyTests.cpp @@ -59,10 +59,10 @@ TEST_CASE("MemoryOptimizerStrategyValidatorTestOverlapX") { // create a few memory blocks MemBlock memBlock0(0, 5, 20, 0, 0); - MemBlock memBlock1(5, 10, 10, 0, 1); - MemBlock memBlock2(10, 15, 15, 0, 2); - MemBlock memBlock3(15, 20, 20, 0, 3); - MemBlock memBlock4(20, 25, 5, 0, 4); + MemBlock memBlock1(6, 10, 10, 0, 1); + MemBlock memBlock2(11, 15, 15, 0, 2); + MemBlock memBlock3(16, 20, 20, 0, 3); + MemBlock memBlock4(21, 25, 5, 0, 4); std::vector memBlocks; memBlocks.reserve(5); @@ -127,10 +127,10 @@ TEST_CASE("MemoryOptimizerStrategyValidatorTestOverlapY") { // create a few memory blocks MemBlock memBlock0(0, 2, 20, 0, 0); - MemBlock memBlock1(0, 3, 10, 20, 1); - MemBlock memBlock2(0, 5, 15, 30, 2); - MemBlock memBlock3(0, 6, 20, 50, 3); - MemBlock memBlock4(0, 8, 5, 70, 4); + MemBlock memBlock1(0, 3, 10, 21, 1); + MemBlock memBlock2(0, 5, 15, 37, 2); + MemBlock memBlock3(0, 6, 20, 58, 3); + MemBlock memBlock4(0, 8, 5, 79, 4); std::vector memBlocks; memBlocks.reserve(5); -- cgit v1.2.1