aboutsummaryrefslogtreecommitdiff
path: root/tests/InferenceTest.hpp
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-01-24 17:05:36 +0000
committerAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-01-30 11:25:56 +0000
commit7cf0eaa26c1fb29ca9df97e4734ec7c1e10f81c4 (patch)
tree407b519ede76b235c54907fe80411970741e8a00 /tests/InferenceTest.hpp
parent28d3d63cc0a33f8396b32fa8347c03912c065911 (diff)
downloadarmnn-7cf0eaa26c1fb29ca9df97e4734ec7c1e10f81c4.tar.gz
IVGCVSW-2564 Add support for multiple input and output bindings in InferenceModel
Change-Id: I64d724367d42dca4b768b6c6e42acda714985950
Diffstat (limited to 'tests/InferenceTest.hpp')
-rw-r--r--tests/InferenceTest.hpp31
1 files changed, 20 insertions, 11 deletions
diff --git a/tests/InferenceTest.hpp b/tests/InferenceTest.hpp
index 3ea70962d2..3c22df9a5e 100644
--- a/tests/InferenceTest.hpp
+++ b/tests/InferenceTest.hpp
@@ -100,31 +100,40 @@ template <typename TModel>
class InferenceModelTestCase : public IInferenceTestCase
{
public:
+ using TContainer = std::vector<typename TModel::DataType>;
+
InferenceModelTestCase(TModel& model,
- unsigned int testCaseId,
- std::vector<typename TModel::DataType> modelInput,
- unsigned int outputSize)
+ unsigned int testCaseId,
+ const std::vector<TContainer>& inputs,
+ const std::vector<unsigned int>& outputSizes)
: m_Model(model)
, m_TestCaseId(testCaseId)
- , m_Input(std::move(modelInput))
+ , m_Inputs(std::move(inputs))
{
- m_Output.resize(outputSize);
+ // Initialize output vector
+ const size_t numOutputs = outputSizes.size();
+ m_Outputs.resize(numOutputs);
+
+ for (size_t i = 0; i < numOutputs; i++)
+ {
+ m_Outputs[i].resize(outputSizes[i]);
+ }
}
virtual void Run() override
{
- m_Model.Run(m_Input, m_Output);
+ m_Model.Run(m_Inputs, m_Outputs);
}
protected:
unsigned int GetTestCaseId() const { return m_TestCaseId; }
- const std::vector<typename TModel::DataType>& GetOutput() const { return m_Output; }
+ const std::vector<TContainer>& GetOutputs() const { return m_Outputs; }
private:
- TModel& m_Model;
- unsigned int m_TestCaseId;
- std::vector<typename TModel::DataType> m_Input;
- std::vector<typename TModel::DataType> m_Output;
+ TModel& m_Model;
+ unsigned int m_TestCaseId;
+ std::vector<TContainer> m_Inputs;
+ std::vector<TContainer> m_Outputs;
};
template <typename TDataType>