aboutsummaryrefslogtreecommitdiff
path: root/src/armnnDeserializer/Deserializer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnDeserializer/Deserializer.cpp')
-rw-r--r--src/armnnDeserializer/Deserializer.cpp39
1 files changed, 35 insertions, 4 deletions
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 0bfbe7ea64..f5fc0473f5 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -132,9 +132,24 @@ void CheckConstTensorPtr(Deserializer::ConstTensorRawPtr rawPtr,
}
}
+void CheckConstTensorSize(const unsigned int constTensorSize,
+ const unsigned int tensorSize,
+ const CheckLocation& location)
+{
+ if (constTensorSize != tensorSize)
+ {
+ throw ParseException(boost::str(boost::format("%1% wrong number of components supplied to tensor. at:%2%") %
+ location.m_Function %
+ location.FileLine()));
+ }
+}
+
#define CHECK_TENSOR_PTR(TENSOR_PTR) \
CheckTensorPtr(TENSOR_PTR, CHECK_LOCATION())
+#define CHECK_CONST_TENSOR_SIZE(CONST_TENSOR_SIZE, TENSOR_SIZE) \
+ CheckConstTensorSize(CONST_TENSOR_SIZE, TENSOR_SIZE, CHECK_LOCATION())
+
#define CHECK_CONST_TENSOR_PTR(TENSOR_PTR) \
CheckConstTensorPtr(TENSOR_PTR, CHECK_LOCATION())
@@ -331,13 +346,29 @@ armnn::ConstTensor ToConstTensor(Deserializer::ConstTensorRawPtr constTensorPtr)
switch (constTensorPtr->data_type())
{
case ConstTensorData_ByteData:
- return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_ByteData()->data()->data());
+ {
+ auto byteData = constTensorPtr->data_as_ByteData()->data();
+ CHECK_CONST_TENSOR_SIZE(byteData->size(), tensorInfo.GetNumElements());
+ return armnn::ConstTensor(tensorInfo, byteData->data());
+ }
case ConstTensorData_ShortData:
- return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_ShortData()->data()->data());
+ {
+ auto shortData = constTensorPtr->data_as_ShortData()->data();
+ CHECK_CONST_TENSOR_SIZE(shortData->size(), tensorInfo.GetNumElements());
+ return armnn::ConstTensor(tensorInfo, shortData->data());
+ }
case ConstTensorData_IntData:
- return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_IntData()->data()->data());
+ {
+ auto intData = constTensorPtr->data_as_IntData()->data();
+ CHECK_CONST_TENSOR_SIZE(intData->size(), tensorInfo.GetNumElements());
+ return armnn::ConstTensor(tensorInfo, intData->data());
+ }
case ConstTensorData_LongData:
- return armnn::ConstTensor(tensorInfo, constTensorPtr->data_as_LongData()->data()->data());
+ {
+ auto longData = constTensorPtr->data_as_LongData()->data();
+ CHECK_CONST_TENSOR_SIZE(longData->size(), tensorInfo.GetNumElements());
+ return armnn::ConstTensor(tensorInfo, longData->data());
+ }
default:
{
CheckLocation location = CHECK_LOCATION();