diff options
author | Diego Russo <diego.russo@arm.com> | 2020-04-23 18:14:37 +0100 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2020-06-18 17:53:52 +0100 |
commit | d0eee26bc17ecd237c1b1e86cda78f5f310af391 (patch) | |
tree | 8b4b78d1cc0f01d3686be5459353bdf1b4ea73e8 /ethosu/vela/tflite_reader.py | |
parent | e4e58e15d9916fdcef33f5c43c2f60ef124da6a6 (diff) | |
download | ethos-u-vela-d0eee26bc17ecd237c1b1e86cda78f5f310af391.tar.gz |
Add test for len1_array_to_scalar function
Moved len1_array_to_scalar from a nested function to a staticmethod
of TFLiteSubgraph.
Change-Id: I182f0b70f03070855c1a4478d26644892c1ebb15
Signed-off-by: Diego Russo <diego.russo@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 4f9bd7d0..7e158aac 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -91,28 +91,15 @@ class TFLiteSubgraph: shape = list(np_shape) if type(np_shape) is np.ndarray else [] name = decode_str(tens_data.Name()) dtype = datatype_map[tens_data.Type()] - tens = Tensor(shape, dtype, name) - quant = tens_data.Quantization() - def len1_array_to_scalar(arr): - # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in - # the input buffer. This is represented in Vela by using None. - # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays - # are converted to scalars - if isinstance(arr, int) and arr == 0: - return None - if len(arr) == 1: - return arr[0] - return arr - tens.quantization = QuantizationParameters() if quant is not None: - tens.quantization.min = len1_array_to_scalar(quant.MinAsNumpy()) - tens.quantization.max = len1_array_to_scalar(quant.MaxAsNumpy()) - tens.quantization.scale_f32 = len1_array_to_scalar(quant.ScaleAsNumpy()) - tens.quantization.zero_point = len1_array_to_scalar(quant.ZeroPointAsNumpy()) + tens.quantization.min = self.len1_array_to_scalar(quant.MinAsNumpy()) + tens.quantization.max = self.len1_array_to_scalar(quant.MaxAsNumpy()) + tens.quantization.scale_f32 = self.len1_array_to_scalar(quant.ScaleAsNumpy()) + tens.quantization.zero_point = self.len1_array_to_scalar(quant.ZeroPointAsNumpy()) if dtype == DataType.uint8: tens.quantization.quant_min = 0 @@ -199,6 +186,18 @@ class TFLiteSubgraph: op.outputs[0] = intermediate_tens act_op.inputs = [intermediate_tens] + @staticmethod + def len1_array_to_scalar(arr): + # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in + # the input buffer. This is represented in Vela by using None. + # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays + # are converted to scalars + if isinstance(arr, int) and arr == 0: + return None + if len(arr) == 1: + return arr[0] + return arr + class TFLiteGraph: def __init__( |