aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2021-09-10 15:27:19 +0100
committerJim Flynn <jim.flynn@arm.com>2021-09-15 17:12:49 +0000
commit7ba84d6881685d6ebfedc597a9af98b16fa42d51 (patch)
treeac4fc72d3c761076531d7970a9c525b1ef1f58c6
parent32b787054793718925226e4ecf406ff5a28b66ab (diff)
downloadarmnn-7ba84d6881685d6ebfedc597a9af98b16fa42d51.tar.gz
GitHub #577 slice layer does not handle a size of -1
* Added support for size of -1 A size of -1 is treated as size = dimension - begin Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I4e381a3794852ec45be029028e2d29bc87791635
-rw-r--r--src/armnnTfLiteParser/TfLiteParser.cpp29
-rw-r--r--src/armnnTfLiteParser/test/Slice.cpp3
2 files changed, 28 insertions, 4 deletions
diff --git a/src/armnnTfLiteParser/TfLiteParser.cpp b/src/armnnTfLiteParser/TfLiteParser.cpp
index 8c85d3099a..0f0e67c539 100644
--- a/src/armnnTfLiteParser/TfLiteParser.cpp
+++ b/src/armnnTfLiteParser/TfLiteParser.cpp
@@ -1609,13 +1609,38 @@ void TfLiteParserImpl::ParseSlice(size_t subgraphIndex, size_t operatorIndex)
armnn::TensorInfo sizeTensorInfo = ToTensorInfo(inputs[2]);
BufferRawPtr sizeBufferPtr = GetBuffer(m_Model, inputs[2]->buffer);
+ std::vector<int> signedSize(sizeTensorInfo.GetNumElements());
+ ::memcpy(signedSize.data(), sizeBufferPtr->data.data(), sizeTensorInfo.GetNumBytes());
std::vector<unsigned int> size(sizeTensorInfo.GetNumElements());
- ::memcpy(size.data(), sizeBufferPtr->data.data(), sizeTensorInfo.GetNumBytes());
+ TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
+
+ for (unsigned int i = 0; i < signedSize.size(); ++i)
+ {
+ int signedValue = signedSize[i];
+
+ if (signedValue < -1 || signedValue > static_cast<int>(inputTensorInfo.GetShape()[i] - begin[i]))
+ {
+ throw ParseException(fmt::format("Invalid value for size {} size must be in range "
+ "[-1, inputDimSize - begin] [-1, {}] inclusive {}",
+ signedValue,
+ inputTensorInfo.GetShape()[i] - begin[i],
+ CHECK_LOCATION().AsString()));
+ }
+
+ if (signedValue == -1)
+ {
+ size[i] = inputTensorInfo.GetShape()[i] - begin[i];
+ }
+ else
+ {
+ size[i] = static_cast<unsigned int>(signedValue);
+ }
+ }
+
desc = SliceDescriptor(begin, size);
auto layerName = fmt::format("Slice:{}:{}", subgraphIndex, operatorIndex);
- TensorInfo inputTensorInfo = ToTensorInfo(inputs[0]);
TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true);
CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
diff --git a/src/armnnTfLiteParser/test/Slice.cpp b/src/armnnTfLiteParser/test/Slice.cpp
index 2a28c6ef09..a2a791feef 100644
--- a/src/armnnTfLiteParser/test/Slice.cpp
+++ b/src/armnnTfLiteParser/test/Slice.cpp
@@ -176,7 +176,7 @@ struct DynamicSliceFixtureD213 : SliceFixture
DynamicSliceFixtureD213() : SliceFixture("[ 3, 2, 3 ]",
"[ ]",
"[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]",
- "[ 2, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0 ]") {}
+ "[ 255, 255, 255, 255, 1, 0, 0, 0, 255, 255, 255, 255 ]") {}
};
TEST_CASE_FIXTURE(DynamicSliceFixtureD213, "DynamicSliceD213")
@@ -187,5 +187,4 @@ TEST_CASE_FIXTURE(DynamicSliceFixtureD213, "DynamicSliceD213")
{{"outputTensor", { 3, 3, 3, 5, 5, 5 }}},
true);
}
-
} \ No newline at end of file