aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2021-11-09 15:43:37 +0000
committerDavid Monahan <David.Monahan@arm.com>2021-11-09 18:26:07 +0000
commitbc8cacae736c569fd3caf343e7ad7a9e5534bf27 (patch)
treedda366cdd464cc35bb3048721597be571365285a
parentee6818be7815e10be4535645f0472ae5ad116309 (diff)
downloadandroid-nn-driver-branches/android-nn-driver_21_11.tar.gz
Fixed Driver Crash in DumpTensorv21.11branches/android-nn-driver_21_11
* When handling Tensors DumpTensor was automatically trying to turn them into ConstTensors but then threw an exceptions when IsConstant returned false Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I8681bb3dd41cfe19c60fbd1cc9394c8b6cca551e
-rw-r--r--Utils.cpp35
-rw-r--r--Utils.hpp3
2 files changed, 26 insertions, 12 deletions
diff --git a/Utils.cpp b/Utils.cpp
index f910cd49..884bed00 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -335,14 +335,15 @@ std::string GetOperandSummary(const V1_3::Operand& operand)
#endif
-using DumpElementFunction = void (*)(const armnn::ConstTensor& tensor,
+template <typename TensorType>
+using DumpElementFunction = void (*)(const TensorType& tensor,
unsigned int elementIndex,
std::ofstream& fileStream);
namespace
{
-template <typename ElementType, typename PrintableType = ElementType>
-void DumpTensorElement(const armnn::ConstTensor& tensor, unsigned int elementIndex, std::ofstream& fileStream)
+template <typename TensorType, typename ElementType, typename PrintableType = ElementType>
+void DumpTensorElement(const TensorType& tensor, unsigned int elementIndex, std::ofstream& fileStream)
{
const ElementType* elements = reinterpret_cast<const ElementType*>(tensor.GetMemoryArea());
fileStream << static_cast<PrintableType>(elements[elementIndex]) << " ";
@@ -350,10 +351,11 @@ void DumpTensorElement(const armnn::ConstTensor& tensor, unsigned int elementInd
} // namespace
+template <typename TensorType>
void DumpTensor(const std::string& dumpDir,
const std::string& requestName,
const std::string& tensorName,
- const armnn::ConstTensor& tensor)
+ const TensorType& tensor)
{
// The dump directory must exist in advance.
fs::path dumpPath = dumpDir;
@@ -368,38 +370,38 @@ void DumpTensor(const std::string& dumpDir,
return;
}
- DumpElementFunction dumpElementFunction = nullptr;
+ DumpElementFunction<TensorType> dumpElementFunction = nullptr;
switch (tensor.GetDataType())
{
case armnn::DataType::Float32:
{
- dumpElementFunction = &DumpTensorElement<float>;
+ dumpElementFunction = &DumpTensorElement<TensorType, float>;
break;
}
case armnn::DataType::QAsymmU8:
{
- dumpElementFunction = &DumpTensorElement<uint8_t, uint32_t>;
+ dumpElementFunction = &DumpTensorElement<TensorType, uint8_t, uint32_t>;
break;
}
case armnn::DataType::Signed32:
{
- dumpElementFunction = &DumpTensorElement<int32_t>;
+ dumpElementFunction = &DumpTensorElement<TensorType, int32_t>;
break;
}
case armnn::DataType::Float16:
{
- dumpElementFunction = &DumpTensorElement<armnn::Half>;
+ dumpElementFunction = &DumpTensorElement<TensorType, armnn::Half>;
break;
}
case armnn::DataType::QAsymmS8:
{
- dumpElementFunction = &DumpTensorElement<int8_t, int32_t>;
+ dumpElementFunction = &DumpTensorElement<TensorType, int8_t, int32_t>;
break;
}
case armnn::DataType::Boolean:
{
- dumpElementFunction = &DumpTensorElement<bool>;
+ dumpElementFunction = &DumpTensorElement<TensorType, bool>;
break;
}
default:
@@ -473,6 +475,17 @@ void DumpTensor(const std::string& dumpDir,
}
}
+
+template void DumpTensor<armnn::ConstTensor>(const std::string& dumpDir,
+ const std::string& requestName,
+ const std::string& tensorName,
+ const armnn::ConstTensor& tensor);
+
+template void DumpTensor<armnn::Tensor>(const std::string& dumpDir,
+ const std::string& requestName,
+ const std::string& tensorName,
+ const armnn::Tensor& tensor);
+
void DumpJsonProfilingIfRequired(bool gpuProfilingEnabled,
const std::string& dumpDir,
armnn::NetworkId networkId,
diff --git a/Utils.hpp b/Utils.hpp
index 9bd28ba6..6e733a26 100644
--- a/Utils.hpp
+++ b/Utils.hpp
@@ -126,10 +126,11 @@ std::string GetModelSummary(const HalModel& model)
return result.str();
}
+template <typename TensorType>
void DumpTensor(const std::string& dumpDir,
const std::string& requestName,
const std::string& tensorName,
- const armnn::ConstTensor& tensor);
+ const TensorType& tensor);
void DumpJsonProfilingIfRequired(bool gpuProfilingEnabled,
const std::string& dumpDir,