Code Monkey home page Code Monkey logo

arldm's Introduction

Hi there 👋

🤗 I'm Flash

🌎 CS Ph.D. Student @ NYU Courant

🏠 More about me, find out at my homepage!

info

arldm's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

arldm's Issues

ckpt is not saved after training?

I ran the training process with config file as following. Everything looked well during training. However, when the training end, I found no ckpt file in ckpt_dir.
Did I miss anything?

# device
mode: train  # train sample
gpu_ids: [ 0,1,2,3 ]  # gpu ids
batch_size: 1  # batch size each item denotes one story
num_workers: 16  # number of workers
num_cpu_cores: -1  # number of cpu cores
seed: 0  # random seed
ckpt_dir: ./result/flintstones # checkpoint directory
run_name: 5epoch_visualization # name for this run

# task
dataset: flintstones  # pororo flintstones vistsis vistdii
task: visualization  # continuation visualization

# train
init_lr: 1e-5  # initial learning rate
warmup_epochs: 1  # warmup epochs
max_epochs: 5  # max epochs
train_model_file:  # model file for resume, none for train from scratch
freeze_clip: False #True  # whether to freeze clip
freeze_blip: False  # whether to freeze blip
freeze_resnet: False  # whether to freeze resnet

# sample
# test_model_file:  # model file for test
# calculate_fid: True  # whether to calculate FID scores
# scheduler: ddim  # ddim pndm
# guidance_scale: 6  # guidance scale
# num_inference_steps: 250  # number of inference steps
# sample_output_dir: /path/to/save_samples # output directory

# pororo:
#   hdf5_file: /path/to/pororo.h5
#   max_length: 85
#   new_tokens: [ "pororo", "loopy", "eddy", "harry", "poby", "tongtong", "crong", "rody", "petty" ]
#   clip_embedding_tokens: 49416
#   blip_embedding_tokens: 30530

flintstones:
  hdf5_file: /root/autodl-tmp/dataset/flintstones.h5
  max_length: 91
  new_tokens: [ "fred", "barney", "wilma", "betty", "pebbles", "dino", "slate" ]
  clip_embedding_tokens: 49412
  blip_embedding_tokens: 30525

# vistsis:
#   hdf5_file: /path/to/vist.h5
#   max_length: 100
#   clip_embedding_tokens: 49408
#   blip_embedding_tokens: 30524

# vistdii:
#   hdf5_file: /path/to/vist.h5
#   max_length: 65
#   clip_embedding_tokens: 49408
#   blip_embedding_tokens: 30524

hydra:
  run:
    dir: .
  output_subdir: null
hydra/job_logging: disabled
hydra/hydra_logging: disabled

Implementation about classifier free guidance

Hi, I have some little questions about how to implement the classifier free guidance generation in your code.

As far as I know, classifier free guidance needs two steps.

  1. Training: randomly select samples (p=0.1 for example) and mask all the context of the selected samples. This means we jointly train two models under the same architecture (p=0.1 to train a unconditional model based on the null context, and p=0.9 to train a conditional model)
  2. Sampling: simultaneously using two models to generate Noise1 from conditional model and Noise2 from unconditional model, and use the formula Noise = Noise1 + w* (Noise1 - Noise2).

But in your code, I'm confused about why you randomly discard some frames in the context rather than all the frames. Because in the sampling stage, you seem to generate the Noise2 from null context.

image

In the above training stage, you only fill random frames with the null, rather than entirely sample. (I think it maybe classifier_free_idx = np.random.rand(B) rather than classifier_free_idx = np.random.rand(B*V))

And in the sampling, I think your code is correct.

image

We should align the behaviour between the training and sampling by both discard all frames in one sample, yeah?

Beside, I don't know if there is a little typo in main.py 355

image

Does it supposed to be : noise_pred = noise_pred_text + guidance_scale * (noise_pred_text - noise_pred_uncond) ?
Because according to the paper:

image

Error

why the code gives me error (AttributeError: 'Namespace' object has no attribute 'save_dir') on dii = [image for image in dii if not os.path.exists('{}/{}.jpg'.format(args.save_dir, list(image)[0]))]

