aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/mlgo/MLGOParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/mlgo/MLGOParser.cpp')
-rw-r--r--src/runtime/CL/mlgo/MLGOParser.cpp188
1 files changed, 91 insertions, 97 deletions
diff --git a/src/runtime/CL/mlgo/MLGOParser.cpp b/src/runtime/CL/mlgo/MLGOParser.cpp
index 625739e450..893daf2ed9 100644
--- a/src/runtime/CL/mlgo/MLGOParser.cpp
+++ b/src/runtime/CL/mlgo/MLGOParser.cpp
@@ -22,19 +22,21 @@
* SOFTWARE.
*/
#include "src/runtime/CL/mlgo/MLGOParser.h"
+
#include "arm_compute/core/Log.h"
+
#include "src/runtime/CL/mlgo/Utils.h"
#include <sstream>
#define CHECK(parser_expr, valid_var) \
(parser_expr); \
- if(!valid_var) \
+ if (!valid_var) \
return;
#define CHECK_DEFAULT(parser_expr, valid_var, default_val) \
(parser_expr); \
- if(!valid_var) \
+ if (!valid_var) \
return default_val;
#ifdef ARM_COMPUTE_LOGGING_ENABLED
@@ -53,8 +55,7 @@
valid_var = false; \
return default_val;
-#define LOG_TOKEN_POS(tokens, pos_var) \
- const auto pos_var = tokens.current_pos();
+#define LOG_TOKEN_POS(tokens, pos_var) const auto pos_var = tokens.current_pos();
#else // ARM_COMPUTE_LOGGING_ENABLED
@@ -73,19 +74,12 @@ namespace
{
void ltrim(std::string &str)
{
- str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](char ch)
- {
- return !std::isspace(ch);
- }));
+ str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](char ch) { return !std::isspace(ch); }));
}
void rtrim(std::string &str)
{
- str.erase(std::find_if(str.rbegin(), str.rend(), [](char ch)
- {
- return !std::isspace(ch);
- }).base(),
- str.end());
+ str.erase(std::find_if(str.rbegin(), str.rend(), [](char ch) { return !std::isspace(ch); }).base(), str.end());
}
void trim(std::string &str)
@@ -109,7 +103,7 @@ enum class ComparatorType
};
TokenStream::TokenStream(std::istream &s, const std::string &delims)
- : _delims{ delims }, _istream{ s }, _tokens{}, _lookahead_pos{}
+ : _delims{delims}, _istream{s}, _tokens{}, _lookahead_pos{}
{
read();
}
@@ -125,7 +119,7 @@ Token TokenStream::take()
ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty");
Token t = _tokens.front();
_tokens.pop_front();
- if(_tokens.empty())
+ if (_tokens.empty())
{
read();
}
@@ -136,7 +130,7 @@ Token TokenStream::peek(size_t i)
ARM_COMPUTE_ERROR_ON_MSG(_tokens.empty(), "TokenStream can never be empty");
ARM_COMPUTE_ERROR_ON_MSG(i >= max_look_ahead, "TokenStream: Exceeding max look ahead");
// NOTE: If i exceeds the stream (_istream.eof()), read() automatically appends a End token at the end
- while(_istream && _tokens.size() <= i)
+ while (_istream && _tokens.size() <= i)
{
read();
}
@@ -146,7 +140,7 @@ Token TokenStream::peek(size_t i)
void advance(CharPosition &pos, char ch)
{
- if(ch == '\n')
+ if (ch == '\n')
{
pos.ln += 1;
pos.col = 0;
@@ -167,17 +161,16 @@ void TokenStream::read()
do
{
// Reached eof
- if(!_istream.get(ch))
+ if (!_istream.get(ch))
{
- if(!reached_end())
+ if (!reached_end())
{
_tokens.emplace_back(TokenType::End, "", _lookahead_pos);
}
return;
}
advance(_lookahead_pos, ch);
- }
- while(std::isspace(ch) || is_delim(ch));
+ } while (std::isspace(ch) || is_delim(ch));
// Read chars until we hit a delim or eof
auto orig_pos = _lookahead_pos;
auto tok = recognize_tok(ch);
@@ -190,41 +183,41 @@ void TokenStream::read()
Token TokenStream::recognize_tok(char ch)
{
- if(ch == '[')
+ if (ch == '[')
{
- return Token{ TokenType::L_List, "", _lookahead_pos };
+ return Token{TokenType::L_List, "", _lookahead_pos};
}
- else if(ch == ']')
+ else if (ch == ']')
{
- return Token{ TokenType::R_List, "", _lookahead_pos };
+ return Token{TokenType::R_List, "", _lookahead_pos};
}
- else if(ch == '.')
+ else if (ch == '.')
{
- return float_after_dp_st(std::string{ ch });
+ return float_after_dp_st(std::string{ch});
}
- else if(std::isdigit(ch))
+ else if (std::isdigit(ch))
{
- return num_st(std::string{ ch });
+ return num_st(std::string{ch});
}
else
{
- return text_st(std::string{ ch });
+ return text_st(std::string{ch});
}
}
Token TokenStream::num_st(std::string value)
{
char ch{};
- while(_istream.get(ch))
+ while (_istream.get(ch))
{
advance(_lookahead_pos, ch);
- if(ch == '.')
+ if (ch == '.')
{
return float_after_dp_st(value + ch);
}
- else if(!std::isdigit(ch))
+ else if (!std::isdigit(ch))
{
- if(!is_delim(ch) && !std::isspace(ch))
+ if (!is_delim(ch) && !std::isspace(ch))
{
rewind(_lookahead_pos);
_istream.unget();
@@ -233,18 +226,18 @@ Token TokenStream::num_st(std::string value)
}
value += ch;
}
- return Token{ TokenType::Int, value, _lookahead_pos };
+ return Token{TokenType::Int, value, _lookahead_pos};
}
Token TokenStream::float_after_dp_st(std::string value)
{
char ch{};
- while(_istream.get(ch))
+ while (_istream.get(ch))
{
advance(_lookahead_pos, ch);
- if(!std::isdigit(ch))
+ if (!std::isdigit(ch))
{
- if(!is_delim(ch) && !std::isspace(ch))
+ if (!is_delim(ch) && !std::isspace(ch))
{
rewind(_lookahead_pos);
_istream.unget();
@@ -253,20 +246,20 @@ Token TokenStream::float_after_dp_st(std::string value)
}
value += ch;
}
- return Token{ TokenType::Float, value, _lookahead_pos };
+ return Token{TokenType::Float, value, _lookahead_pos};
}
Token TokenStream::text_st(std::string value)
{
char ch{};
- while(_istream.get(ch))
+ while (_istream.get(ch))
{
advance(_lookahead_pos, ch);
- if(is_delim(ch))
+ if (is_delim(ch))
{
break;
}
- if(ch == '[' || ch == ']')
+ if (ch == '[' || ch == ']')
{
rewind(_lookahead_pos);
_istream.unget();
@@ -274,7 +267,7 @@ Token TokenStream::text_st(std::string value)
}
value += ch;
}
- return Token{ TokenType::Text, value, _lookahead_pos };
+ return Token{TokenType::Text, value, _lookahead_pos};
}
bool TokenStream::reached_end() const
@@ -291,7 +284,7 @@ void end(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
auto tok = in.take();
- if(tok.type != TokenType::End)
+ if (tok.type != TokenType::End)
{
FAIL_WITH_MSG(valid, pos, "Unexpected token at the end of stream");
}
@@ -301,7 +294,7 @@ bool bool_val(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
auto tok = in.take();
- if(tok.type != TokenType::Int)
+ if (tok.type != TokenType::Int)
{
FAIL_WITH_MSG_DEFAULT(valid, false, pos, "Expect bool or int token");
}
@@ -314,7 +307,7 @@ int int_val(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
auto tok = in.take();
- if(tok.type != TokenType::Int)
+ if (tok.type != TokenType::Int)
{
FAIL_WITH_MSG_DEFAULT(valid, -1, pos, "Expect int token");
}
@@ -327,7 +320,7 @@ unsigned int uint_val(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
int val = CHECK_DEFAULT(int_val(in, valid), valid, 0);
- if(val < 0)
+ if (val < 0)
{
FAIL_WITH_MSG_DEFAULT(valid, 0, pos, "Expect unsigned int token");
}
@@ -338,7 +331,7 @@ float float_val(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
auto tok = in.take();
- if(tok.type != TokenType::Float)
+ if (tok.type != TokenType::Float)
{
FAIL_WITH_MSG_DEFAULT(valid, 0.f, pos, "Expect float token");
}
@@ -351,7 +344,7 @@ std::string text_val(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
auto tok = in.take();
- if(tok.type != TokenType::Text || tok.value.empty())
+ if (tok.type != TokenType::Text || tok.value.empty())
{
FAIL_WITH_MSG_DEFAULT(valid, "", pos, "Expect a non-empty text token");
}
@@ -361,9 +354,9 @@ std::string text_val(TokenStream &in, bool &valid)
bool accept_text(TokenStream &in, const std::string &c_str, bool take = true)
{
auto tok = in.peek();
- if(tok.type == TokenType::Text && tok.value == c_str)
+ if (tok.type == TokenType::Text && tok.value == c_str)
{
- if(take)
+ if (take)
{
in.take();
}
@@ -375,7 +368,7 @@ bool accept_text(TokenStream &in, const std::string &c_str, bool take = true)
void expect_text(TokenStream &in, const std::string &str, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(!accept_text(in, str))
+ if (!accept_text(in, str))
{
FAIL_WITH_MSG(valid, pos, std::string("Expect text token: ") + str);
}
@@ -384,7 +377,7 @@ void expect_text(TokenStream &in, const std::string &str, bool &valid)
bool accept_l_list(TokenStream &in)
{
auto tok = in.peek();
- if(tok.type == TokenType::L_List)
+ if (tok.type == TokenType::L_List)
{
in.take();
return true;
@@ -395,7 +388,7 @@ bool accept_l_list(TokenStream &in)
void expect_l_list(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(!accept_l_list(in))
+ if (!accept_l_list(in))
{
FAIL_WITH_MSG(valid, pos, "Expect '['");
}
@@ -404,7 +397,7 @@ void expect_l_list(TokenStream &in, bool &valid)
bool accept_r_list(TokenStream &in)
{
auto tok = in.peek();
- if(tok.type == TokenType::R_List)
+ if (tok.type == TokenType::R_List)
{
in.take();
return true;
@@ -415,7 +408,7 @@ bool accept_r_list(TokenStream &in)
void expect_r_list(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(!accept_r_list(in))
+ if (!accept_r_list(in))
{
FAIL_WITH_MSG(valid, pos, "Expect ']'");
}
@@ -424,23 +417,23 @@ void expect_r_list(TokenStream &in, bool &valid)
ConditionalOp conditional_op(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "<="))
+ if (accept_text(in, "<="))
{
return ConditionalOp::LE;
}
- else if(accept_text(in, ">="))
+ else if (accept_text(in, ">="))
{
return ConditionalOp::GE;
}
- else if(accept_text(in, "=="))
+ else if (accept_text(in, "=="))
{
return ConditionalOp::EQ;
}
- else if(accept_text(in, "<"))
+ else if (accept_text(in, "<"))
{
return ConditionalOp::LT;
}
- else if(accept_text(in, ">"))
+ else if (accept_text(in, ">"))
{
return ConditionalOp::GT;
}
@@ -464,11 +457,11 @@ void ip_type(TokenStream &in, bool &valid)
{
CHECK(expect_text(in, "ip-type", valid), valid);
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "gpu"))
+ if (accept_text(in, "gpu"))
{
;
}
- else if(accept_text(in, "cpu"))
+ else if (accept_text(in, "cpu"))
{
;
}
@@ -489,15 +482,15 @@ void header(TokenStream &in, bool &valid)
DataType data_type(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "f16"))
+ if (accept_text(in, "f16"))
{
return DataType::F16;
}
- else if(accept_text(in, "f32"))
+ else if (accept_text(in, "f32"))
{
return DataType::F32;
}
- else if(accept_text(in, "qasymm8"))
+ else if (accept_text(in, "qasymm8"))
{
return DataType::QASYMM8;
}
@@ -510,15 +503,15 @@ DataType data_type(TokenStream &in, bool &valid)
ComparatorType comparator_type(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "var"))
+ if (accept_text(in, "var"))
{
return ComparatorType::Var;
}
- else if(accept_text(in, "num"))
+ else if (accept_text(in, "num"))
{
return ComparatorType::Num;
}
- else if(accept_text(in, "enum"))
+ else if (accept_text(in, "enum"))
{
return ComparatorType::Enum;
}
@@ -531,19 +524,19 @@ ComparatorType comparator_type(TokenStream &in, bool &valid)
HeuristicType heuristic_type(TokenStream &in, bool &valid, bool take = true)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "gemm-type", take))
+ if (accept_text(in, "gemm-type", take))
{
return HeuristicType::GEMM_Type;
}
- else if(accept_text(in, "gemm-config-native", take))
+ else if (accept_text(in, "gemm-config-native", take))
{
return HeuristicType::GEMM_Config_Native;
}
- else if(accept_text(in, "gemm-config-reshaped-only-rhs", take))
+ else if (accept_text(in, "gemm-config-reshaped-only-rhs", take))
{
return HeuristicType::GEMM_Config_Reshaped_Only_RHS;
}
- else if(accept_text(in, "gemm-config-reshaped", take))
+ else if (accept_text(in, "gemm-config-reshaped", take))
{
return HeuristicType::GEMM_Config_Reshaped;
}
@@ -557,7 +550,7 @@ void expect_heuristic_type(TokenStream &in, HeuristicType expected_ht, bool &val
{
LOG_TOKEN_POS(in, pos);
auto ht = CHECK(heuristic_type(in, valid, false), valid);
- if(ht != expected_ht)
+ if (ht != expected_ht)
{
FAIL_WITH_MSG(valid, pos, "Unexpected heuristic type");
}
@@ -567,15 +560,15 @@ void expect_heuristic_type(TokenStream &in, HeuristicType expected_ht, bool &val
GEMMType gemm_type(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "native"))
+ if (accept_text(in, "native"))
{
return GEMMType::NATIVE;
}
- else if(accept_text(in, "reshaped-only-rhs"))
+ else if (accept_text(in, "reshaped-only-rhs"))
{
return GEMMType::RESHAPED_ONLY_RHS;
}
- else if(accept_text(in, "reshaped"))
+ else if (accept_text(in, "reshaped"))
{
return GEMMType::RESHAPED;
}
@@ -593,7 +586,7 @@ GEMMConfigNative gemm_config_native(TokenStream &in, bool &valid)
const auto n0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
const auto k0 = CHECK_DEFAULT(uint_val(in, valid), valid, invalid_val);
CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val);
- return GEMMConfigNative{ m0, n0, k0 };
+ return GEMMConfigNative{m0, n0, k0};
}
GEMMConfigReshapedOnlyRHS gemm_config_reshaped_only_rhs(TokenStream &in, bool &valid)
@@ -608,7 +601,7 @@ GEMMConfigReshapedOnlyRHS gemm_config_reshaped_only_rhs(TokenStream &in, bool &v
const auto tr = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
const auto ex = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val);
- return GEMMConfigReshapedOnlyRHS{ m0, n0, k0, h0, ir, tr, ex };
+ return GEMMConfigReshapedOnlyRHS{m0, n0, k0, h0, ir, tr, ex};
}
GEMMConfigReshaped gemm_config_reshaped(TokenStream &in, bool &valid)
@@ -625,17 +618,17 @@ GEMMConfigReshaped gemm_config_reshaped(TokenStream &in, bool &valid)
const auto tr = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
const auto ex = CHECK_DEFAULT(bool_val(in, valid), valid, invalid_val);
CHECK_DEFAULT(expect_r_list(in, valid), valid, invalid_val);
- return GEMMConfigReshaped{ m0, n0, k0, v0, h0, il, ir, tr, ex };
+ return GEMMConfigReshaped{m0, n0, k0, v0, h0, il, ir, tr, ex};
}
void gpu_priority(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "best-performance"))
+ if (accept_text(in, "best-performance"))
{
;
}
- else if(accept_text(in, "best-memory-usage"))
+ else if (accept_text(in, "best-memory-usage"))
{
;
}
@@ -648,11 +641,11 @@ void gpu_priority(TokenStream &in, bool &valid)
void gpu_behavior(TokenStream &in, bool &valid)
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "static"))
+ if (accept_text(in, "static"))
{
;
}
- else if(accept_text(in, "dynamic"))
+ else if (accept_text(in, "dynamic"))
{
;
}
@@ -665,7 +658,7 @@ void gpu_behavior(TokenStream &in, bool &valid)
void free_vars(TokenStream &in, bool &valid)
{
CHECK(expect_l_list(in, valid), valid);
- while(!accept_r_list(in))
+ while (!accept_r_list(in))
{
CHECK(text_val(in, valid), valid);
}
@@ -688,7 +681,7 @@ void heuristics_table_entry(TokenStream &in, MLGOHeuristics &h, bool &valid)
void heuristics_table(TokenStream &in, MLGOHeuristics &h, bool &valid)
{
CHECK(expect_text(in, "<heuristics-table>", valid), valid);
- while(!accept_text(in, "</heuristics-table>"))
+ while (!accept_text(in, "</heuristics-table>"))
{
CHECK(heuristics_table_entry(in, h, valid), valid);
}
@@ -705,11 +698,12 @@ Condition condition(TokenStream &in, bool &valid)
const auto c_o = CHECK_DEFAULT(conditional_op(in, valid), valid, invalid_val);
const auto r_t = CHECK_DEFAULT(comparator_type(in, valid), valid, invalid_val);
const auto r_v = CHECK_DEFAULT(float_val(in, valid), valid, invalid_val);
- if(l_t != ComparatorType::Var || r_t != ComparatorType::Num)
+ if (l_t != ComparatorType::Var || r_t != ComparatorType::Num)
{
- FAIL_WITH_MSG_DEFAULT(valid, invalid_val, pos, "Only accept LHS type to be Var (string) and RHS type to be Num (float)");
+ FAIL_WITH_MSG_DEFAULT(valid, invalid_val, pos,
+ "Only accept LHS type to be Var (string) and RHS type to be Num (float)");
}
- return Condition{ l_v, c_o, r_v };
+ return Condition{l_v, c_o, r_v};
}
void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid)
@@ -717,13 +711,13 @@ void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid)
CHECK(expect_text(in, "<heuristic", valid), valid);
const auto tree_id = CHECK(uint_val(in, valid), valid);
CHECK(expect_text(in, ">", valid), valid);
- HeuristicTree *t = nullptr;
- std::tie(valid, t) = CHECK(h.get_heuristic_tree(tree_id), valid);
+ HeuristicTree *t = nullptr;
+ std::tie(valid, t) = CHECK(h.get_heuristic_tree(tree_id), valid);
const HeuristicType t_heuristic_type = std::get<0>(t->index());
- while(!accept_text(in, "</heuristic>"))
+ while (!accept_text(in, "</heuristic>"))
{
LOG_TOKEN_POS(in, pos);
- if(accept_text(in, "b"))
+ if (accept_text(in, "b"))
{
// Branch node
const auto id = CHECK(uint_val(in, valid), valid);
@@ -732,7 +726,7 @@ void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid)
const auto f_id = CHECK(uint_val(in, valid), valid);
valid = CHECK(t->add_branch(id, cond, t_id, f_id), valid);
}
- else if(accept_text(in, "l"))
+ else if (accept_text(in, "l"))
{
// Leaf node
const auto id = CHECK(uint_val(in, valid), valid);
@@ -740,7 +734,7 @@ void heuristic_tree(TokenStream &in, MLGOHeuristics &h, bool &valid)
// heuristic table). For now it remains as a step for validation.
LOG_TOKEN_POS(in, pos);
CHECK(expect_heuristic_type(in, t_heuristic_type, valid), valid);
- switch(t_heuristic_type)
+ switch (t_heuristic_type)
{
case HeuristicType::GEMM_Type:
{
@@ -786,7 +780,7 @@ MLGOHeuristics mlgo(TokenStream &in, bool &valid)
MLGOHeuristics h;
CHECK_DEFAULT(header(in, valid), valid, h);
CHECK_DEFAULT(heuristics_table(in, h, valid), valid, h);
- while(accept_text(in, "<heuristic", false))
+ while (accept_text(in, "<heuristic", false))
{
CHECK_DEFAULT(heuristic_tree(in, h, valid), valid, h);
}
@@ -809,4 +803,4 @@ std::pair<bool, MLGOHeuristics> parse_mlgo(std::istream &in)
#undef CHECK
#undef CHECK_DEFAULT
#undef FAIL_WITH_MSG
-#undef FAIL_WITH_MSG_DEFAULT \ No newline at end of file
+#undef FAIL_WITH_MSG_DEFAULT