aboutsummaryrefslogtreecommitdiff
path: root/ethosu
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2020-09-04 09:46:17 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-09-07 06:23:54 +0000
commit835d8e10f33f411664cebe65d3f6a872f6cc849a (patch)
tree60c0cfd477cac7b604ade275ec82f54cbb14e9e8 /ethosu
parente5cf95b8c3de4e1e4cbc7046cafd4d84c7492596 (diff)
downloadethos-u-vela-835d8e10f33f411664cebe65d3f6a872f6cc849a.tar.gz
[MLBEDSW-2928] Add batching to softmax
Added batching to softmax by reshaping the input. Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com> Change-Id: I0b516f9bf2410fb86372b229beba4a7280c498cc
Diffstat (limited to 'ethosu')
-rw-r--r--ethosu/vela/softmax.py14
-rw-r--r--ethosu/vela/supported_operators.py4
-rw-r--r--ethosu/vela/tensor.py2
3 files changed, 12 insertions, 8 deletions
diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py
index 2834f8c2..9e8b846d 100644
--- a/ethosu/vela/softmax.py
+++ b/ethosu/vela/softmax.py
@@ -201,6 +201,14 @@ class SoftMax:
ifm = self.op.inputs[0]
ofm = self.op.outputs[0]
+ # Reshape ifm/ofm (if needed)
+ full_shape = ifm.get_full_shape()
+ if full_shape[0] > 1:
+ full_shape[1] *= full_shape[0]
+ full_shape[0] = 1
+ ifm = create_reshape_tensor(ifm, full_shape)
+ ofm = create_reshape_tensor(ofm, full_shape, False)
+
if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype:
return self.get_graph_8bit(ifm, ofm)
elif ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16:
@@ -211,8 +219,6 @@ class SoftMax:
def get_graph_8bit(self, ifm, ofm):
exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32)
- ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
- ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
no_scale_quant = ifm.quantization.clone()
no_scale_quant.scale_f32 = None
no_scale_quant.zero_point = 0
@@ -242,7 +248,7 @@ class SoftMax:
# PASS 1 - Sub+LUT(exp)
sub_op = Operation("SubAct", self.op.name + "_sub1")
sub_op.add_input_tensor(ifm)
- sub_op.add_input_tensor(ifm_max)
+ sub_op.add_input_tensor(create_reshape_tensor(ifm_max, [1, ifm.shape[1], ifm.shape[2], 1]))
sub_op.set_activation_lut(
create_const_tensor(
sub_op.name + "_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT
@@ -463,8 +469,6 @@ class SoftMax:
return shr30_op
def get_graph_int16(self, ifm, ofm):
- ifm = create_reshape_tensor(ifm, ifm.get_full_shape())
- ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False)
no_scale_quant = ifm.quantization.clone()
no_scale_quant.scale_f32 = None
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index e0ee6163..86cc3c07 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -420,8 +420,8 @@ class SupportedOperators:
if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16):
return False
- # check batch size
- if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1:
+ # check shape
+ if len(ifm_tensor.shape) > 4 or ifm_tensor.shape != ofm_tensor.shape:
return False
return True
diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py
index 83dc61a3..49521e7a 100644
--- a/ethosu/vela/tensor.py
+++ b/ethosu/vela/tensor.py
@@ -751,7 +751,7 @@ class Tensor:
elif d == 2:
return [self.shape[0], 1, 1, self.shape[1]]
else:
- return self.shape
+ return self.shape.copy()
def __str__(self):
return "<nng.Tensor '%s' shape=%s dtype=%s>" % (self.name, self.shape, self.dtype)