License of the codebase

Thank you for releasing the codebase for ARLDM. It is really a wonderful work with lot of insights into using auto-regression for story visualization. I was wondering what would be the license for this work. We would like to suggest MIT license if possible.

Training issue.

Hi, my running also sucks at the beginning.... Could you tell me how to fix this bug?

Here is the output.

"distributed_backend=nccl
All distributed processes registered. Starting with 2 processes

micl-libyws6:2583202:2583202 [2] NCCL INFO Bootstrap : Using [0]eth0:10.96.80.81<0> [1]enxb03af2b6059f:169.254.3.1<0>
micl-libyws6:2583202:2583202 [2] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation

micl-libyws6:2583202:2583202 [2] misc/ibvwrap.cc:63 NCCL WARN Failed to open libibverbs.so[.1]
micl-libyws6:2583202:2583202 [2] NCCL INFO NET/Socket : Using [0]eth0:10.96.80.81<0> [1]enxb03af2b6059f:169.254.3.1<0>
micl-libyws6:2583202:2583202 [2] NCCL INFO Using network Socket
NCCL version 2.7.8+cuda10.2
micl-libyws6:2583381:2583381 [3] NCCL INFO Bootstrap : Using [0]eth0:10.96.80.81<0> [1]enxb03af2b6059f:169.254.3.1<0>
micl-libyws6:2583381:2583381 [3] NCCL INFO NET/Plugin : No plugin found (libnccl-net.so), using internal implementation

micl-libyws6:2583381:2583381 [3] misc/ibvwrap.cc:63 NCCL WARN Failed to open libibverbs.so[.1]
micl-libyws6:2583381:2583381 [3] NCCL INFO NET/Socket : Using [0]eth0:10.96.80.81<0> [1]enxb03af2b6059f:169.254.3.1<0>
micl-libyws6:2583381:2583381 [3] NCCL INFO Using network Socket
micl-libyws6:2583202:2587764 [2] NCCL INFO Channel 00/02 : 0 1
micl-libyws6:2583381:2587765 [3] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
micl-libyws6:2583202:2587764 [2] NCCL INFO Channel 01/02 : 0 1
micl-libyws6:2583381:2587765 [3] NCCL INFO Trees [0] -1/-1/-1->1->0|0->1->-1/-1/-1 [1] -1/-1/-1->1->0|0->1->-1/-1/-1
micl-libyws6:2583202:2587764 [2] NCCL INFO threadThresholds 8/8/64 | 16/8/64 | 8/8/64
micl-libyws6:2583202:2587764 [2] NCCL INFO Trees [0] 1/-1/-1->0->-1|-1->0->1/-1/-1 [1] 1/-1/-1->0->-1|-1->0->1/-1/-1
micl-libyws6:2583381:2587765 [3] NCCL INFO Channel 00 : 1[43000] -> 0[41000] via P2P/IPC
micl-libyws6:2583202:2587764 [2] NCCL INFO Channel 00 : 0[41000] -> 1[43000] via P2P/IPC
micl-libyws6:2583381:2587765 [3] NCCL INFO Channel 01 : 1[43000] -> 0[41000] via P2P/IPC
micl-libyws6:2583202:2587764 [2] NCCL INFO Channel 01 : 0[41000] -> 1[43000] via P2P/IPC
micl-libyws6:2583381:2587765 [3] NCCL INFO 2 coll channels, 2 p2p channels, 2 p2p channels per peer
micl-libyws6:2583381:2587765 [3] NCCL INFO comm 0x7f96640010d0 rank 1 nranks 2 cudaDev 3 busId 43000 - Init COMPLETE
micl-libyws6:2583202:2587764 [2] NCCL INFO 2 coll channels, 2 p2p channels, 2 p2p channels per peer
micl-libyws6:2583202:2587764 [2] NCCL INFO comm 0x7fe9a00010d0 rank 0 nranks 2 cudaDev 2 busId 41000 - Init COMPLETE
micl-libyws6:2583202:2583202 [2] NCCL INFO Launch mode Parallel

