Code Monkey home page Code Monkey logo

matanby / fast-style-transfer-pytorch Goto Github PK

View Code? Open in Web Editor NEW
8.0 2.0 1.0 21.38 MB

A simple and minimalistic implementation of the fast neural style transfer method presented in "Perceptual Losses for Real-Time Style Transfer and Super-Resolution" by Johnson et. al. (2016) ๐Ÿž

Python 100.00%
art deep-learning pytorch style-transfer artistic-style-transfer nst neural-style-transfer neural-style-transfer-pytorch perceptual-losses

fast-style-transfer-pytorch's Introduction

Fast Neural Style-Transfer PyTorch

This is a simple and minimalistic PyTorch implementation of the fast neural style transfer method introduced in Perceptual Losses for Real-Time Style Transfer and Super-Resolution by Johnson et. al (2016).

The original neural style transfer method by Gatys et. al (A Neural Algorithm of Artistic Style) generates the stylized image by iteratively optimizing the target loss function, which combines the content and style terms.

Unlike this method, which is slow by its nature, the method by Johnson et. al presents a method for training a convolutional neural-network that takes in a content image and generates a stylized version of it. This makes the image generation process orders of magnitudes faster. However, the down-side is that the network is trained on one specific style, and therefore is not generic.


Prerequisites:

  • Python 3
  • CUDA + CUDNN (for GPU acceleration)

Installation:

  1. Clone this repository:
git clone https://github.com/matanby/fast-style-transfer-pytorch
  1. Install PIP requirements:
python3 -m virtualenv .venv
source .venv/bin/activate 
pip install -r fast-style-transfer-pytorch/requirements.py

Usage:

You can use one of the three pre-trained models that are bundled with this repository, or train your custom models on your own style images (see training instructions below).

From command-line:

python run.py [PATH_TO_PRETRAINED_MODEL] [PATH_TO_CONTENT_IMAGE] [PATH_TO_STYLIZED_OUTPUT]

Programmatically:

Use the Stylizer class to create stylized images programmatically. For example:

import image_utils
from stylizer import Stylizer

stylizer = Stylizer('models/style1.pt')
image = image_utils.load('images/content/1.jpg')
stylized = stylizer.stylize(image)
image_utils.save(stylized, f'images/stylized/style1/1.jpg')

Training the model:

You can train a custom model on your own inputs style images. To do so, you'll need a dataset of content images to train on. The authors of the paper used the MS-COCO 2014 dataset.

To initiate the training process, run the train.py script as follows:

python train.py --dataset_path [PATH_TO_DATASET] --style_image_path [PATH_TO_STYLE_IMAGE]

See below for more info on how the dataset folder should be structured.

It is also possible to override the default configuration entries and hyper-parameters values, by providing additional CLI arguments, for example:

python train.py \
  --dataset_path ms-coco \
  --style_image_path images/style/1.jpg \
  --batch_size 8 \
  --lambda_style 200

Complete list of configuration entries and hyper-parameters:

  • dataset_path: the path to the folder containing training and validation sets. this folder should be structured as follows:
    • train
      • images
        • image_1.jpg
        • ...
    • validation
      • images
        • image_1.jpg
        • ...
  • style_image_path: the path to the target style image.
  • root_logdir: the root directory in which model snapshots and TensorBoard logs will be saved. default = 'models'.
  • weights_snapshot_path: a path to a snapshot of the model's weights. to be used when resuming a previous training job. default = ''.
  • lambda_content: the weight of the content term in the total loss. empirically good range: 1 - 100. default = 10.
  • lambda_style: the weight of the style term in the total loss. empirically good range: 10 - 100,000. default = 100.
  • lambda_tv: the weight of the generated image's total variation in the total loss. empirically good range: 0 - 1,000. default = 10.
  • learning_rate: the size of each step of the optimization process. default = 1e-3.
  • epochs: number of training epochs to perform. default = 2.
  • content_block_weights: the weight of each convolutional block in the content loss. These five numbers refer to the following five activations of the VGG19 model: conv1_1, conv2_1, conv3_1, conv4_1, conv5_1. default = (0.0, 1.0, 0.0, 0.0, 0.0).
  • style_block_weights: the weight of each convolutional block in the style loss. These five numbers refer to the following five activations of the VGG19 model: conv1_1, conv2_1, conv3_1, conv4_1, conv5_1. default = (1/5, 1/5, 1/5, 1/5, 1/5).
  • input_images_dim: the dimension of the model's input images. default = 256.
  • visualization_interval: the interval (number of training iterations) after which intermediate results of the stylized images will be visualized in TensorBoard. default = 50.
  • snapshot_interval: the interval (number of training iterations) after which an intermediate snapshot of the model will be saved to the disk. default = 1000.
  • batch_size: the mini batch size to use for each training iteration. default = 4.
  • num_data_loader_workers: the number of workers to use for loading images from the dataset in the background. default = 5.

Examples:

Content Image /
Style Image
content content content
content content content content
content content content content
content content content content
content content content content
content content content content
content content content content

fast-style-transfer-pytorch's People

Contributors

matanby avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

cstichbury

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.