From f11bab5c4e5f9da85dfef079087cf10b6fabd475 Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Mon, 25 Feb 2019 09:22:58 +0000 Subject: IVGCVSW-2757 Add check for wrong number of components supplied to const tensor Change-Id: Ia9bc6c73ce246712c41496a1cfe0bb6a1d2eb8e9 Signed-off-by: Saoirse Stewart --- src/armnnDeserializer/Deserializer.cpp | 39 ++++++++++++++++++++++++++++++---- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp index 0bfbe7ea64..f5fc0473f5 100644 --- a/src/armnnDeserializer/Deserializer.cpp +++ b/src/armnnDeserializer/Deserializer.cpp @@ -132,9 +132,24 @@ void CheckConstTensorPtr(Deserializer::ConstTensorRawPtr rawPtr, } } +void CheckConstTensorSize(const unsigned int constTensorSize, + const unsigned int tensorSize, + const CheckLocation& location) +{ + if (constTensorSize != tensorSize) + { + throw ParseException(boost::str(boost::format("%1% wrong number of components supplied to tensor. at:%2%") % + location.m_Function % + location.FileLine())); + } +} + #define CHECK_TENSOR_PTR(TENSOR_PTR) \ CheckTensorPtr(TENSOR_PTR, CHECK_LOCATION()) +#define CHECK_CONST_TENSOR_SIZE(CONST_TENSOR_SIZE, TENSOR_SIZE) \ + CheckConstTensorSize(CONST_TENSOR_SIZE, TENSOR_SIZE, CHECK_LOCATION()) + #define CHECK_CONST_TENSOR_PTR(TENSOR_PTR) \ CheckConstTensorPtr(TENSOR_PTR, CHECK_LOCATION()) @@ -331,13 +346,29 @@ armnn::ConstTensor ToConstTensor(Deserializer::ConstTensorRawPtr constTensorPtr) switch (constTensorPtr->data_type()) { case ConstTensorData_ByteData: - return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_ByteData()->data()->data()); + { + auto byteData = constTensorPtr->data_as_ByteData()->data(); + CHECK_CONST_TENSOR_SIZE(byteData->size(), tensorInfo.GetNumElements()); + return armnn::ConstTensor(tensorInfo, byteData->data()); + } case ConstTensorData_ShortData: - return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_ShortData()->data()->data()); + { + auto shortData = constTensorPtr->data_as_ShortData()->data(); + CHECK_CONST_TENSOR_SIZE(shortData->size(), tensorInfo.GetNumElements()); + return armnn::ConstTensor(tensorInfo, shortData->data()); + } case ConstTensorData_IntData: - return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_IntData()->data()->data()); + { + auto intData = constTensorPtr->data_as_IntData()->data(); + CHECK_CONST_TENSOR_SIZE(intData->size(), tensorInfo.GetNumElements()); + return armnn::ConstTensor(tensorInfo, intData->data()); + } case ConstTensorData_LongData: - return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_LongData()->data()->data()); + { + auto longData = constTensorPtr->data_as_LongData()->data(); + CHECK_CONST_TENSOR_SIZE(longData->size(), tensorInfo.GetNumElements()); + return armnn::ConstTensor(tensorInfo, longData->data()); + } default: { CheckLocation location = CHECK_LOCATION(); -- cgit v1.2.1