aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/func_config.cc
blob: 6bd809e964afc07e00e24773af820ba47469ce28 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

// Copyright (c) 2020-2022, ARM Limited.
//
//    Licensed under the Apache License, Version 2.0 (the "License");
//    you may not use this file except in compliance with the License.
//    You may obtain a copy of the License at
//
//         http://www.apache.org/licenses/LICENSE-2.0
//
//    Unless required by applicable law or agreed to in writing, software
//    distributed under the License is distributed on an "AS IS" BASIS,
//    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//    See the License for the specific language governing permissions and
//    limitations under the License.

#include <cxxopts.hpp>
#include <ctype.h>
#include <signal.h>
#include <stdarg.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/types.h>

#include "func_config.h"
#include "func_debug.h"

// Read the command line arguments
int func_model_parse_cmd_line(
    func_config_t& func_config, func_debug_t& func_debug, int argc, char** argv, const char* version)
{
    try
    {
        cxxopts::Options options("tosa_reference_model", "The TOSA reference model");

        // clang-format off
        options.add_options()
        ("operator_fbs", "Flat buffer schema file", cxxopts::value<std::string>(func_config.operator_fbs), "<schema>")
        ("test_desc", "Json test descriptor", cxxopts::value<std::string>(func_config.test_desc), "<descriptor>")
        ("flatbuffer_dir", "Flatbuffer directory to load. If not specified, it will be overwritten by dirname(test_desc)",
            cxxopts::value<std::string>(func_config.flatbuffer_dir))
        ("output_dir", "Output directory to write. If not specified, it will be overwritten by dirname(test_desc)",
            cxxopts::value<std::string>(func_config.output_dir))
        ("tosa_file", "Flatbuffer file. Support .json or .tosa. Specifying this will overwrite the one initialized by --test_desc.",
            cxxopts::value<std::string>(func_config.tosa_file))
        ("ifm_name", "Input tensor name. Comma(,) separated. Specifying this will overwrite the one initialized by --test_desc.",
            cxxopts::value<std::string>(func_config.ifm_name))
        ("ifm_file", "Input tensor numpy Comma(,) separated. file to initialize with placeholder. Specifying this will overwrite the one initialized by --test_desc.",
            cxxopts::value<std::string>(func_config.ifm_file))
        ("ofm_name", "Output tensor name. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.",
            cxxopts::value<std::string>(func_config.ofm_name))
        ("ofm_file", "Output tensor numpy file to be generated. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.",
            cxxopts::value<std::string>(func_config.ofm_file))
        ("eval", "Evaluate the network (0/1)", cxxopts::value<uint32_t>(func_config.eval))
        ("fp_format", "Floating-point number dump format string (printf-style format, e.g. 0.5)",
            cxxopts::value<std::string>(func_config.fp_format))
        ("validate_only", "Validate the network, but do not read inputs or evaluate (0/1)",
            cxxopts::value<uint32_t>(func_config.validate_only))
        ("output_tensors", "Output tensors to a file (0/1)", cxxopts::value<uint32_t>(func_config.output_tensors))
        ("tosa_profile", "Set TOSA profile (0 = Base Inference, 1 = Main Inference, 2 = Main Training)",
            cxxopts::value<uint32_t>(func_config.tosa_profile))
        ("dump_intermediates", "Dump intermediate tensors (0/1)", cxxopts::value<uint32_t>(func_config.dump_intermediates))
        ("v,version", "print model version")
        ("i,input_tensor_file", "specify input tensor files", cxxopts::value<std::vector<std::string>>())
        ("l,loglevel", func_debug.get_debug_verbosity_help_string(), cxxopts::value<std::string>())
        ("o,logfile", "output log file", cxxopts::value<std::string>())
        ("d,debugmask", func_debug.get_debug_mask_help_string(), cxxopts::value<std::vector<std::string>>())
        ("h,help", "print help");
        // clang-format on

        auto result = options.parse(argc, argv);
        if (result.count("help")) {
            std::cout << options.help() << std::endl;
            return 1;
        }
        if (result.count("debugmask")) {
            auto& v = result["debugmask"].as<std::vector<std::string>>();
            for (const std::string& s : v)
                func_debug.set_mask(s);
        }
        if (result.count("loglevel")) {
            const std::string& levelstr = result["loglevel"].as<std::string>();
            func_debug.set_verbosity(levelstr);
        }
        if (result.count("logfile")) {
            func_debug.set_file(result["logfile"].as<std::string>());
        }
        if (result.count("input_tensor_file")) {
            func_config.ifm_name = result["input_tensor_file"].as<std::string>();
        }
        if (result.count("version")) {
            std::cout << "Model version " << version << std::endl;
        }
    }
    catch(const std::exception& e)
    {
        std::cerr << e.what() << '\n';
        return 1;
    }

    return 0;
}

void func_model_print_help()
{

}