aboutsummaryrefslogtreecommitdiff
path: root/python/pyarmnn/examples/object_detection/style_transfer.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyarmnn/examples/object_detection/style_transfer.py')
-rw-r--r--python/pyarmnn/examples/object_detection/style_transfer.py138
1 files changed, 138 insertions, 0 deletions
diff --git a/python/pyarmnn/examples/object_detection/style_transfer.py b/python/pyarmnn/examples/object_detection/style_transfer.py
new file mode 100644
index 0000000000..eda618e31a
--- /dev/null
+++ b/python/pyarmnn/examples/object_detection/style_transfer.py
@@ -0,0 +1,138 @@
+# Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+# SPDX-License-Identifier: MIT
+
+import numpy as np
+import urllib.request
+import cv2
+import network_executor_tflite
+import cv_utils
+
+
+def style_transfer_postprocess(preprocessed_frame: np.ndarray, image_shape: tuple):
+ """
+ Resizes the output frame of style transfer network and changes the color back to original configuration
+
+ Args:
+ preprocessed_frame: A preprocessed frame after style transfer.
+ image_shape: Contains shape of the original frame before preprocessing.
+
+ Returns:
+ Resizing factor to scale coordinates according to image_shape.
+ """
+
+ postprocessed_frame = np.squeeze(preprocessed_frame, axis=0)
+ # select original height and width from image_shape
+ frame_height = image_shape[0]
+ frame_width = image_shape[1]
+ postprocessed_frame = cv2.resize(postprocessed_frame, (frame_width, frame_height)).astype("float32") * 255
+ postprocessed_frame = cv2.cvtColor(postprocessed_frame, cv2.COLOR_RGB2BGR)
+
+ return postprocessed_frame
+
+
+def create_stylized_detection(style_transfer_executor, style_transfer_class, frame: np.ndarray,
+ detections: list, resize_factor, labels: dict):
+ """
+ Perform style transfer on a detected class in a frame
+
+ Args:
+ style_transfer_executor: The style transfer executor
+ style_transfer_class: The class detected to change its style
+ frame: The original captured frame from video source.
+ detections: A list of detected objects in the form [class, [box positions], confidence].
+ resize_factor: Resizing factor to scale box coordinates to output frame size.
+ labels: Dictionary of labels and colors keyed on the classification index.
+ """
+ for detection in detections:
+ class_idx, box, confidence = [d for d in detection]
+ label = labels[class_idx][0]
+ if label.lower() == style_transfer_class.lower():
+ # Obtain frame size and resized bounding box positions
+ frame_height, frame_width = frame.shape[:2]
+ x_min, y_min, x_max, y_max = [int(position * resize_factor) for position in box]
+
+ # Ensure box stays within the frame
+ x_min, y_min = max(0, x_min), max(0, y_min)
+ x_max, y_max = min(frame_width, x_max), min(frame_height, y_max)
+
+ # Crop only the detected object
+ cropped_frame = cv_utils.crop_bounding_box_object(frame, x_min, y_min, x_max, y_max)
+
+ # Run style_transfer on preprocessed_frame
+ stylized_frame = style_transfer_executor.run_style_transfer(cropped_frame)
+
+ # Paste stylized_frame on the original frame in the correct place
+ frame[int(y_min)+1:int(y_max), int(x_min)+1:int(x_max)] = stylized_frame
+
+ return frame
+
+
+class StyleTransfer:
+
+ def __init__(self, style_predict_model_path: str, style_transfer_model_path: str,
+ style_image: np.ndarray, backends: list, delegate_path: str):
+ """
+ Creates an inference executor for style predict network, style transfer network,
+ list of backends and a style image.
+
+ Args:
+ style_predict_model_path: model which is used to create a style bottleneck
+ style_transfer_model_path: model which is used to create stylized frames
+ style_image: an image to create the style bottleneck
+ backends: List of backends to optimize network.
+ delegate_path: tflite delegate file path (.so).
+ """
+
+ self.style_predict_executor = network_executor_tflite.TFLiteNetworkExecutor(style_predict_model_path, backends,
+ delegate_path)
+ self.style_transfer_executor = network_executor_tflite.TFLiteNetworkExecutor(style_transfer_model_path,
+ backends,
+ delegate_path)
+ self.style_bottleneck = self.run_style_predict(style_image)
+
+ def get_style_predict_executor_shape(self):
+ """
+ Get the input shape of the initiated network.
+
+ Returns:
+ tuple: The Shape of the network input.
+ """
+ return self.style_predict_executor.get_shape()
+
+ # Function to run create a style_bottleneck using preprocessed style image.
+ def run_style_predict(self, style_image):
+ """
+ Creates bottleneck tensor for a given style image.
+
+ Args:
+ style_image: an image to create the style bottleneck
+
+ Returns:
+ style bottleneck tensor
+ """
+ # The style image has to be preprocessed to (1, 256, 256, 3)
+ preprocessed_style_image = cv_utils.preprocess(style_image, self.style_predict_executor.get_data_type(),
+ self.style_predict_executor.get_shape(), True, keep_aspect_ratio=False)
+ # output[0] is the style bottleneck tensor
+ style_bottleneck = self.style_predict_executor.run([preprocessed_style_image])[0]
+
+ return style_bottleneck
+
+ # Run style transform on preprocessed style image
+ def run_style_transfer(self, content_image):
+ """
+ Runs inference for given content_image and style bottleneck to create a stylized image.
+
+ Args:
+ content_image:a content image to stylize
+ """
+ # The content image has to be preprocessed to (1, 384, 384, 3)
+ preprocessed_style_image = cv_utils.preprocess(content_image, np.float32,
+ self.style_transfer_executor.get_shape(), True, keep_aspect_ratio=False)
+
+ # Transform content image. output[0] is the stylized image
+ stylized_image = self.style_transfer_executor.run([preprocessed_style_image, self.style_bottleneck])[0]
+
+ post_stylized_image = style_transfer_postprocess(stylized_image, content_image.shape)
+
+ return post_stylized_image