aboutsummaryrefslogtreecommitdiff
path: root/shim/sl/canonical/CanonicalUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'shim/sl/canonical/CanonicalUtils.cpp')
-rw-r--r--shim/sl/canonical/CanonicalUtils.cpp35
1 files changed, 23 insertions, 12 deletions
diff --git a/shim/sl/canonical/CanonicalUtils.cpp b/shim/sl/canonical/CanonicalUtils.cpp
index 713629f554..059b5ca4a3 100644
--- a/shim/sl/canonical/CanonicalUtils.cpp
+++ b/shim/sl/canonical/CanonicalUtils.cpp
@@ -198,25 +198,26 @@ std::string GetOperandSummary(const Operand& operand)
return ss.str();
}
-using DumpElementFunction = void (*)(const armnn::ConstTensor& tensor,
+template <typename TensorType>
+using DumpElementFunction = void (*)(const TensorType& tensor,
unsigned int elementIndex,
std::ofstream& fileStream);
namespace
{
-template <typename ElementType, typename PrintableType = ElementType>
-void DumpTensorElement(const armnn::ConstTensor& tensor, unsigned int elementIndex, std::ofstream& fileStream)
+template <typename TensorType, typename ElementType, typename PrintableType = ElementType>
+void DumpTensorElement(const TensorType& tensor, unsigned int elementIndex, std::ofstream& fileStream)
{
const ElementType* elements = reinterpret_cast<const ElementType*>(tensor.GetMemoryArea());
fileStream << static_cast<PrintableType>(elements[elementIndex]) << " ";
}
} // namespace
-
+template <typename TensorType>
void DumpTensor(const std::string& dumpDir,
const std::string& requestName,
const std::string& tensorName,
- const armnn::ConstTensor& tensor)
+ const TensorType& tensor)
{
// The dump directory must exist in advance.
fs::path dumpPath = dumpDir;
@@ -231,38 +232,38 @@ void DumpTensor(const std::string& dumpDir,
return;
}
- DumpElementFunction dumpElementFunction = nullptr;
+ DumpElementFunction<TensorType> dumpElementFunction = nullptr;
switch (tensor.GetDataType())
{
case armnn::DataType::Float32:
{
- dumpElementFunction = &DumpTensorElement<float>;
+ dumpElementFunction = &DumpTensorElement<TensorType, float>;
break;
}
case armnn::DataType::QAsymmU8:
{
- dumpElementFunction = &DumpTensorElement<uint8_t, uint32_t>;
+ dumpElementFunction = &DumpTensorElement<TensorType, uint8_t, uint32_t>;
break;
}
case armnn::DataType::Signed32:
{
- dumpElementFunction = &DumpTensorElement<int32_t>;
+ dumpElementFunction = &DumpTensorElement<TensorType, int32_t>;
break;
}
case armnn::DataType::Float16:
{
- dumpElementFunction = &DumpTensorElement<armnn::Half>;
+ dumpElementFunction = &DumpTensorElement<TensorType, armnn::Half>;
break;
}
case armnn::DataType::QAsymmS8:
{
- dumpElementFunction = &DumpTensorElement<int8_t, int32_t>;
+ dumpElementFunction = &DumpTensorElement<TensorType, int8_t, int32_t>;
break;
}
case armnn::DataType::Boolean:
{
- dumpElementFunction = &DumpTensorElement<bool>;
+ dumpElementFunction = &DumpTensorElement<TensorType, bool>;
break;
}
default:
@@ -336,6 +337,16 @@ void DumpTensor(const std::string& dumpDir,
}
}
+template void DumpTensor<armnn::ConstTensor>(const std::string& dumpDir,
+ const std::string& requestName,
+ const std::string& tensorName,
+ const armnn::ConstTensor& tensor);
+
+template void DumpTensor<armnn::Tensor>(const std::string& dumpDir,
+ const std::string& requestName,
+ const std::string& tensorName,
+ const armnn::Tensor& tensor);
+
void DumpJsonProfilingIfRequired(bool gpuProfilingEnabled,
const std::string& dumpDir,
armnn::NetworkId networkId,