diff options
Diffstat (limited to 'ethosu/vela/tflite_reader.py')
-rw-r--r-- | ethosu/vela/tflite_reader.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index b3b0720a..c190f7e2 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -106,7 +106,8 @@ class TFLiteSubgraph: np_shape = tens_data.ShapeAsNumpy() shape = list(np_shape) if type(np_shape) is np.ndarray else [] name = decode_str(tens_data.Name()) - dtype = datatype_map[tens_data.Type()] + tens_dtype = tens_data.Type() + dtype = datatype_map[tens_dtype] tens = Tensor(shape, dtype, name) quant = tens_data.Quantization() @@ -129,8 +130,8 @@ class TFLiteSubgraph: tens.values = None buf = self.graph.buffers[tens_data.Buffer()] - if buf is not None: - tens.values = np.array(buf.view(datatype_map_numpy[tens_data.Type()]).reshape(shape)) + if buf is not None and dtype != DataType.string: + tens.values = np.array(buf.view(datatype_map_numpy[tens_dtype]).reshape(shape)) if tens.quantization is not None: tens.quant_values = tens.values tens.values = tens.quantization.dequantize(tens.quant_values) |