aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_reader.py
diff options
context:
space:
mode:
authorDiego Russo <diego.russo@arm.com>2020-04-23 18:14:37 +0100
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commitd0eee26bc17ecd237c1b1e86cda78f5f310af391 (patch)
tree8b4b78d1cc0f01d3686be5459353bdf1b4ea73e8 /ethosu/vela/tflite_reader.py
parente4e58e15d9916fdcef33f5c43c2f60ef124da6a6 (diff)
downloadethos-u-vela-d0eee26bc17ecd237c1b1e86cda78f5f310af391.tar.gz
Add test for len1_array_to_scalar function
Moved len1_array_to_scalar from a nested function to a staticmethod of TFLiteSubgraph. Change-Id: I182f0b70f03070855c1a4478d26644892c1ebb15 Signed-off-by: Diego Russo <diego.russo@arm.com>
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r--ethosu/vela/tflite_reader.py33
1 files changed, 16 insertions, 17 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py
index 4f9bd7d0..7e158aac 100644
--- a/ethosu/vela/tflite_reader.py
+++ b/ethosu/vela/tflite_reader.py
@@ -91,28 +91,15 @@ class TFLiteSubgraph:
shape = list(np_shape) if type(np_shape) is np.ndarray else []
name = decode_str(tens_data.Name())
dtype = datatype_map[tens_data.Type()]
-
tens = Tensor(shape, dtype, name)
-
quant = tens_data.Quantization()
- def len1_array_to_scalar(arr):
- # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in
- # the input buffer. This is represented in Vela by using None.
- # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays
- # are converted to scalars
- if isinstance(arr, int) and arr == 0:
- return None
- if len(arr) == 1:
- return arr[0]
- return arr
-
tens.quantization = QuantizationParameters()
if quant is not None:
- tens.quantization.min = len1_array_to_scalar(quant.MinAsNumpy())
- tens.quantization.max = len1_array_to_scalar(quant.MaxAsNumpy())
- tens.quantization.scale_f32 = len1_array_to_scalar(quant.ScaleAsNumpy())
- tens.quantization.zero_point = len1_array_to_scalar(quant.ZeroPointAsNumpy())
+ tens.quantization.min = self.len1_array_to_scalar(quant.MinAsNumpy())
+ tens.quantization.max = self.len1_array_to_scalar(quant.MaxAsNumpy())
+ tens.quantization.scale_f32 = self.len1_array_to_scalar(quant.ScaleAsNumpy())
+ tens.quantization.zero_point = self.len1_array_to_scalar(quant.ZeroPointAsNumpy())
if dtype == DataType.uint8:
tens.quantization.quant_min = 0
@@ -199,6 +186,18 @@ class TFLiteSubgraph:
op.outputs[0] = intermediate_tens
act_op.inputs = [intermediate_tens]
+ @staticmethod
+ def len1_array_to_scalar(arr):
+ # The following flatbuffer quantisation fields all return a scalar value of 0 if they are not definied in
+ # the input buffer. This is represented in Vela by using None.
+ # Otherwise, the fields returned are a single or multi-element array. In which case, single element arrays
+ # are converted to scalars
+ if isinstance(arr, int) and arr == 0:
+ return None
+ if len(arr) == 1:
+ return arr[0]
+ return arr
+
class TFLiteGraph:
def __init__(