aboutsummaryrefslogtreecommitdiff
path: root/src/backends/tosaCommon/TosaLayerSupportRules.hpp
blob: 2a2b08da99a7479a29d749fdf1ec32bebcdf80f5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
//
// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//

#pragma once

// List of Layer Support Rules common to TOSA backends only, for use with CheckSupportRule()

struct TosaOperatorAttributeOfAny : public Rule
{
    template<typename Container>
    explicit TosaOperatorAttributeOfAny(TosaSerializationOperator* op, const Container& c)
    {
        m_Res = std::any_of(c.begin(), c.end(), [&op](Attribute attribute)
        {
            return attribute == op->GetAttributeType();
        });
    }
};

struct TosaTypeAnyOf : public Rule
{
    template<typename Container>
    TosaTypeAnyOf(TosaSerializationTensor* tensor, const Container& c)
    {
        m_Res = std::any_of(c.begin(), c.end(), [&tensor](DType dt)
        {
            return dt == tensor->GetDtype();
        });
    }
};

struct TosaTensorNumDimensionsWithinBounds : public Rule
{
    explicit TosaTensorNumDimensionsWithinBounds(TosaSerializationTensor* tensor)
    {
        m_Res = (tensor->GetShape().size() <= MaxNumOfTensorDimensions) || (!tensor->GetShape().empty());
    }
};