have interpolation run in a thread

This commit is contained in:
q
2025-08-21 14:06:24 +03:00
parent 3fc383382b
commit af70258261
2 changed files with 111 additions and 54 deletions

View File

@@ -2,7 +2,7 @@ import argparse
from tsmark.video_annotator import Marker from tsmark.video_annotator import Marker
VERSION = "0.7.9" VERSION = "0.7.10"
class SmartFormatter(argparse.HelpFormatter): class SmartFormatter(argparse.HelpFormatter):

View File

@@ -49,7 +49,9 @@ class Marker:
self.points_interpolated = {} self.points_interpolated = {}
self.point_index = None self.point_index = None
self.points_interpolation_enabled = True self.points_interpolation_enabled = True
self.points_interpolation_required = False self.points_interpolation_required = {}
self.points_interpolation_thread = None
self.points_interpolation_thread_exit = False
self.message = None self.message = None
self.message_timer = time.time() self.message_timer = time.time()
@@ -257,7 +259,7 @@ class Marker:
if index == self.point_index and self.point_click == 1: if index == self.point_index and self.point_click == 1:
continue continue
current = self.get_interpolated_point(index=index) current = self.get_interpolated_point(index=index)
if current["type"] in ("pre", "post"): if current["type"] in ("pre", "post", None):
continue continue
if current["visible"] == "hidden": if current["visible"] == "hidden":
continue continue
@@ -283,14 +285,21 @@ class Marker:
) )
# Show current track # Show current track
x, y = [20, 70] x, y = [20, 70]
intrp_str = (
""
if not self.points_interpolation_enabled
else " i*" if True in self.points_interpolation_required.values() else " i"
)
self.shadow_text( self.shadow_text(
frame, frame,
"P:" + str(self.point_index), f"P:{str(self.point_index)}{intrp_str}",
(x, y), (x, y),
0.5, 0.5,
1, 1,
(255, 255, 255), (255, 255, 255),
) )
try: try:
current = self.get_interpolated_point() current = self.get_interpolated_point()
if current["type"] is not None: if current["type"] is not None:
@@ -306,12 +315,13 @@ class Marker:
cv2.circle(frame, (current["cx"], current["cy"]), 10, color, 1) cv2.circle(frame, (current["cx"], current["cy"]), 10, color, 1)
history = [] if self.points_interpolation_enabled:
for p in range(max(1, int(self.nr - self.viewer_fps)), self.nr + 1): history = []
po = self.get_interpolated_point(p) for p in range(max(1, int(self.nr - self.viewer_fps)), self.nr + 1):
history.append([po["cx"], po["cy"]]) po = self.get_interpolated_point(p)
history = np.array(history, np.int32).reshape((-1, 1, 2)) history.append([po["cx"], po["cy"]])
cv2.polylines(frame, [history], False, COLOR_INTERP, 1) history = np.array(history, np.int32).reshape((-1, 1, 2))
cv2.polylines(frame, [history], False, COLOR_INTERP, 1)
except KeyError: except KeyError:
print(self.get_interpolated_point(), self.nr) print(self.get_interpolated_point(), self.nr)
@@ -349,7 +359,7 @@ class Marker:
if direction == "previous": if direction == "previous":
for ts in reversed(sorted(list(self.points[self.point_index].keys()))): for ts in reversed(sorted(list(self.points[self.point_index].keys()))):
if ts < self.nr - 1: if ts < self.nr:
return set_nr(ts) return set_nr(ts)
except Exception: except Exception:
@@ -372,7 +382,7 @@ class Marker:
"y1": ip["y1"], "y1": ip["y1"],
"visible": POINT_VISIBILITY[0], "visible": POINT_VISIBILITY[0],
} }
self.interpolate_points() self.points_interpolation_required[self.point_index] = True
except Exception: except Exception:
pass pass
@@ -415,6 +425,17 @@ class Marker:
if index is None: if index is None:
index = self.point_index index = self.point_index
if index in self.points:
if nr in self.points[index]:
value = self.get_point(nr=nr, index=index)
value.update({"type": "key" if value["x0"] is not None else None, "age": 0})
return value
if not self.points_interpolation_enabled:
value = self.get_point(nr=nr, index=index)
value.update({"type": "key" if value["x0"] is not None else None, "age": 0})
return value
if index in self.points_interpolated: if index in self.points_interpolated:
if nr in self.points_interpolated[index]: if nr in self.points_interpolated[index]:
value = self.points_interpolated[index][nr].copy() value = self.points_interpolated[index][nr].copy()
@@ -441,6 +462,8 @@ class Marker:
def convert_interpolated_points(self): def convert_interpolated_points(self):
if self.point_click == 1 and self.point_index in self.points: if self.point_click == 1 and self.point_index in self.points:
self.toggle_interpolation(True)
for nr in range(self.frames): for nr in range(self.frames):
ip = self.get_interpolated_point(nr=nr) ip = self.get_interpolated_point(nr=nr)
if ip["type"] == "interp" and ip["visible"] == POINT_VISIBILITY[0]: if ip["type"] == "interp" and ip["visible"] == POINT_VISIBILITY[0]:
@@ -451,7 +474,7 @@ class Marker:
"y1": ip["y1"], "y1": ip["y1"],
"visible": POINT_VISIBILITY[0], "visible": POINT_VISIBILITY[0],
} }
self.interpolate_points() # self.interpolate_points()
def modify_point(self, position, x, y): def modify_point(self, position, x, y):
"""position: tl topleft, br bottomright, c center""" """position: tl topleft, br bottomright, c center"""
@@ -536,7 +559,8 @@ class Marker:
self.points[self.point_index][self.nr]["y1"], self.points[self.point_index][self.nr]["y1"],
) )
self.interpolate_points() # self.interpolate_points()
self.points_interpolation_required[self.point_index] = True
def modify_point_wh(self): def modify_point_wh(self):
@@ -558,7 +582,7 @@ class Marker:
self.points[self.point_index][self.nr]["y1"] = int(curr_point["cy"] + new_hh) self.points[self.point_index][self.nr]["y1"] = int(curr_point["cy"] + new_hh)
self.points[self.point_index][self.nr]["visible"] = POINT_VISIBILITY[0] self.points[self.point_index][self.nr]["visible"] = POINT_VISIBILITY[0]
self.interpolate_points() self.points_interpolation_required[self.point_index] = True
def toggle_point_visibility(self): def toggle_point_visibility(self):
@@ -585,6 +609,7 @@ class Marker:
except Exception as e: except Exception as e:
print(e) print(e)
pass pass
self.points_interpolation_required[self.point_index] = True
def track_point(self): def track_point(self):
@@ -594,11 +619,14 @@ class Marker:
if self.opts.output_points is None: if self.opts.output_points is None:
return return
self.toggle_interpolation(True)
tracker_gui = TrackerGUI(self) tracker_gui = TrackerGUI(self)
if len(tracker_gui.points) > 0: if len(tracker_gui.points) > 0:
for nr in tracker_gui.points: for nr in tracker_gui.points:
self.points[self.point_index][nr] = tracker_gui.points[nr] self.points[self.point_index][nr] = tracker_gui.points[nr]
self.interpolate_points() # self.interpolate_points()
self.points_interpolation_required[self.point_index] = True
self.nr = max(tracker_gui.points) - 1 self.nr = max(tracker_gui.points) - 1
self.read_next = True self.read_next = True
@@ -655,13 +683,13 @@ class World:
post: after any keyframes post: after any keyframes
""" """
if self.points_interpolation_enabled: if self.points_interpolation_thread is None:
process = threading.Thread(target=self.interpolate_points_in_thread, args=(point_index,)) self.points_interpolation_thread = threading.Thread(target=self.interpolate_points_in_thread, args=())
process.start() self.points_interpolation_thread.start()
if not self.points_interpolation_thread.is_alive():
self.points_interpolation_thread = threading.Thread(target=self.interpolate_points_in_thread, args=())
def interpolate_points_in_thread(self, point_index=None): self.points_interpolation_thread.start()
if point_index is None: if point_index is None:
point_index = self.point_index point_index = self.point_index
@@ -675,11 +703,10 @@ class World:
if not point_index in self.points: if not point_index in self.points:
return return
self.points_interpolation_required[point_index] = False
if not point_index in self.points_interpolated: if not point_index in self.points_interpolated:
self.points_interpolated[point_index] = {key: {} for key in range(self.frames)} self.points_interpolated[point_index] = {key: {} for key in range(self.frames)}
# ~ self.points_interpolation_required = False
new_points = {k: v for k, v in self.points_interpolated[point_index].items()} new_points = {k: v for k, v in self.points_interpolated[point_index].items()}
if len(self.points[point_index]) == 1: # only one point added if len(self.points[point_index]) == 1: # only one point added
@@ -742,7 +769,25 @@ class World:
self.points_interpolated[point_index] = new_points self.points_interpolated[point_index] = new_points
def toggle_interpolation(self): def interpolate_points_in_thread(self):
self.points_interpolation_frequency = 1
while True:
if self.points_interpolation_thread_exit:
return
time.sleep(self.points_interpolation_frequency)
if not self.points_interpolation_enabled:
continue
for point_index in self.points_interpolation_required:
if self.points_interpolation_required[point_index]:
self.interpolate_points(point_index)
def toggle_interpolation(self, value=None):
if value is not None:
self.points_interpolation_enabled = not value
self.points_interpolation_enabled = not self.points_interpolation_enabled self.points_interpolation_enabled = not self.points_interpolation_enabled
if self.points_interpolation_enabled: if self.points_interpolation_enabled:
@@ -1341,6 +1386,21 @@ class World:
self.modify_point_wh() self.modify_point_wh()
elif k & 0xFF == ord("u"): # toggle interpolation elif k & 0xFF == ord("u"): # toggle interpolation
self.toggle_interpolation() self.toggle_interpolation()
if self.point_click == 0:
self.shadow_text(
frame_visu,
(
"Point interpolation turned on"
if self.points_interpolation_enabled
else "Point interpolation turned off"
),
(20, 70),
0.9,
2,
(255, 255, 255),
)
cv2.imshow("tsmark", frame_visu)
k2 = cv2.waitKey(1000)
elif k & 0xFF == ord("x"): # toggle ts elif k & 0xFF == ord("x"): # toggle ts
if self.point_click == 1: if self.point_click == 1:
self.toggle_point(self.nr) self.toggle_point(self.nr)
@@ -1385,6 +1445,7 @@ class World:
except Exception as e: except Exception as e:
print(e) print(e)
self.points_interpolation_thread_exit = True
self.video_reader.release() self.video_reader.release()
cv2.destroyAllWindows() cv2.destroyAllWindows()
self.print_timestamps() self.print_timestamps()
@@ -1432,6 +1493,8 @@ class TrackerGUI:
tracked[0] = [*bbox, 1] tracked[0] = [*bbox, 1]
for i in range(max_frames): for i in range(max_frames):
# Read a new frame # Read a new frame
self.marker.video_reader.set(cv2.CAP_PROP_POS_FRAMES, self.marker.nr + i)
ok, frame = self.marker.video_reader.read() ok, frame = self.marker.video_reader.read()
frame = cv2.resize(frame.copy(), self.marker.video_res) frame = cv2.resize(frame.copy(), self.marker.video_res)
if not ok: if not ok:
@@ -1440,10 +1503,10 @@ class TrackerGUI:
ok, bbox = tracker.update(frame) ok, bbox = tracker.update(frame)
if ok: if ok:
# Tracking success # Tracking success
if self.marker.nr + i + 1 in self.marker.points[self.marker.point_index]: if self.marker.nr + i in self.marker.points[self.marker.point_index]:
point = self.marker.get_point(nr=self.marker.nr + i + 1) point = self.marker.get_point(nr=self.marker.nr + i)
bbox = tuple([point["x0"], point["y0"], point["w"], point["h"]]) bbox = tuple([point["x0"], point["y0"], point["w"], point["h"]])
tracked[i + 1] = [*bbox, 1] tracked[i] = [*bbox, 1]
show_message = f"Tracking... ({i}/{max_frames})" show_message = f"Tracking... ({i}/{max_frames})"
else: else:
# Tracking failure # Tracking failure
@@ -1471,20 +1534,16 @@ class TrackerGUI:
while True: while True:
if done: if done:
break break
self.marker.video_reader.set(cv2.CAP_PROP_POS_FRAMES, self.marker.nr)
i = -1 i = 0
while True: while True:
show_time = time.time() show_time = time.time()
self.marker.video_reader.set(cv2.CAP_PROP_POS_FRAMES, self.marker.nr + i)
if done: if done:
break break
if paused:
frame = frame_copy.copy() ok, frame = self.marker.video_reader.read()
if (not paused) or seek: frame = cv2.resize(frame.copy(), self.marker.video_res)
ok, frame = self.marker.video_reader.read()
frame = cv2.resize(frame.copy(), self.marker.video_res)
frame_copy = frame.copy()
i += 1
seek = False
self.marker.shadow_text(frame, f"Accept? ({i+1}/{max_frames})", (100, 80), 0.75, 2, (255, 255, 255)) self.marker.shadow_text(frame, f"Accept? ({i+1}/{max_frames})", (100, 80), 0.75, 2, (255, 255, 255))
if i in tracked: if i in tracked:
bbox = tracked[i] bbox = tracked[i]
@@ -1495,7 +1554,7 @@ class TrackerGUI:
cv2.rectangle(frame, p1, p2, color, thicc, 1) cv2.rectangle(frame, p1, p2, color, thicc, 1)
cv2.imshow("tsmark - tracker", frame) cv2.imshow("tsmark - tracker", frame)
# speed up fps by 2 # speed up fps by 2
time_to_wait = self.marker.viewer_spf / 2 - time.time() + show_time time_to_wait = 0.2 if paused else (self.marker.viewer_spf / 2 - time.time() + show_time)
k = cv2.waitKey(max(1, int(time_to_wait * 1000))) k = cv2.waitKey(max(1, int(time_to_wait * 1000)))
if k & 0xFF == ord("q") or k & 0xFF == 13: # accept with q or enter if k & 0xFF == ord("q") or k & 0xFF == 13: # accept with q or enter
done = True done = True
@@ -1508,46 +1567,44 @@ class TrackerGUI:
paused = not paused paused = not paused
# Movement ================= # Movement =================
elif k & 0xFF == 83 or k & 0xFF == ord("l"): # right arrow elif k & 0xFF == 83 or k & 0xFF == ord("l"): # right arrow
i += int(self.marker.fps) - 1 i += int(self.marker.fps)
seek = True seek = True
elif k & 0xFF == 81 or k & 0xFF == ord("j"): # left arrow elif k & 0xFF == 81 or k & 0xFF == ord("j"): # left arrow
i -= int(self.marker.fps) + 1 i -= int(self.marker.fps)
seek = True
# Move by frame # Move by frame
elif k & 0xFF == ord(".") or k & 0xFF == ord("c"): elif k & 0xFF == ord(".") or k & 0xFF == ord("c"):
paused = True paused = True
seek = True i += 1
elif k & 0xFF == ord(",") or k & 0xFF == ord("z"): elif k & 0xFF == ord(",") or k & 0xFF == ord("z"):
paused = True paused = True
i -= 2 i -= 1
seek = True
elif k & 0xFF == ord("x"): elif k & 0xFF == ord("x"):
cut_after = i cut_after = i + 1
# TODO: ord("h") for help! # TODO: ord("h") for help!
if i >= max_frames - 1: if i >= max_frames - 1:
i = max_frames - 2 i = max_frames - 2
paused = True paused = True
seek = True
if i < 0: if i < 0:
i = -1 i = 0
paused = True paused = True
seek = True
if seek: if not paused:
self.marker.video_reader.set(cv2.CAP_PROP_POS_FRAMES, self.marker.nr + i + 1) i += 1
cv2.destroyWindow("tsmark - tracker") cv2.destroyWindow("tsmark - tracker")
self.marker.nr = old_nr - 1
self.marker.read_next = True
self.points = {} self.points = {}
for i in sorted(list(tracked.keys())): for i in sorted(list(tracked.keys())):
if i >= cut_after: if i >= cut_after:
continue continue
self.points[self.marker.nr + i + 1] = { self.points[old_nr + i] = {
"x0": tracked[i][0], "x0": tracked[i][0],
"y0": tracked[i][1], "y0": tracked[i][1],
"x1": tracked[i][0] + tracked[i][2], "x1": tracked[i][0] + tracked[i][2],
"y1": tracked[i][1] + tracked[i][3], "y1": tracked[i][1] + tracked[i][3],
"visible": POINT_VISIBILITY[0], "visible": POINT_VISIBILITY[0],
} }
self.marker.nr = old_nr + cut_after - 1
self.marker.read_next = True