Code Monkey home page Code Monkey logo

textgan-pytorch's Introduction

TextGAN-PyTorch

TextGAN is a PyTorch framework for Generative Adversarial Networks (GANs) based text generation models. TextGAN serves as a benchmarking platform to support research on GAN-based text generation models. Since most GAN-based text generation models are implemented by Tensorflow, TextGAN can help those who get used to PyTorch to enter the text generation field faster.

If you find any mistake in my implementation, please let me know! Also, please feel free to contribute to this repository if you want to add other models.

LICENSE Contributions welcome Gitter

Requirements

  • PyTorch >= 1.0.0
  • Python 3.6
  • Numpy 1.14.5
  • CUDA 7.5+ (For GPU)
  • nltk 3.4
  • tqdm 4.32.1

To install, run pip install -r requirements.txt. In case of CUDA problems, consult the official PyTorch Get Started guide.

Implemented Models and Original Papers

Get Started

  • Get Started
git clone https://github.com/williamSYSU/TextGAN-PyTorch.git
cd TextGAN-PyTorch
  • For real data experiments, Image COCO and EMNLP news dataset can be downloaded from here.
  • Run with a specific model
cd run
python3 run_[model_name].py 0 0	# The first 0 is job_id, the second 0 is gpu_id

# For example
python3 run_seqgan.py 0 0

Features

  1. Instructor

    For each model, the entire runing process is defined in instructor/oracle_data/seqgan_instructor.py. (Take SeqGAN in Synthetic data experiment for example). Some basic functions like init_model()and optimize() are defined in the base class BasicInstructor in instructor.py. If you want to add a new GAN-based text generation model, please create a new instructor under instructor/oracle_data and define the training process for the model.

  2. Visualization

    Use utils/visualization.py to visualize the log file, including model loss and metrics scores. Custom your log files in log_file_list, no more than len(color_list). The log filename should exclude .txt.

  3. Logging

    The TextGAN-PyTorch use the logging module in Python to record the running process, like generator's loss and metric scores. For the convenience of visualization, there would be two same log file saved in log/log_****_****.txt and save/**/log.txt respectively. Furthermore, The code would automatically save the state dict of models and a batch-size of generator's samples in ./save/**/models and ./save/**/samples per log step, where ** depends on your hyper-parameters.

  4. Running Signal

    You can easily control the training process with the class Signal (please refer to utils/helpers.py) based on dictionary file run_signal.txt.

    For using the Signal, just edit the local file run_signal.txt and set pre_sig to Fasle for example, the program will stop pre-training process and step into next training phase. It is convenient to early stop the training if you think the current training is enough.

  5. Automatiaclly select GPU

    In config.py, the program would automatically select a GPU device with the least GPU-Util in nvidia-smi. This feature is enabled by default. If you want to manually select a GPU device, please uncomment the --device args in run_[run_model].py and specify a GPU device with command.

Implementation Details

SeqGAN

LeakGAN

MaliGAN

JSDGAN

RelGAN

Licence

MIT lincense

textgan-pytorch's People

Contributors

williamsysu avatar ishalyminov avatar rtst777 avatar songyouwei avatar yupaul avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.