aboutsummaryrefslogtreecommitdiff
path: root/tests/ModelAccuracyTool-Armnn
diff options
context:
space:
mode:
Diffstat (limited to 'tests/ModelAccuracyTool-Armnn')
-rw-r--r--tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp9
1 files changed, 6 insertions, 3 deletions
diff --git a/tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp b/tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp
index bb0d824e0e..85241e889c 100644
--- a/tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp
+++ b/tests/ModelAccuracyTool-Armnn/ModelAccuracyTool-Armnn.cpp
@@ -194,6 +194,9 @@ int main(int argc, char* argv[])
inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[3] : inputTensorShape[2];
const unsigned int inputTensorHeight =
inputTensorDataLayout == armnn::DataLayout::NCHW ? inputTensorShape[2] : inputTensorShape[1];
+ // Get output tensor info
+ const unsigned int outputNumElements = model.GetOutputSize();
+
const unsigned int batchSize = 1;
// Get normalisation parameters
SupportedFrontend modelFrontend;
@@ -232,7 +235,7 @@ int main(int argc, char* argv[])
normParams,
batchSize,
inputTensorDataLayout));
- outputDataContainers = {vector<int>(1001)};
+ outputDataContainers = { vector<int>(outputNumElements) };
break;
case armnn::DataType::QuantisedAsymm8:
inputDataContainers.push_back(
@@ -241,7 +244,7 @@ int main(int argc, char* argv[])
normParams,
batchSize,
inputTensorDataLayout));
- outputDataContainers = {vector<uint8_t>(1001)};
+ outputDataContainers = { vector<uint8_t>(outputNumElements) };
break;
case armnn::DataType::Float32:
default:
@@ -251,7 +254,7 @@ int main(int argc, char* argv[])
normParams,
batchSize,
inputTensorDataLayout));
- outputDataContainers = {vector<float>(1001)};
+ outputDataContainers = { vector<float>(outputNumElements) };
break;
}