micl-libyws6:2583381:2583381 [3] enqueue.cc:215 NCCL WARN Cuda failure 'invalid device function'
micl-libyws6:2583381:2583381 [3] NCCL INFO group.cc:282 -> 1

micl-libyws6:2583202:2583202 [2] enqueue.cc:215 NCCL WARN Cuda failure 'invalid device function'
micl-libyws6:2583202:2583202 [2] NCCL INFO group.cc:282 -> 1
Error executing job with overrides: []
Traceback (most recent call last):
File "/home/houyi/projects/ARLDM/main.py", line 489, in
main()
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/main.py", line 94, in decorated_main
_run_hydra(
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
_run_app(
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/utils.py", line 457, in _run_app
run_and_report(
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
raise ex
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
return func()
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/utils.py", line 458, in
lambda: hydra.run(
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/hydra.py", line 132, in run
_ = ret.return_value
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/core/utils.py", line 260, in return_value
raise self._return_value
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/core/utils.py", line 186, in run_job
ret.return_value = task_function(task_cfg)
File "/home/houyi/projects/ARLDM/main.py", line 482, in main
train(args)
File "/home/houyi/projects/ARLDM/main.py", line 437, in train
trainer.fit(model, dataloader, ckpt_path=args.train_model_file)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
self._call_and_handle_interrupt(
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1172, in _run
self.__setup_profiler()
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1797, in __setup_profiler
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 2249, in log_dir
dirpath = self.strategy.broadcast(dirpath)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 319, in broadcast
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1681, in broadcast_object_list
broadcast(object_sizes_tensor, src=src, group=group)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1039, in broadcast
Error executing job with overrides: []
Traceback (most recent call last):
File "main.py", line 489, in
work = default_pg.broadcast([tensor], opts)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1627336343171/work/torch/lib/c10d/ProcessGroupNCCL.cpp:33, unhandled cuda error, NCCL version 2.7.8
ncclUnhandledCudaError: Call to CUDA function failed.
main()
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/main.py", line 94, in decorated_main
_run_hydra(
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
_run_app(
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/utils.py", line 457, in _run_app
run_and_report(
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
raise ex
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
return func()
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/utils.py", line 458, in
lambda: hydra.run(
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/_internal/hydra.py", line 132, in run
_ = ret.return_value
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/core/utils.py", line 260, in return_value
raise self._return_value
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/hydra/core/utils.py", line 186, in run_job
ret.return_value = task_function(task_cfg)
File "main.py", line 482, in main
train(args)
File "main.py", line 437, in train
trainer.fit(model, dataloader, ckpt_path=args.train_model_file)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
self._call_and_handle_interrupt(
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in _call_and_handle_interrupt
return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
return function(*args, **kwargs)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1172, in _run
self.__setup_profiler()
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1797, in __setup_profiler
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 2249, in log_dir
dirpath = self.strategy.broadcast(dirpath)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 319, in broadcast
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1681, in broadcast_object_list
broadcast(object_sizes_tensor, src=src, group=group)
File "/home/houyi/anaconda3/envs/arldm2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 1039, in broadcast
work = default_pg.broadcast([tensor], opts)
RuntimeError: NCCL error in: /opt/conda/conda-bld/pytorch_1627336343171/work/torch/lib/c10d/ProcessGroupNCCL.cpp:33, unhandled cuda error, NCCL version 2.7.8
ncclUnhandledCudaError: Call to CUDA function failed."

source images contain not only the first image?

I was trying to join my personal model with your code, and I wondered if the current dataset code for Pororo is working correctly.

I haven't looked into the code and paper thoroughly, especially the modified Unet, so I think I rather ask you a question. 😅

As in this line, source_images gets all the images (the first to the fifth image), after that, it is given sequentially by masking in main.py..

square_mask is a matrix filled with 1 in a triangle. I assume this eventually gives all the encoded source images and source captions to Unet.

Am I correct?

Char-F1 and F-Acc score

Previous works (VLCStoryGan, StoryDALL-E) also use Char-F1 and F-Acc scores to evaluate the character information of generated images. Have you evaluated these two indicators and what are the results?

Regarding the data of the VIST Dataset

Hi @xichenpan
When I tried to reproduce the experiment on the VIST dataset, I noticed that there are numerous duplicate story images in the testing set as illustrated in the figure below, although their text descriptions differ. Is this because some image URLs were inaccessible during the download process? I utilized the vist_img_download.py script to download a total of 184011 images, but I am unsure if some images may have been missing. Would it be possible for you to share the vist.h5 file through Google Drive?
QQ截图20230306160245

updating Stable Diffusion to 2.1?

Thank you for your repository. It greatly helped.

Is there a plan to update the current version of stable diffusion used in your code to 2.1?

I've just tried to naively change the path to "stabilityai/stable-diffusion-2-1-base" and failed on "models.diffusers_override.unet_2d_blocks.py"; this line "out_channels // attn_num_head_channels,"

class CrossAttnUpBlock2D(nn.Module):
    def __init__(
            [omitted]

            resnets.append(
                ResnetBlock2D(
                    in_channels=resnet_in_channels + res_skip_channels,
                    out_channels=out_channels,
                    temb_channels=temb_channels,
                    eps=resnet_eps,
                    groups=resnet_groups,
                    dropout=dropout,
                    time_embedding_norm=resnet_time_scale_shift,
                    non_linearity=resnet_act_fn,
                    output_scale_factor=output_scale_factor,
                    pre_norm=resnet_pre_norm,
                )
            )
            attentions.append(
                Transformer2DModel(
                    attn_num_head_channels,
                    **out_channels // attn_num_head_channels,**
                    in_channels=out_channels,
                    num_layers=1,
                    cross_attention_dim=cross_attention_dim,
                    norm_num_groups=resnet_groups,
                )
            )

There seem to be differences in "attn_num_head_channels" between the two versions.

I am sorry that I cannot provide error message because I have already reverted my code.

In v2.1, "attn_num_head_channels" is a list, not an int.

Before I hack into it, I thought it would be a good time to ask if this would have been tried. I hope for your generous advice on this.

StoryDALL-E results

As you mentioned in your paper, you conducted StoryDALL-E inference experiments. ("experimental results reproduced by us" in Table 1)

I'm also trying to run the code but having difficulty running it. Keep giving me VRAM shortage error on A100 (80GB)

I have left an issue on the StoryDALL-E repo. Still, the author is not replying. 😢

If I may ask you, how did you run it? or could you upload the results outputs (StoryDALL-E generated images)?

I want to benchmark the results (ARLDM and StoryDALL-E) with my own custom model. 😅

best fid score?

Hi,
Thank you so much for your work
How to choose the weight of the best FID score, I used your code to reproduce, none of them reached the performance to the original paper, My model in the FlintstonesSV of FID is 30

Training stucks at the beginning..

I'm working on training this model on the FlintstonesSV dataset. I run the training script on a GPU server with 8x 3080ti (with 12GB ram each card). Is this server able to train this model? What's the maximun memory useage during training?

The training process seems to stuck at "trainer.fit(model, dataloader, ckpt_path=args.train_model_file)". Here is the log:

Global seed set to 0
clip 4 new tokens added
blip 1 new tokens added
clip 4 new tokens added
blip 1 new tokens added
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 0
clip 4 new tokens added
blip 1 new tokens added
Global seed set to 0
clip 4 new tokens added
blip 1 new tokens added
clip 4 new tokens added
blip 1 new tokens added
clip 4 new tokens added
blip 1 new tokens added
Global seed set to 0
Global seed set to 0
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
clip 4 new tokens added
blip 1 new tokens added
clip 4 new tokens added
blip 1 new tokens added
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth
Global seed set to 0
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
[2022-12-26 18:54:48,402][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 1
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large.pth
Global seed set to 0
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
[2022-12-26 18:54:51,472][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 2
Global seed set to 0
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
[2022-12-26 18:54:55,093][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 3
[2022-12-26 18:54:55,097][torch.distributed.distributed_c10d][INFO] - Added key: store_based_barrier_key:1 to store for rank: 0
[2022-12-26 18:54:55,098][torch.distributed.distributed_c10d][INFO] - Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------

[2022-12-26 18:54:55,103][torch.distributed.distributed_c10d][INFO] - Rank 3: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
[2022-12-26 18:54:55,106][torch.distributed.distributed_c10d][INFO] - Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
[2022-12-26 18:54:55,106][torch.distributed.distributed_c10d][INFO] - Rank 2: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.

no signal after waiting for 30 min...

The config.yaml is:

# device
mode: train  # train sample
gpu_ids: [ 0, 1, 2, 3 ]  # gpu ids
batch_size: 1  # batch size each item denotes one story
num_workers: 4  # number of workers
num_cpu_cores: -1  # number of cpu cores
seed: 0  # random seed
ckpt_dir: results/ # checkpoint directory
run_name: first_try # name for this run

# task
dataset: flintstones  # pororo flintstones vistsis vistdii
task: visualization  # continuation visualization

# train
init_lr: 1e-5  # initial learning rate
warmup_epochs: 1  # warmup epochs
max_epochs: 50  # max epochs
train_model_file:  # model file for resume, none for train from scratch
freeze_clip: False  # whether to freeze clip
freeze_blip: False  # whether to freeze blip
freeze_resnet: False  # whether to freeze resnet

# # sample
# test_model_file:  # model file for test
# calculate_fid: True  # whether to calculate FID scores
# scheduler: ddim  # ddim pndm
# guidance_scale: 6  # guidance scale
# num_inference_steps: 250  # number of inference steps
# sample_output_dir: /path/to/save_samples # output directory

# pororo:
#   hdf5_file: /path/to/pororo.h5
#   max_length: 85
#   new_tokens: [ "pororo", "loopy", "eddy", "harry", "poby", "tongtong", "crong", "rody", "petty" ]
#   clip_embedding_tokens: 49416
#   blip_embedding_tokens: 30530

flintstones:
  hdf5_file: Downloads/save_hdf5_files/flintstones.hdf5
  max_length: 91
  new_tokens: [ "fred", "barney", "wilma", "betty", "pebbles", "dino", "slate" ]
  clip_embedding_tokens: 49412
  blip_embedding_tokens: 30525

# vistsis:
#   hdf5_file: /path/to/vist.h5
#   max_length: 100
#   clip_embedding_tokens: 49408
#   blip_embedding_tokens: 30524

# vistdii:
#   hdf5_file: /path/to/vist.h5
#   max_length: 65
#   clip_embedding_tokens: 49408
#   blip_embedding_tokens: 30524

hydra:
  run:
    dir: .
  output_subdir: null
hydra/job_logging: disabled
hydra/hydra_logging: disabled

hello

I like this project very much, when will the pre-training weights be released?

About the image size

Hi, thank you for your wonderful work. I have some questions.
1.Due to the limited memory of each GPU (8*A100 40G), I can only resize images into 256x256 (but not 512x512) so that my GPU can accommodate one story to train AR-LDM. What impact will this change have on the FID score?
2.Besides, the process of sample is very slow. Can we perform the sampling on multiple GPUs?

dataset for VIST

Thanks for opensource!
I'm focusing on reproducing this work and notice that datasets include both VIST-SIS and VIST-DII. However the download script (vist_img_download.py) only suggest to download DII, while vist_hdf5.py only converts SIS.

My question is how to organize VIST-SIS/DII for training in detail.

Adaptive AR-LDM

Hello,

I am interested in the Adaptive AR-LDM section of your paper. Could you please provide more details on how it was implemented? Specifically, I am curious to know if Dreambooth was used to fine-tune the pre-trained AR-LDM model or if another method was employed. Additionally, is the code for this implementation available in the published repository?
Thank you for your time and I look forward to hearing back from you.

Best regards.

Is the generation text guided?

First, thanks for your wonderful work. But I found that it directly outputs results without guide text when I testing the model I trained. And, each test sample in Flintstones dataset has 5 frames. I wonder how the test data are used, and where is the guide text.

stuck by hydra when running on slurm cluster

Hi, when I tried to run this program on our slurm cluster, I can only get the seed output Global seed set to 1 and nothing else. After checking, I think the problem may be the hydra decorator. Seems like the main function did not even be executed. Do you have some experience or recommendations regarding this problem? It would be really helpful and thanks in advance:)

Training Cannot Start

I am running train with the visit datasets on a supercomputer GPU node. The code is unable to proceed from rank: 0:

/people/$USER/.conda/envs/arldm/lib/python3.8/site-packages/pl_bolts/models/self_supervised/amdim/amdim_module.py:34: UnderReviewWarning: The feature generate_power_seq is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
/people/$USER//.conda/envs/arldm/lib/python3.8/site-packages/pl_bolts/models/self_supervised/amdim/amdim_module.py:92: UnderReviewWarning: The feature FeatureMapContrastiveTask is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
/people/$USER//.conda/envs/arldm/lib/python3.8/site-packages/pl_bolts/losses/self_supervised_learning.py:228: UnderReviewWarning: The feature AmdimNCELoss is currently marked under review. The compatibility with other Lightning projects is not guaranteed and API may change at any time. The API and functionality may change without warning in future releases. More details: https://lightning-bolts.readthedocs.io/en/latest/stability.html
  self.nce_loss = AmdimNCELoss(tclip)
[rank: 0] Global seed set to 0
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Global seed set to 0
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[W socket.cpp:426] [c10d] The server socket cannot be initialized on [::]:19369 (errno: 97 - Address family not supported by protocol).
[W socket.cpp:601] [c10d] The client socket cannot be initialized to connect to [dlt04.local]:19369 (errno: 97 - Address family not supported by protocol).

I noticed that the error potentially is related to this DDP issue. And I've tried all the mentioned solutions from this thread but none solve my problem.

So My question is, are there any other ways to run the code differently to bypass this issue?
I'd like to run this program ideally on GPUs. CPU is also fine but is there any adjustment so I can reduce the computation time?


My config: https://github.com/candiceT233/ARLDM/blob/main/config.yaml

My conda environment packages:

accelerate==0.20.3
diffusers==0.7.2
ftfy==6.1.1
hydra-core==1.3.2
lightning-bolts==0.7.0
pytorch-lightning==1.9.5
timm==0.5.4
torch==2.0.1
torchaudio==2.0.2
torchmetrics==1.1.0
torchvision==0.15.2
transformers==4.24.0

Others:

Python 3.8.17
CentOS7
3.10.0-1127.18.2.el7.x86_64
GPU : 8 x RTX 2080 Ti GPUs 384GB memory

Sbatch command:

python main.py &> "$SCRIPT_DIR/$JOB_NAME.log"

a problem about google drive

I am absolutely enamored with the ARLDM project! It's an incredibly skilled and thoughtful implementation that truly stands out in its field. The intelligent design and comprehensive approach taken by the team are remarkable and deeply appreciated. What I love the most about ARLDM is its commitment to solving complex problems with innovative and effective solutions. The project's functionality, efficiency, and usability make it a joy to use. It's clear that a lot of dedication and hard work has been poured into this project, and it certainly does not go unnoticed. Keep up the excellent work!
However, I've encountered a problem.When I try to download goole drive ,i meet a problem.
gdown --id 11Io1_BufAayJ1BpdxxV2uJUvCcirbrNc -O pororo.zip
/usr/local/lib/python3.6/site-packages/gdown/cli.py:130: FutureWarning: Option --id was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.
category=FutureWarning,
Access denied with the following error:

    Too many users have viewed or downloaded this file recently. Please
    try accessing the file again later. If the file you are trying to
    access is particularly large or is shared with many people, it may
    take up to 24 hours to be able to view or download the file. If you
    still can't access a file after 24 hours, contact your domain
    administrator. 

You may still be able to access the file from the browser:

     https://drive.google.com/uc?id=11Io1_BufAayJ1BpdxxV2uJUvCcirbrNc 

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.