aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/softmax.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/softmax.py')
-rw-r--r--ethosu/vela/softmax.py49
1 files changed, 42 insertions, 7 deletions
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 711c1e04..9565bc5c 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -300,11 +300,20 @@ class SoftMax:
# PASS 5 - Sub
headroom_offset = create_const_tensor(
- "headroom_offset_const", [1, 1, 1, 1], DataType.int32, [12 + 31 - 8], np.int32, quantization=no_scale_quant,
+ "headroom_offset_const",
+ [1, 1, 1, 1],
+ DataType.int32,
+ [12 + 31 - 8],
+ np.int32,
+ quantization=no_scale_quant,
)
right_shift = add_op_get_ofm(
create_sub(
- f"{self.op.name}_sub{pass_number}", headroom_offset, headroom_plus_one, no_scale_quant, activation,
+ f"{self.op.name}_sub{pass_number}",
+ headroom_offset,
+ headroom_plus_one,
+ no_scale_quant,
+ activation,
)
)
@@ -329,7 +338,13 @@ class SoftMax:
# PASS 9 - SHL
shifted_sum_minus_one = add_op_get_ofm(
- create_shl(f"{self.op.name}_shl{pass_number}", shifted_sum_minus_one, one, no_scale_quant, activation,)
+ create_shl(
+ f"{self.op.name}_shl{pass_number}",
+ shifted_sum_minus_one,
+ one,
+ no_scale_quant,
+ activation,
+ )
)
# PASS 10 - Add
@@ -353,7 +368,11 @@ class SoftMax:
)
rescaled = add_op_get_ofm(
create_mul(
- f"{self.op.name}_mul{pass_number}", half_denominator, neg_32_over_17, two_scale_quant, activation2,
+ f"{self.op.name}_mul{pass_number}",
+ half_denominator,
+ neg_32_over_17,
+ two_scale_quant,
+ activation2,
)
)
@@ -362,7 +381,13 @@ class SoftMax:
"48_over_17_const", [1, 1, 1, 1], DataType.int32, [1515870810], np.int32, quantization=no_scale_quant
)
rescale_w_offset = add_op_get_ofm(
- create_add(f"{self.op.name}_add{pass_number}", rescaled, const_48_over_17, one_scale_quant, activation,)
+ create_add(
+ f"{self.op.name}_add{pass_number}",
+ rescaled,
+ const_48_over_17,
+ one_scale_quant,
+ activation,
+ )
)
# PASS 13 - 27
@@ -376,12 +401,22 @@ class SoftMax:
for _ in range(3):
# PASS 13, 18, 23 - MUL
half_denominator_times_x = add_op_get_ofm(
- create_mul(f"{self.op.name}_mul{pass_number}", nr_x, half_denominator, two_scale_quant, activation2,)
+ create_mul(
+ f"{self.op.name}_mul{pass_number}",
+ nr_x,
+ half_denominator,
+ two_scale_quant,
+ activation2,
+ )
)
# PASS 14, 19, 24 - SUB
one_minus_half_denominator_times_x = add_op_get_ofm(
create_sub(
- f"{self.op.name}_sub{pass_number}", F2_one, half_denominator_times_x, one_scale_quant, activation,
+ f"{self.op.name}_sub{pass_number}",
+ F2_one,
+ half_denominator_times_x,
+ one_scale_quant,
+ activation,
)
)
# PASS 15, 20, 25 - MUL