Code Monkey home page Code Monkey logo

cs348k_dit's Introduction

CS348K Final Project Proposal (Kyle Sargent, Eric Ryan Chan)

Proposal Document

  • Project Title: Alternative Architectures for Diffusion Image Transformers

  • Names and SuNET ID's ksarge, erchan

  • Summary: For years, UNets were the predominant network architecture for denoising diffusion models. Recently, transformer-focused network designs, including UViT and DiT, have been demonstrated to outperform UNets for image and video generation. While transformers are a powerful architecture, their computational complexity (O(N2) in sequence length) make them difficult to scale without vast computational resources. We would like to investigate whether sub-quadratic transformer alternatives, such as Mamba, are a viable alternative to standard attention for image-generation diffusion models. To our knowledge, sub-quadratic attention has not been successfully used in diffusion models before, and a thorough evaluation, whether it produces a positive or a negative result, will be useful for the generative modeling community.

  • Inputs and outputs: The input to the task is a training dataset, used to train our generative model, and the output of the task is a trained generative model, which can be used to synthesize images. The central constraints are 1. model performance, measured in Frechet Inception Distance (FID, lower is better), and compute required to train the model, measured in GPU hours. Our goal is to design a diffusion model architecture that produces good images (achieving a low FID score) while requiring minimal resources to train.

  • Task list:

    • Download the Diffusion Transformers github, get an allocation for compute resources on the SVL lab cluster.
    • Download and preprocess the FFHQ dataset.
    • Implement an evaluation protocol. We plan to use training dataset Frechet Inception Distance as our primary metric, with Kernel Inception Distance, validation loss, and validation dataset Frechet Inception Distance as useful secondary measures.
    • Set a baseline using a transformer-based architecture on the FFHQ dataset.
    • Implement an alternative architecture using Mamba.
    • Train, evaluate, and analyze the results of the Mamba-based architecture.
  • Expected deliverables and/or evaluation. Our intended goal is to train a DiT variant with a Mamba-based architecture to generate images on the FFHq dataset, and compare the performance using the metrics above with the reference DiT implementation. In addition to visual quality, which can be benchmarked via the generative metrics like FID, KID, etc., we intend to perform analysis of the inference throughput of the Mamba-based variant with DiT.

  • What are the biggest risks? Beyond implementation and training, the biggest risk of the project is that we implement novel diffusion architectures and, despite our best efforts, they underperform current SOTA models based on standard transformers. Although less satisfying, a negative result, i.e. a conclusion that sub-quadratic attention cannot outperform traditional attention for image synthesis, would still be an important discovery that would inform future research.

  • What you need help with?. Although we believe we are comfortable executing the project plan as written, it would be helpful to get some initial feedback about whether the project is targeting an appropriate problem and scoped appropriately. Additionally, if there are any further systems-type analyses that are desirable, please let us know and we can attempt them on our Mamba-based generative model of images.

cs348k_dit's People

Contributors

wpeebles avatar kylesargent avatar ericryanchan avatar eltociear avatar s9xie avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.