Code Monkey home page Code Monkey logo

cgds-package's Introduction

Hi there I'm Hongkai Zheng👋

I'm a Ph.D student in Computing and Mathematical Science at Caltech. I obtained my B.S. degree in CS from Shanghai Jiao Tong University.

My personal website: https://devzhk.myportfolio.com/

My Research Interest

  • 💻 Machine learning
  • ✨ Deep generative models and sampling.
  • ⚡ Operator learning

Publications

My google scholar: Google scholar page

Academic Service

  • Journal of Machine Learning Research (JMLR)
  • Conference on Neural Information Processing Systems (NeurIPS) 2022

Photography

I'm also a photography hobbyist! My instagram: https://www.instagram.com/devzhk. My portfolio: https://hzphotography.myportfolio.com/

cgds-package's People

Contributors

devzhk avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

cgds-package's Issues

Potential bug in conjugate_gradient and general_conjugate_gradient

Hello,

I was using ACGD for a constrained optimization problem with the Lagrange multiplier method. I let the primal parameters be the min player and the Lagrange multiplier be the max player.

With the Lagrange multiplier having a dimension of 1, after the first iteration my nsteps equals 0 after after line 168. This creates an error in line 200 if i > 100:.

I currently solved it by inserting in line 178: i = 0. The results I got make sense.

Without going too much into the theory while using this, I was wondering if ACGD is the correct choice and if my fix makes sense.

Thank you!

One TypeError when running acgd_test.py.

When I run tests/acgd_test.py, there is an error: TypeError: backward() got an unexpected keyword argument 'inputs'.

ad8cd055c73108c6cab9ac5fd76f33b

Have you ever encountered this problem? Can you tell me how to resolve it?

bug,

hi. I have the problem.

image

Shall you help me? Here is my code

This is my definition of optimizer. I don't know if there's a problem.

`
class Network(nn.Module):

def __init__(self, criterion, cfg, **kwargs):

    super(Network, self).__init__()
    self.HRNet = HighResolutionNet(cfg, **kwargs)
    self.HRNet.init_weights('/mnt/mountA/dzr/segcgd/pretrain/hrnet_w48_pascal_context_cls59_480x480.pth') # 这里是用初始化权重,用于测试(已经训练好了的)如果想要重新训练的话,直接注释掉即可
    self.CoutHR = self.HRNet.last_inp_channels

    self.dishsi = DisHSI()

    self.respG = RespG()   # 响应函数  返回的是响应谱和传入的MSI已经进行了相乘
    self.hsiG = HSIG()    # 重建网络   这两个是生成器     想把这两个网络进行合并为生成器
    # self.RgbG = RGBG()
    self.iter = 1

    self.extractor = Fea_extra(self.CoutHR+31, cfg.DATASET.NUM_CLASSES)  # 这个是一个特征提取,在cgd优化器不用管他,

    # 将生成器进行整合,便于训练
    self.generator = nn.Sequential(
        self.respG,
        self.hsiG,
    )


    # 这是优化器 cgd优化器  这是整体的,有一些不方便
    rank = torch.distributed.get_rank()
    self.g = DDP(self.generator.cuda(), device_ids=[rank], broadcast_buffers=False)
    self.d = DDP(self.dishsi.cuda(), device_ids=[rank], broadcast_buffers=False)
    g_reducer = self.g.reducer
    d_reducer = self.d.reducer

    self.cgd_optimizer = CGDs.ACGD(max_params=self.g.parameters(),
                                   min_params=self.d.parameters(),
                                   lr_max=1e-3, lr_min=1e-3,
                                   max_reducer=g_reducer, min_reducer=d_reducer,
                                   tol=1e-4, atol=1e-8)
    # self.cgd_optimizer = CGDs.ACGD(max_params=itertools.chain(self.hsiG.parameters(), self.respG.parameters()),
    #                                min_params=self.dishsi.parameters(),
    #                                lr_max=1e-3, lr_min=1e-3,
    #                                # max_reducer=g_reducer, min_reducer=d_reducer,
    #                                tol=1e-4, atol=1e-8)


    # 分别定义了优化器

    # self.gen_optimizer = torch.optim.Adam(itertools.chain(self.hsiG.parameters(), self.respG.parameters(), self.extractor.parameters()),lr=1e-4)
    # self.gen_optimizer = torch.optim.Adam(itertools.chain(self.hsiG.parameters(), self.respG.parameters()),lr=1e-4)
    # self.gen_optimizer = torch.optim.Adam(itertools.chain(self.hsiG.parameters(), self.respG.parameters()), lr=1e-4)   # 生成器太弱了  这个1
    # # self.dis_optimizer = torch.optim.Adam(itertools.chain(self.dishsi.parameters(), self.dishsi_line.parameters()),lr=1e-3)
    # self.dis_optimizer = torch.optim.Adam(itertools.chain(self.dishsi.parameters()), lr=1e-6)   # 鉴别器太强了 lr = 5e-4  这个2

    self.criterion_class = criterion
    # self.criterion = nn.BCELoss()
    self.criterion = nn.BCEWithLogitsLoss()  # 改动
    self.mean = np.array([0.485, 0.456, 0.406])

    self.std = np.array([0.229, 0.224, 0.225])

    self.gradnorm = GradNorm().apply`

