From 835d8e10f33f411664cebe65d3f6a872f6cc849a Mon Sep 17 00:00:00 2001 From: Fredrik Svedberg Date: Fri, 4 Sep 2020 09:46:17 +0200 Subject: [MLBEDSW-2928] Add batching to softmax Added batching to softmax by reshaping the input. Signed-off-by: Fredrik Svedberg Change-Id: I0b516f9bf2410fb86372b229beba4a7280c498cc --- ethosu/vela/softmax.py | 14 +++++++++----- ethosu/vela/supported_operators.py | 4 ++-- ethosu/vela/tensor.py | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) (limited to 'ethosu') diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py index 2834f8c2..9e8b846d 100644 --- a/ethosu/vela/softmax.py +++ b/ethosu/vela/softmax.py @@ -201,6 +201,14 @@ class SoftMax: ifm = self.op.inputs[0] ofm = self.op.outputs[0] + # Reshape ifm/ofm (if needed) + full_shape = ifm.get_full_shape() + if full_shape[0] > 1: + full_shape[1] *= full_shape[0] + full_shape[0] = 1 + ifm = create_reshape_tensor(ifm, full_shape) + ofm = create_reshape_tensor(ofm, full_shape, False) + if ifm.dtype in (DataType.uint8, DataType.int8) and ofm.dtype == ifm.dtype: return self.get_graph_8bit(ifm, ofm) elif ifm.dtype == DataType.int16 and ofm.dtype == DataType.int16: @@ -211,8 +219,6 @@ class SoftMax: def get_graph_8bit(self, ifm, ofm): exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32) - ifm = create_reshape_tensor(ifm, ifm.get_full_shape()) - ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False) no_scale_quant = ifm.quantization.clone() no_scale_quant.scale_f32 = None no_scale_quant.zero_point = 0 @@ -242,7 +248,7 @@ class SoftMax: # PASS 1 - Sub+LUT(exp) sub_op = Operation("SubAct", self.op.name + "_sub1") sub_op.add_input_tensor(ifm) - sub_op.add_input_tensor(ifm_max) + sub_op.add_input_tensor(create_reshape_tensor(ifm_max, [1, ifm.shape[1], ifm.shape[2], 1])) sub_op.set_activation_lut( create_const_tensor( sub_op.name + "_lut", [1, 1, 1, 256], DataType.int32, exp_lut, np.int32, TensorPurpose.LUT @@ -463,8 +469,6 @@ class SoftMax: return shr30_op def get_graph_int16(self, ifm, ofm): - ifm = create_reshape_tensor(ifm, ifm.get_full_shape()) - ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False) no_scale_quant = ifm.quantization.clone() no_scale_quant.scale_f32 = None diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index e0ee6163..86cc3c07 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -420,8 +420,8 @@ class SupportedOperators: if ifm_tensor.dtype not in (DataType.uint8, DataType.int8, DataType.int16): return False - # check batch size - if len(ifm_tensor.shape) in (2, 4) and ifm_tensor.shape[0] != 1: + # check shape + if len(ifm_tensor.shape) > 4 or ifm_tensor.shape != ofm_tensor.shape: return False return True diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index 83dc61a3..49521e7a 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -751,7 +751,7 @@ class Tensor: elif d == 2: return [self.shape[0], 1, 1, self.shape[1]] else: - return self.shape + return self.shape.copy() def __str__(self): return "" % (self.name, self.shape, self.dtype) -- cgit v1.2.1