aboutsummaryrefslogtreecommitdiff
path: root/samples/ObjectDetection/test/PipelineTest.cpp
blob: bc5824e48335d4d8ab609cf2fa14d3a6a3bda831 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
//
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#include <catch.hpp>
#include <opencv2/opencv.hpp>
#include "ObjectDetectionPipeline.hpp"
#include "Types.hpp"

static std::string GetResourceFilePath(const std::string& filename)
{
    std::string testResources = TEST_RESOURCE_DIR;
    if (0 == testResources.size())
    {
        throw "Invalid test resources directory provided";
    }
    else
    {
        if(testResources.back() != '/')
        {
            return testResources + "/" + filename;
        }
        else
        {
            return testResources + filename;
        }
    }
}

TEST_CASE("Test Network Execution SSD_MOBILE")
{
    std::string testResources = TEST_RESOURCE_DIR;
    REQUIRE(testResources != "");
    // Create the network options
    common::PipelineOptions options;
    options.m_ModelFilePath = GetResourceFilePath("detect.tflite");
    options.m_ModelName = "SSD_MOBILE";
    options.m_backends = {"CpuAcc", "CpuRef"};

    od::IPipelinePtr objectDetectionPipeline = od::CreatePipeline(options);

    common::InferenceResults<float> results;
    cv::Mat processed;
    cv::Mat inputFrame = cv::imread(GetResourceFilePath("basketball1.png"), cv::IMREAD_COLOR);
    cv::cvtColor(inputFrame, inputFrame, cv::COLOR_BGR2RGB);

    objectDetectionPipeline->PreProcessing(inputFrame, processed);

    CHECK(processed.type() == CV_8UC3);
    CHECK(processed.cols == 300);
    CHECK(processed.rows == 300);

    objectDetectionPipeline->Inference(processed, results);
    objectDetectionPipeline->PostProcessing(results,
                                            [](od::DetectedObjects detects) -> void {
                                                CHECK(detects.size() == 2);
                                                CHECK(detects[0].GetLabel() == "0");
                                            });

}