aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKristofer Jonsson <kristofer.jonsson@arm.com>2020-11-25 10:58:38 +0100
committerKristofer Jonsson <kristofer.jonsson@arm.com>2020-11-26 08:20:23 +0100
commit5e632c242e4e6445c067f74ccc3bf8f14da7f60b (patch)
tree4ebf2f392468d9aeaff58a4bbcb5ff2a1c62a711
parent716546ac63b42a3d6faa70d4a18418c0d020df3d (diff)
downloadethos-u-linux-driver-stack-5e632c242e4e6445c067f74ccc3bf8f14da7f60b.tar.gz
16 bit network support20.11
Change-Id: I65cd027dea115d2f50f302ebe997d6d2525e0d7e
-rw-r--r--driver_library/src/ethosu.cpp13
1 files changed, 13 insertions, 0 deletions
diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp
index 6b30827..2c498a8 100644
--- a/driver_library/src/ethosu.cpp
+++ b/driver_library/src/ethosu.cpp
@@ -57,12 +57,25 @@ size_t getShapeSize(const flatbuffers::Vector<int32_t> *shape) {
return size;
}
+size_t getTensorTypeSize(const enum tflite::TensorType type) {
+ switch (type) {
+ case tflite::TensorType::TensorType_UINT8:
+ case tflite::TensorType::TensorType_INT8:
+ return 1;
+ case tflite::TensorType::TensorType_INT16:
+ return 2;
+ default:
+ throw EthosU::Exception("Unsupported tensor type");
+ }
+}
+
vector<size_t> getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap) {
vector<size_t> dims;
for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) {
auto tensor = subgraph->tensors()->Get(*index);
size_t size = getShapeSize(tensor->shape());
+ size *= getTensorTypeSize(tensor->type());
if (size > 0) {
dims.push_back(size);