This is the code for the training cycle.

def train(config, epoch, num_epoch, ### epoch_iters, base_lr, num_iters,

     trainloader, optimizer, model, writer_dict, device, Logger=None):

# Training
model.train()
# model.HRNet.eval()
# model.hsiG.eval()
# model.hsiG.eval()
# model.respG.eval()

batch_time = AverageMeter()
ave_loss = AverageMeter()
tic = time.time()
cur_iters = epoch*epoch_iters
writer = writer_dict['writer']
global_steps = writer_dict['train_global_steps']
rank = get_rank()
world_size = get_world_size()

# # 将生成器进行整合,便于训练
# generator = nn.Sequential(
#     model.module.respG,
#     model.module.hsiG,
# )
# cgd_optimizer = CGDs.ACGD(max_params=generator.parameters(),
#                                min_params=model.module.dishsi.parameters(),
#                                lr_max=1e-3, lr_min=1e-3,
#                                # max_reducer=g_reducer, min_reducer=d_reducer,
#                                tol=1e-4, atol=1e-8)

loss_g = []
loss_d = []
GP_List = []
smooth_List = []
res_List = []
# 计数器
gen_train_count = 0
dis_train_count = 0
for i_iter, batch in enumerate(trainloader):
    images, labels, _, _, MSI, HSI = batch  # 获得一个batch中的各个数据
    # images, labels, _, _,MSI = batch
    model.zero_grad()    # 梯度置为零
    model.module.cgd_optimizer.zero_grad()
    # cgd_optimizer.zero_grad()


    images = images.to(device).float()   # 将输入的图像数据(通常为RGB图像)转化为GPU上的张量,并且数据类型转化为float类型
    MSI = MSI.to(device).float()   # 将多光谱图像数据转化为GPU上的张量,数据类型依旧转化为float类型
    labels = labels.long().to(device)  # 将标签数据转化为GPU上的张量,将数据类型转化为long类型
    HSI = HSI.to(device).float()    #  同理,将高光谱数据也转化为GPU上的张量,数据类型转化为float类型



    loss,_,_ = model(images,MSI,labels)
    loss_d_, GP_loss = model.module.update_discriminator(MSI = MSI, HSI = HSI, rank = 0)  # 多卡 加了一个module  更新鉴别器
    loss_g_, smooth_loss, res_loss, gen_loss = model.module.update_generator(MSI = MSI, HSI = HSI, img = images, rank = 0, seg_label=labels)   # 更新生成器损失
    # loss = np.mean(np.array(loss_g_))


    model.module.cgd_optimizer.step(loss_d_)   # 调用更新步长
    # cgd_optimizer.step(loss_d_)   # 调用更新步长

    loss_d_ = loss_d_.item()
    loss_g.append(loss_g_.item() - gen_loss.item())
    loss_d.append(loss_d_)
    GP_List.append(GP_loss)
    smooth_List.append(smooth_loss)
    res_List.append(res_loss)

    optimizer.step()  # sgd优化器的更新


    # measure elapsed time
    batch_time.update(time.time() - tic)
    tic = time.time()

    # update average loss
    ave_loss.update(loss_g_.item() - gen_loss.item())

    # gen_lr = adjust_learning_rate(model.gen_optimizer,
    #                           1e-4,
    #                           num_iters,
    #                           0)

    dis_lr = adjust_learning_rate(optimizer,
                              base_lr,
                              num_iters,
                              i_iter+cur_iters)

`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.