Code Monkey home page Code Monkey logo

stitchdiffusion's Introduction

StitchDiffusion (Keep Update)

Customizing 360-Degree Panoramas through Text-to-Image Diffusion Models
Hai Wang, Xiaoyu Xiang, Yuchen Fan, Jing-Hao Xue

Project arXiv

Colab was implemented by @lshus.

StitchDiffusion Code

Actually, StitchDiffusion is a tailored generation (denoising) process for synthesizing 360-degree panoramas, we provide its core code here.

## following MultiDiffusion: https://github.com/omerbt/MultiDiffusion/blob/master/panorama.py ##
## the window size is changed for 360-degree panorama generation ##
def get_views(panorama_height, panorama_width, window_size=[64,128], stride=16):
    panorama_height /= 8
    panorama_width /= 8
    num_blocks_height = (panorama_height - window_size[0]) // stride + 1
    num_blocks_width = (panorama_width - window_size[1]) // stride + 1
    total_num_blocks = int(num_blocks_height * num_blocks_width)
    views = []
    for i in range(total_num_blocks):
        h_start = int((i // num_blocks_width) * stride)
        h_end = h_start + window_size[0]
        w_start = int((i % num_blocks_width) * stride)
        w_end = w_start + window_size[1]
        views.append((h_start, h_end, w_start, w_end))
    return views
#####################
## StitchDiffusion ##
#####################

views_t = get_views(height, width) # height = 512; width = 4*height = 2048
count_t = torch.zeros_like(latents)
value_t = torch.zeros_like(latents)
# latents are sampled from standard normal distribution (torch.randn()) with a size of Bx4x64x256,
# where B denotes the batch size.

for i, t in enumerate(tqdm(timesteps)):

    count_t.zero_()
    value_t.zero_()

    # initialize the value of latent_view_t
    latent_view_t = latents[:, :, :, 64:192]

    #### pre-denoising operations twice on the stitch block ####
    for ii_md in range(2):

        latent_view_t[:, :, :, 0:64] = latents[:, :, :, 192:256] #left part of the stitch block
        latent_view_t[:, :, :, 64:128] = latents[:, :, :, 0:64] #right part of the stitch block

        # expand the latents if we are doing classifier free guidance
        latent_model_input = latent_view_t.repeat((2, 1, 1, 1))

        # # predict the noise residual
        noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the denoising step with the reference (customized) model
        latent_view_denoised = self.scheduler.step(noise_pred, t, latent_view_t)['prev_sample']

        value_t[:, :, :, 192:256] += latent_view_denoised[:, :, :, 0:64]
        count_t[:, :, :, 192:256] += 1

        value_t[:, :, :, 0:64] += latent_view_denoised[:, :, :, 64:128]
        count_t[:, :, :, 0:64] += 1

    # same denoising operations as what MultiDiffusion does
    for h_start, h_end, w_start, w_end in views_t:

        latent_view_t = latents[:, :, h_start:h_end, w_start:w_end]
    
        # expand the latents if we are doing classifier free guidance
        latent_model_input = latent_view_t.repeat((2, 1, 1, 1))
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']

        #perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the denoising step with the reference (customized) model
        latent_view_denoised = self.scheduler.step(noise_pred, t, latent_view_t)['prev_sample']
        value_t[:, :, h_start:h_end, w_start:w_end] += latent_view_denoised
        count_t[:, :, h_start:h_end, w_start:w_end] += 1

    latents = torch.where(count_t > 0, value_t / count_t, value_t)

latents = 1 / 0.18215 * latents
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)


#### global cropping operation ####
image = image[:, :, :, 512:1536]
image = image.cpu().permute(0, 2, 3, 1).float().numpy()

Useful Tools

360 panoramic images viewer: It could be used to view the synthesized 360-degree panorama.

Seamless Texture Checker: It could be employed to check the continuity between the leftmost and rightmost sides of the generated image.

clip-interrogator: It contains Google Colab of BLIP to generate text prompts.

CLIP: It contains Google Colab to calculate the CLIP-score.

FID: It contains Google Colab to calculate FID.

Statement

This research was done by Hai Wang in University College London. The code and released models are owned by Hai Wang.

Citation

If you find the code helpful in your research or work, please cite our paper:

@inproceedings{wang2024customizing,
  title={Customizing 360-Degree Panoramas through Text-to-Image Diffusion Models},
  author={Wang, Hai and Xiang, Xiaoyu and Fan, Yuchen and Xue, Jing-Hao},
  booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
  pages={4933--4943},
  year={2024}
}

Acknowledgments

We thank MultiDiffusion. Our work is based on their excellent codes.

stitchdiffusion's People

Contributors

littlewhitesea avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

aniketgurav

stitchdiffusion's Issues

About fine-tuned code and data sets

Hello, what a great idea. I learned from the paper that you used Panorama to fine-tune it. Could you share the fine-tuning code and data set? I would be very grateful.

Is this real equirectangular panorama image?

Hi, I was wondering whether the equirectangular projection is always ensured, or the finetuned model just know how to generate images seems like equirectangular patch? If equirectangular projection is ensured, it can be decomposed into multiple perspective view images right?

I have this question because I also tried similar things, to make Multi-Diffusion work with camera rotations. But I found the latents are very sensitive to warping operations. The results becomes terrible after several steps. I was wondering whether you have met the same problem? Do you find directly learning and generating equirectangular images helps to solve this problem?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.