Skip to content

Commit

Permalink
feat(ml): improve the inference for any paths.txt and longer frames
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Aug 2, 2024
1 parent 2e76af1 commit 90876b0
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions scripts/gen_vid_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@
from util.util import flatten_json


def atoi(text):
return int(text) if text.isdigit() else text


def natural_keys(text):
return [atoi(c) for c in re.split("(\d+)", text)]


def load_model(
model_in_dir,
model_in_filename,
Expand Down Expand Up @@ -237,14 +245,25 @@ def generate(
for i, delta_values in enumerate(mask_delta):
if len(delta_values) == 1:
mask_delta[i].append(delta_values[0])

# Load image
with open(args.paths_file, "r") as file:
lines = file.readlines()

image_bbox_pair = [
line.strip().split() for line in lines[: opt.data_temporal_number_frames]
paths_img = []
paths_bbox = []

image_bbox_pairs = []
for line in lines:
parts = line.strip().split()
image_bbox_pairs.append((parts[0], parts[1]))

image_bbox_pairs.sort(key=lambda x: natural_keys(x[0]))
startframe = random.randint(100, 10000)
limited_image_bbox_pairs = image_bbox_pairs[
startframe : startframe + opt.data_temporal_number_frames + 10
]
limited_paths_img = [pair[0] for pair in limited_image_bbox_pairs]
limited_paths_bbox = [pair[1] for pair in limited_image_bbox_pairs]

cond_image_list = []
y_t_list = []
y0_tensor_list = []
Expand All @@ -254,7 +273,7 @@ def generate(
img_tensor_list = []
out_img_list = []

for img_path, bbox_path in image_bbox_pair:
for img_path, bbox_path in zip(limited_paths_img, limited_paths_bbox):
img_in = os.path.join(args.data_root, img_path)
maskin = os.path.join(args.data_root, bbox_path)
bbox_select = None
Expand Down

0 comments on commit 90876b0

Please sign in to comment.