Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get latents path function improvement #50

Open
fsbarros98 opened this issue Sep 18, 2024 · 0 comments
Open

Get latents path function improvement #50

fsbarros98 opened this issue Sep 18, 2024 · 0 comments

Comments

@fsbarros98
Copy link

Hi there, just started using this repository for some cool experiments.
I noticed some people are having trouble with the get_latents_path method in the run_tokenflow_pnp.py, I'm also one of them.
If I understood the intentions of that function correctly, I would suggest changing the first couple of lines from:

latents_path = os.path.join(config["latents_path"], f'sd_{config["sd_version"]}',
                     Path(config["data_path"]).stem, f'steps_{config["n_inversion_steps"]}')
latents_path = [x for x in glob.glob(f'{latents_path}/*') if '.' not in Path(x).name]
n_frames = [int([x for x in latents_path[i].split('/') if 'nframes' in x][0].split('_')[1]) for i in range(len(latents_path))]
latents_path = latents_path[np.argmax(n_frames)]

to

# Get parent folder of latents directories
latents_path_dir = os.path.join(config["latents_path"], 
                               f'sd_{config["sd_version"]}',
                               Path(config["data_path"]).stem, 
                               f'steps_{config["n_inversion_steps"]}')

# Get all possible folders that will contain latents according to different n_frames
latents_path_folders = [os.path.join(latents_path_dir, folder) 
                       for folder in os.listdir(latents_path_dir) 
                       if os.path.isdir(os.path.join(latents_path_dir, folder)) 
                       and 'nframes' in folder]

# Get all possible n_frames
n_frames = [int(latents_path_folder.split('_')[-1]) for latents_path_folder in latents_path_folders]

# Define latents_path according to the folder with the highest n_frames
latents_path = latents_path_folders[np.argmax(n_frames)]

By doing this you will avoid different OS collisions in the split('/') and is also more readable and easier to debug by not always using the same latents_path variable for different things. I would also suggest a more detailed description of what should be in the config.yml file since "data_path" for me was not obvious and if it comes from a folder that is generated in the preprocess it makes the process of detailing the config less automated, so I also changed the lines in get_data method so that I passed the video used for input in the config, and it extracts the name of the video (assuming it added a folder with that name...)

    def get_data(self):
        # load frames
        if self.config["data_path"].endswith('.mp4'):
            self.config["data_path"] = os.path.splitext(self.config["data_path"])[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant