aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2022-07-15 10:22:49 +0100
committerSadik Armagan <sadik.armagan@arm.com>2022-07-15 13:02:18 +0000
commit09742380f1d986e3f7d3a7d130f37f3ab85729fd (patch)
tree24a3b084853ed954dc4098782a2ee6a8bf4fc95c
parent47263a241d50bf5cbd6819d25113dff29e4c03aa (diff)
downloadarmnn-09742380f1d986e3f7d3a7d130f37f3ab85729fd.tar.gz
IVGCVSW-7107 'Error while running Arm NN Sl with -d option'
* Templated the DumpTensor() function based on tensor type Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: I08775e480f89010be61daf0a09a2ab0274e05978
-rw-r--r--shim/sl/canonical/CanonicalUtils.cpp35
-rw-r--r--shim/sl/canonical/CanonicalUtils.hpp3
2 files changed, 25 insertions, 13 deletions
diff --git a/shim/sl/canonical/CanonicalUtils.cpp b/shim/sl/canonical/CanonicalUtils.cpp
index 713629f554..059b5ca4a3 100644
--- a/shim/sl/canonical/CanonicalUtils.cpp
+++ b/shim/sl/canonical/CanonicalUtils.cpp
@@ -198,25 +198,26 @@ std::string GetOperandSummary(const Operand& operand)
return ss.str();
}
-using DumpElementFunction = void (*)(const armnn::ConstTensor& tensor,
+template <typename TensorType>
+using DumpElementFunction = void (*)(const TensorType& tensor,
unsigned int elementIndex,
std::ofstream& fileStream);
namespace
{
-template <typename ElementType, typename PrintableType = ElementType>
-void DumpTensorElement(const armnn::ConstTensor& tensor, unsigned int elementIndex, std::ofstream& fileStream)
+template <typename TensorType, typename ElementType, typename PrintableType = ElementType>
+void DumpTensorElement(const TensorType& tensor, unsigned int elementIndex, std::ofstream& fileStream)
{
const ElementType* elements = reinterpret_cast<const ElementType*>(tensor.GetMemoryArea());
fileStream << static_cast<PrintableType>(elements[elementIndex]) << " ";
}
} // namespace
-
+template <typename TensorType>
void DumpTensor(const std::string& dumpDir,
const std::string& requestName,
const std::string& tensorName,
- const armnn::ConstTensor& tensor)
+ const TensorType& tensor)
{
// The dump directory must exist in advance.
fs::path dumpPath = dumpDir;
@@ -231,38 +232,38 @@ void DumpTensor(const std::string& dumpDir,
return;
}
- DumpElementFunction dumpElementFunction = nullptr;
+ DumpElementFunction<TensorType> dumpElementFunction = nullptr;
switch (tensor.GetDataType())
{
case armnn::DataType::Float32:
{
- dumpElementFunction = &DumpTensorElement<float>;
+ dumpElementFunction = &DumpTensorElement<TensorType, float>;
break;
}
case armnn::DataType::QAsymmU8:
{
- dumpElementFunction = &DumpTensorElement<uint8_t, uint32_t>;
+ dumpElementFunction = &DumpTensorElement<TensorType, uint8_t, uint32_t>;
break;
}
case armnn::DataType::Signed32:
{
- dumpElementFunction = &DumpTensorElement<int32_t>;
+ dumpElementFunction = &DumpTensorElement<TensorType, int32_t>;
break;
}
case armnn::DataType::Float16:
{
- dumpElementFunction = &DumpTensorElement<armnn::Half>;
+ dumpElementFunction = &DumpTensorElement<TensorType, armnn::Half>;
break;
}
case armnn::DataType::QAsymmS8:
{
- dumpElementFunction = &DumpTensorElement<int8_t, int32_t>;
+ dumpElementFunction = &DumpTensorElement<TensorType, int8_t, int32_t>;
break;
}
case armnn::DataType::Boolean:
{
- dumpElementFunction = &DumpTensorElement<bool>;
+ dumpElementFunction = &DumpTensorElement<TensorType, bool>;
break;
}
default:
@@ -336,6 +337,16 @@ void DumpTensor(const std::string& dumpDir,
}
}
+template void DumpTensor<armnn::ConstTensor>(const std::string& dumpDir,
+ const std::string& requestName,
+ const std::string& tensorName,
+ const armnn::ConstTensor& tensor);
+
+template void DumpTensor<armnn::Tensor>(const std::string& dumpDir,
+ const std::string& requestName,
+ const std::string& tensorName,
+ const armnn::Tensor& tensor);
+
void DumpJsonProfilingIfRequired(bool gpuProfilingEnabled,
const std::string& dumpDir,
armnn::NetworkId networkId,
diff --git a/shim/sl/canonical/CanonicalUtils.hpp b/shim/sl/canonical/CanonicalUtils.hpp
index a509684153..b94fd5e126 100644
--- a/shim/sl/canonical/CanonicalUtils.hpp
+++ b/shim/sl/canonical/CanonicalUtils.hpp
@@ -55,10 +55,11 @@ bool isQuantizedOperand(const OperandType& operandType);
std::string GetModelSummary(const Model& model);
+template <typename TensorType>
void DumpTensor(const std::string& dumpDir,
const std::string& requestName,
const std::string& tensorName,
- const armnn::ConstTensor& tensor);
+ const TensorType& tensor);
void DumpJsonProfilingIfRequired(bool gpuProfilingEnabled,
const std::string& dumpDir,