Code Monkey home page Code Monkey logo

ntscc_jsac22's Introduction

Nonlinear Transform Source-Channel Coding for Semantic Communications

Pytorch Implementation of JSAC 2022 Paper "Nonlinear Transform Source-Channel Coding for Semantic Communications"

Arxiv Link: https://arxiv.org/abs/2112.10961

Project Page: https://semcomm.github.io/ntscc/

Prerequisites

  • Python 3.8 and Conda
  • CUDA 11.0
  • Environment
    conda create -n $YOUR_PY38_ENV_NAME python=3.8
    conda activate $YOUR_PY38_ENV_NAME
    
    pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
    python -m pip install -r requirements.txt
    

Usage

Example of test the PSNR model:

python main.py --phase test --checkpoint path_to_checkpoint

Pretrained Models

Pretrained models (optimized for MSE) trained from scratch using randomly chose 500k images from the OpenImages dataset.

Other pretrained models will be released successively.

Note: We reorganize code and the performances are slightly different from the paper's.

RD curves on Kodak, under AWGN channel SNR=10dB. kodak_rd

Citation

If you find the code helpful in your research or work, please cite:

@ARTICLE{9791398,
  author={Dai, Jincheng and Wang, Sixian and Tan, Kailin and Si, Zhongwei and Qin, Xiaoqi and Niu, Kai and Zhang, Ping},
  journal={IEEE Journal on Selected Areas in Communications}, 
  title={Nonlinear Transform Source-Channel Coding for Semantic Communications}, 
  year={2022},
  volume={40},
  number={8},
  pages={2300-2316},
  doi={10.1109/JSAC.2022.3180802}
  }

Acknowledgements

The NTSCC model is partially built upon the Swin Transformer and CompressAI. We thank the authors for sharing their code.

ntscc_jsac22's People

Contributors

wsxtyrdd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

ntscc_jsac22's Issues

Question about the channel and channel coding

Thank you for your great work!!
I have some questions about the channel coding and the channel. First, I havn't seen the implementation of channel coding in your code. Besides, the channel seems to directly add a complex gaussian noise on the quantified vector without coding and modulation. The above two issues are different from the paper. I wonder if there are some details I havn't notice.

Settings for Cifar10 dataset

Hi sixian:

Thanks a lot for the nice work! I am recently trying to re-implement it for the CIFAR10 dataset. I found that by setting lambda = {1024, 256, 64, 16, 4} for the CIFAR10 as shown in the JSAC paper does not give me the same performance in Fig.9 (a). To be specific, setting lambda = 4 gives me a PSNR ~ 34 dB at SNR = 10dB while setting lambda = 64 yields PSNR ~ 33.5 dB.

Could you please inform me how to obtain the CIFAR 10 performance in Fig. 9 (a)? By setting even smaller lambda (e.g., 0.1) for good PSNR and setting very large lambda (e.g., 10k) to obtain low bit rate? Or one should simply change the eta from 0.2 to some other values?

Thanks a lot in advance!

About datasets

train_data_dir = ['path to DIV2K_train_HR or OpenImages']
test_data_dir = ['Path to kodak']
paper mention that "the dataset for training the proposed NTSCC model for large images consists of 500,000 images sampled from the Open Images Dataset". What is the difference in permance between DIV2K_train_HR and OpenImages? The size of DIV2K_train_HR is about 900, while the size of OpenImages is 500,000.

Train error

When I run
python main.py -p train, there is a error:
RuntimeError: Expected isFloatingType(grad.scalar_type()) || (input_is_complex == grad_is_complex) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)
I checked the Gradient but output False.

for batch_idx, input_image in enumerate(train_loader):
optimizer_G.zero_grad()
aux_optimizer.zero_grad()
start_time = time.time()
input_image = input_image.to(device)
global_step += 1
mse_loss_ntc, bpp_y, bpp_z, mse_loss_ntscc, cbr_y, x_hat_ntc, x_hat_ntscc = net(input_image)
if config.use_side_info: # False
cbr_z = bpp_snr_to_kdivn(bpp_z, 10)
loss = mse_loss_ntscc + mse_loss_ntc + config.train_lambda * (bpp_y * config.eta + cbr_z)
cbrs.update(cbr_y + cbr_z)
else:
# add ntc_loss to improve the training convergence stability
ntc_loss = mse_loss_ntc + config.train_lambda * (bpp_y + bpp_z)
loss = ntc_loss + mse_loss_ntscc
cbrs.update(cbr_y)
print("Input Image Type:", input_image.dtype)
print("Gradient Exists:", any(p.grad is not None for p in net.parameters()))

How to solve this promblem? Thx a lot :)

"When I run python main.py --phase train, the following error occurs."

Traceback (most recent call last):
File "/home/jn/SC_work/NTSCC_JSAC22-master/main.py", line 175, in
main(sys.argv[1:])
File "/home/jn/SC_work/NTSCC_JSAC22-master/main.py", line 163, in main
loss = test(net, test_loader, logger)
File "/home/jn/SC_work/NTSCC_JSAC22-master/main.py", line 25, in test
mse_loss_ntc, bpp_y, bpp_z, mse_loss_ntscc, cbr_y, x_hat_ntc, x_hat_ntscc = net(input_image)
File "/home/jn/miniconda3/envs/ntscc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/jn/SC_work/NTSCC_JSAC22-master/net/NTSCC_Hyperior.py", line 119, in forward
self.forward_NTC(input_image, require_probs=True)
File "/home/jn/SC_work/NTSCC_JSAC22-master/net/NTSCC_Hyperior.py", line 152, in forward_NTC
return super(NTSCC_Hyperprior, self).forward(input_image, **kwargs)
File "/home/jn/SC_work/NTSCC_JSAC22-master/net/NTSCC_Hyperior.py", line 53, in forward
y = self.ga(input_image)
File "/home/jn/miniconda3/envs/ntscc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/jn/SC_work/NTSCC_JSAC22-master/layer/analysis_transform.py", line 98, in forward
x = layer(x)
File "/home/jn/miniconda3/envs/ntscc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, **kwargs)
File "/home/jn/SC_work/NTSCC_JSAC22-master/layer/analysis_transform.py", line 35, in forward
x = self.downsample(x)
File "/home/jn/miniconda3/envs/ntscc/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, **kwargs)
File "/home/jn/SC_work/NTSCC_JSAC22-master/layer/layers.py", line 363, in forward
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}
{W}) are not even."
AssertionError: x size (681
1024) are not even.
image

How to install requirements.txt

Hi, Iโ€™m having some issues with installing the requirements for your project. When I run python -m pip install -r requirements.txt, I get this error: ERROR: Invalid requirement: '_libgcc_mutex=0.1=conda_forge' (from line 4 of requirements.txt). And when I run conda create --name zxy --file requirements.txt, I get this error: PackagesNotFoundError: The following packages are not available from current channels:.
Could you please help me resolve these errors? Thank you.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.