aboutsummaryrefslogtreecommitdiff
path: root/pseudocode/library/arithmetic_helpers.tosac
blob: 18d5c64d7f1fa6242ec157e820aa30391b93a38a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
//
// This confidential and proprietary software may be used only as
// authorised by a licensing agreement from ARM Limited
// (C) COPYRIGHT 2020-2024 ARM Limited
// ALL RIGHTS RESERVED
// The entire notice above must be reproduced on all authorised
// copies and copies may only be made to the extent permitted
// by a licensing agreement from ARM Limited.

in_t apply_add_s<in_t>(in_t a, in_t b) {
    if (is_floating_point<in_t>()) return a + b;
    int64_t c = sign_extend<int64_t>(a) + sign_extend<int64_t>(b);
    REQUIRE(c >= minimum_s<in_t>() && c <= maximum_s<in_t>());
    return static_cast<in_t>(c);
}

in_t apply_add_u<in_t>(in_t a, in_t b) {
    if (is_floating_point<in_t>()) return a + b;
    uint64_t c = zero_extend<uint64_t>(a) + zero_extend<uint64_t>(b);
    REQUIRE(c >= minimum_u<in_t>() && c <= maximum_u<in_t>());
    return truncate<in_t>(c);
}

in_t apply_arith_rshift<in_t>(in_t a, in_t b) {
    int32_t c = sign_extend<int32_t>(a) >> sign_extend<int32_t>(b);
    return static_cast<in_t>(c);
}

in_t apply_intdiv_s<in_t>(in_t a, in_t b) {
    int64_t c = sign_extend<int64_t>(a) / sign_extend<int64_t>(b);
    REQUIRE(c >= minimum_s<in_t>() && c <= maximum_s<in_t>());
    return static_cast<in_t>(c);
}

// return input value rounded up to nearest integer
in_t apply_ceil<in_t>(in_t input);

// return e to the power input
in_t apply_exp<in_t>(in_t input);

// return input value rounded down to nearest integer
in_t apply_floor<in_t>(in_t input);

// return the natural logarithm of input
in_t apply_log_positive_input<in_t>(in_t input);

in_t apply_log<in_t>(in_t input) {
    if (input == 0) {
        return -INFINITY;
    }
    else if (input < 0) {
        return NaN;
    }
    return apply_log_positive_input(input);
}

in_t apply_logical_rshift<in_t>(in_t a, in_t b) {
    uint64_t c = zero_extend<uint32_t>(a) >> zero_extend<uint32_t>(b);
    return static_cast<in_t>(c);
}

in_t apply_max_s<in_t>(in_t a, in_t b) {
    if (is_floating_point<in_t>()) {
        if (isNaN(a) || isNaN(b)) {
            return NaN;
        }
        if (a >= b) return a; else return b;
    }
    // Integer version
    if (sign_extend<int64_t>(a) >= sign_extend<int64_t>(b)) return a; else return b;
}

in_t apply_max_u<in_t>(in_t a, in_t b) {
    if (zero_extend<uint64_t>(a) >= zero_extend<int64_t>(b)) return a; else return b;
}

in_t apply_min_s<in_t>(in_t a, in_t b) {
    if (is_floating_point<in_t>()) {
        if (isNaN(a) || isNaN(b)) {
            return NaN;
        }
        if (a < b) return a; else return b;
    }
    // Integer version
    if (sign_extend<int64_t>(a) < sign_extend<int64_t>(b)) return a; else return b;
}

in_t apply_min_u<in_t>(in_t a, in_t b) {
    if (zero_extend<int64_t>(a) < zero_extend<int64_t>(b)) return a; else return b;
}

in_t apply_clip_s<in_t>(in_t value, in_t min_val, in_t max_val) {
    if (is_floating_point<in_t>()) {
        REQUIRE(min_val <= max_val);
    }
    else {
        REQUIRE(sign_extend<int64_t>(min_val) <= sign_extend<int64_t>(max_val));
    }
    value = apply_max_s<in_t>(value, min_val);
    value = apply_min_s<in_t>(value, max_val);
    return value;
}

in_t apply_clip_u<in_t>(in_t value, in_t min_val, in_t max_val) {
    REQUIRE(zero_extend<int64_t>(min_val) <= zero_extend<int64_t>(max_val));
    value = apply_max_u<in_t>(value, min_val);
    value = apply_min_u<in_t>(value, max_val);
    return value;
}

in_t apply_mul_s<in_t>(in_t a, in_t b) {
    if (is_floating_point<in_t>()) return a * b;
    int64_t c = sign_extend<int64_t>(a) * sign_extend<int64_t>(b);
    return static_cast<in_t>(c);
}

in_t apply_pow<in_t>(in_t a, in_t b) {
    return a ** b; // a raised to the power b
}

// return the square root of input
in_t apply_sqrt<in_t>(in_t input);

in_t apply_sub_s<in_t>(in_t a, in_t b) {
    if (is_floating_point<in_t>()) return a - b;
    int64_t c = sign_extend<int64_t>(a) - sign_extend<int64_t>(b);
    REQUIRE(c >= minimum_s<in_t>() && c <= maximum_s<in_t>());
    return static_cast<in_t>(c);
}

in_t apply_sub_u<in_t>(in_t a, in_t b) {
    uint64_t c = zero_extend<uint64_t>(a) - zero_extend<uint64_t>(b);
    REQUIRE(c >= minimum_u<in_t>() && c <= maximum_u<in_t>());
    return truncate<in_t>(c);
}

int32_t count_leading_zeros(int32_t a) {
    int32_t acc = 32;
    if (a != 0) {
        uint32_t mask;
        mask = 1 << (32 - 1); // width of int32_t - 1
        acc = 0;
        while ((mask & a) == 0) {
            mask = mask >> 1;
            acc = acc + 1;
        }
    }
    return acc;
}