aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSaoirse Stewart <saoirse.stewart@arm.com>2019-02-25 09:22:58 +0000
committerSaoirse Stewart Arm <saoirse.stewart@arm.com>2019-02-25 12:35:16 +0000
commitf11bab5c4e5f9da85dfef079087cf10b6fabd475 (patch)
tree6f1c212191f268d925e72b088c355b880ec922ce
parent8f6d7a71844c9b71cda633525361bf0d554f01e8 (diff)
downloadarmnn-f11bab5c4e5f9da85dfef079087cf10b6fabd475.tar.gz
IVGCVSW-2757 Add check for wrong number of components supplied to const tensor
Change-Id: Ia9bc6c73ce246712c41496a1cfe0bb6a1d2eb8e9 Signed-off-by: Saoirse Stewart <saoirse.stewart@arm.com>
-rw-r--r--src/armnnDeserializer/Deserializer.cpp39
1 files changed, 35 insertions, 4 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 0bfbe7ea64..f5fc0473f5 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -132,9 +132,24 @@ void CheckConstTensorPtr(Deserializer::ConstTensorRawPtr rawPtr,
}
}
+void CheckConstTensorSize(const unsigned int constTensorSize,
+ const unsigned int tensorSize,
+ const CheckLocation& location)
+{
+ if (constTensorSize != tensorSize)
+ {
+ throw ParseException(boost::str(boost::format("%1% wrong number of components supplied to tensor. at:%2%") %
+ location.m_Function %
+ location.FileLine()));
+ }
+}
+
#define CHECK_TENSOR_PTR(TENSOR_PTR) \
CheckTensorPtr(TENSOR_PTR, CHECK_LOCATION())
+#define CHECK_CONST_TENSOR_SIZE(CONST_TENSOR_SIZE, TENSOR_SIZE) \
+ CheckConstTensorSize(CONST_TENSOR_SIZE, TENSOR_SIZE, CHECK_LOCATION())
+
#define CHECK_CONST_TENSOR_PTR(TENSOR_PTR) \
CheckConstTensorPtr(TENSOR_PTR, CHECK_LOCATION())
@@ -331,13 +346,29 @@ armnn::ConstTensor ToConstTensor(Deserializer::ConstTensorRawPtr constTensorPtr)
switch (constTensorPtr->data_type())
{
case ConstTensorData_ByteData:
- return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_ByteData()->data()->data());
+ {
+ auto byteData = constTensorPtr->data_as_ByteData()->data();
+ CHECK_CONST_TENSOR_SIZE(byteData->size(), tensorInfo.GetNumElements());
+ return armnn::ConstTensor(tensorInfo, byteData->data());
+ }
case ConstTensorData_ShortData:
- return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_ShortData()->data()->data());
+ {
+ auto shortData = constTensorPtr->data_as_ShortData()->data();
+ CHECK_CONST_TENSOR_SIZE(shortData->size(), tensorInfo.GetNumElements());
+ return armnn::ConstTensor(tensorInfo, shortData->data());
+ }
case ConstTensorData_IntData:
- return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_IntData()->data()->data());
+ {
+ auto intData = constTensorPtr->data_as_IntData()->data();
+ CHECK_CONST_TENSOR_SIZE(intData->size(), tensorInfo.GetNumElements());
+ return armnn::ConstTensor(tensorInfo, intData->data());
+ }
case ConstTensorData_LongData:
- return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_LongData()->data()->data());
+ {
+ auto longData = constTensorPtr->data_as_LongData()->data();
+ CHECK_CONST_TENSOR_SIZE(longData->size(), tensorInfo.GetNumElements());
+ return armnn::ConstTensor(tensorInfo, longData->data());
+ }
default:
{
CheckLocation location = CHECK_LOCATION();