aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2021-03-08 19:28:24 +0000
committerTeresaARM <teresa.charlinreyes@arm.com>2021-04-12 13:30:45 +0000
commit32fe97ec627a70b6453375fcfc6665c0e1ad2024 (patch)
treeb39dc22143c9d4b30dddc44ba9b6335c985f154e
parentc5e0bb093d24392ddd12c4bd6b6fe6d0d0de850d (diff)
downloadandroid-nn-driver-master.tar.gz
IVGCVSW-5763 Remove datalayout from dumps, as it is not known.HEADmaster
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: Ia2bdae7a38252414d295d10a0a2cbb9ae7b083d4
-rw-r--r--Utils.cpp79
1 files changed, 26 insertions, 53 deletions
diff --git a/Utils.cpp b/Utils.cpp
index 930c2b2..1884281 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -350,23 +350,9 @@ template <typename ElementType, typename PrintableType = ElementType>
void DumpTensorElement(const armnn::ConstTensor& tensor, unsigned int elementIndex, std::ofstream& fileStream)
{
const ElementType* elements = reinterpret_cast<const ElementType*>(tensor.GetMemoryArea());
- fileStream << static_cast<PrintableType>(elements[elementIndex]) << ",";
+ fileStream << static_cast<PrintableType>(elements[elementIndex]) << " ";
}
-constexpr const char* MemoryLayoutString(const armnn::ConstTensor& tensor)
-{
- const char* str = "";
-
- switch (tensor.GetNumDimensions())
- {
- case 4: { str = "(BHWC) "; break; }
- case 3: { str = "(HWC) "; break; }
- case 2: { str = "(HW) "; break; }
- default: { str = ""; break; }
- }
-
- return str;
-}
} // namespace
void DumpTensor(const std::string& dumpDir,
@@ -430,55 +416,42 @@ void DumpTensor(const std::string& dumpDir,
if (dumpElementFunction != nullptr)
{
const unsigned int numDimensions = tensor.GetNumDimensions();
-
- const unsigned int batch = (numDimensions == 4) ? tensor.GetShape()[numDimensions - 4] : 1;
-
- const unsigned int height = (numDimensions >= 3)
- ? tensor.GetShape()[numDimensions - 3]
- : (numDimensions >= 2) ? tensor.GetShape()[numDimensions - 2] : 1;
-
- const unsigned int width = (numDimensions >= 3)
- ? tensor.GetShape()[numDimensions - 2]
- : (numDimensions >= 1) ? tensor.GetShape()[numDimensions - 1] : 0;
-
- const unsigned int channels = (numDimensions >= 3) ? tensor.GetShape()[numDimensions - 1] : 1;
+ const armnn::TensorShape shape = tensor.GetShape();
fileStream << "# Number of elements " << tensor.GetNumElements() << std::endl;
- fileStream << "# Dimensions " << MemoryLayoutString(tensor);
- fileStream << "[" << tensor.GetShape()[0];
- for (unsigned int d = 1; d < numDimensions; d++)
+ fileStream << "# Shape [" << shape[0];
+ for (unsigned int d = 1; d < numDimensions; ++d)
{
- fileStream << "," << tensor.GetShape()[d];
+ fileStream << "," << shape[d];
}
fileStream << "]" << std::endl;
+ fileStream << "Each line contains the data of each of the elements of dimension0. In NCHW and NHWC, each line"
+ " will be a batch" << std::endl << std::endl;
- for (unsigned int e = 0, b = 0; b < batch; ++b)
+ // Split will create a new line after all elements of the first dimension
+ // (in a 4, 3, 2, 3 tensor, there will be 4 lines of 18 elements)
+ unsigned int split = 1;
+ if (numDimensions == 1)
+ {
+ split = shape[0];
+ }
+ else
{
- if (numDimensions >= 4)
+ for (unsigned int i = 1; i < numDimensions; ++i)
{
- fileStream << "# Batch " << b << std::endl;
+ split *= shape[i];
}
- for (unsigned int c = 0; c < channels; c++)
+ }
+
+ // Print all elements in the tensor
+ for (unsigned int elementIndex = 0; elementIndex < tensor.GetNumElements(); ++elementIndex)
+ {
+ (*dumpElementFunction)(tensor, elementIndex, fileStream);
+
+ if ( (elementIndex + 1) % split == 0 )
{
- if (numDimensions >= 3)
- {
- fileStream << "# Channel " << c << std::endl;
- }
- for (unsigned int h = 0; h < height; h++)
- {
- for (unsigned int w = 0; w < width; w++, e += channels)
- {
- (*dumpElementFunction)(tensor, e, fileStream);
- }
- fileStream << std::endl;
- }
- e -= channels - 1;
- if (c < channels)
- {
- e -= ((height * width) - 1) * channels;
- }
+ fileStream << std::endl;
}
- fileStream << std::endl;
}
fileStream << std::endl;
}