aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJacob Bohlin <jacob.bohlin@arm.com>2020-08-13 15:32:45 +0200
committerpatrik.gustavsson <patrik.gustavsson@arm.com>2020-08-18 13:51:46 +0000
commitf767b937c12935be3cb1f9ee406fbb796176a40c (patch)
tree3d27a0d63b582f2a1c4d23866a2823ce40d6f1ff
parent0b8268a0dac80aa22133ca83ed6912d3b565439a (diff)
downloadethos-u-vela-f767b937c12935be3cb1f9ee406fbb796176a40c.tar.gz
MLBEDSW-2732: Added complex64 to datatypes
Added complex64 datatype to allow pass through without crashing. Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com> Change-Id: I8beeceafb32182d4877a9880d21d51ba21033030
-rw-r--r--ethosu/vela/data_type.py3
-rw-r--r--ethosu/vela/tflite_mapping.py3
2 files changed, 5 insertions, 1 deletions
diff --git a/ethosu/vela/data_type.py b/ethosu/vela/data_type.py
index bb4c5589..4d05fef5 100644
--- a/ethosu/vela/data_type.py
+++ b/ethosu/vela/data_type.py
@@ -35,6 +35,7 @@ class BaseType(enum.Flag):
String = 128
Resource = 256
Variant = 512
+ Complex = 1024
class DataType:
@@ -78,6 +79,7 @@ class DataType:
BaseType.String: ("string", False),
BaseType.Resource: ("resource", False),
BaseType.Variant: ("variant", False),
+ BaseType.Complex: ("complex%s", True),
}
@@ -112,3 +114,4 @@ DataType.string = DataType(BaseType.String, 64)
DataType.bool = DataType(BaseType.Bool, 8)
DataType.resource = DataType(BaseType.Resource, 8)
DataType.variant = DataType(BaseType.Variant, 8)
+DataType.complex64 = DataType(BaseType.Complex, 64)
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py
index 79521680..55351cb9 100644
--- a/ethosu/vela/tflite_mapping.py
+++ b/ethosu/vela/tflite_mapping.py
@@ -143,7 +143,7 @@ datatype_map = {
TensorType.FLOAT32: DataType.float32,
TensorType.STRING: DataType.string,
TensorType.BOOL: DataType.bool,
- # no TensorType.COMPLEX64 for now
+ TensorType.COMPLEX64: DataType.complex64,
}
datatype_inv_map = inverse_map(datatype_map)
@@ -163,6 +163,7 @@ datatype_map_numpy = {
TensorType.FLOAT16: np.float16,
TensorType.FLOAT32: np.float32,
TensorType.BOOL: np.bool,
+ TensorType.COMPLEX64: np.complex64,
}