Code Monkey home page Code Monkey logo

deep-learning-collection's Issues

Very low training accuracies

Hello, I know this is quite an old repo - but this has been one of the most helpful one for me in recent years, for implementation issues.

I am facing this odd problem with EfficientNet. I am trying to pretrain it from scratch using four classes from the CIFAR10. But the validation accuracy remains extremely low for different combinations of hyperparameters. To be honest, I have tried couple more other popular implementations too - the result remains the same. I double checked my dataloading and preprocessing stages. They don't seem particularly problematic :

def load_dataset(dataset_name, dataset_path, splits, nworkers = 4):
    if dataset_name.upper() == 'CIFAR10':
        mean, std = [x / 255 for x in [125.3, 123.0, 113.9]],  [x / 255 for x in [63.0, 62.1, 66.7]]
        dataset = dset.CIFAR10
        num_classes = 10
   
    else: assert False, "Unknown dataset : {}".format(dataset_name)

    cifar_classes = list(set(dataset(root='./data', train=True, download=True).targets))

    train_transform = transforms.Compose([transforms.Resize(256),
                                          transforms.CenterCrop(224),
                                          transforms.AutoAugment(torchvision.transforms.AutoAugmentPolicy.CIFAR10), 
                                          transforms.ToTensor(), transforms.Normalize(mean, std)])
    test_transform = transforms.Compose([transforms.Resize(256),
                                         transforms.CenterCrop(224), 
                                         transforms.ToTensor(), transforms.Normalize(mean, std)])

    cifar_train = dataset(dataset_path, train=True, transform=train_transform, download=True)
    cifar_test = dataset(dataset_path, train=False, transform=test_transform, download=True)

    train_loaders = []
    test_loaders = []

    for idx, split in enumerate(splits):
        classes_set = [*range(split[0], split[1], 1)] 
        print("Classes Set: ", classes_set)
        indices_train = [i for i, label in enumerate(cifar_train.targets) if label in classes_set]
        indices_test = [i for i, label in enumerate(cifar_test.targets) if label in classes_set]

        subset_train = Subset(cifar_train, indices_train)
        subset_test = Subset(cifar_test, indices_test)

        train_loader = DataLoader(subset_train, batch_size=batch_size, shuffle=True,)
        test_loader = DataLoader(subset_test, batch_size=batch_size, shuffle=False)

        train_loaders.append(train_loader)
        test_loaders.append(test_loader)
    return num_classes, train_loaders, test_loaders

These are the other hyperparams I am using:

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(pre_model.parameters(), lr=0.001, momentum=0.9) #TODO
exp_lr_scheduler = lr_scheduler.CosineAnnealingLR(optimizer, eta_min=0.00001, verbose=True,  T_max = 30)

My training loop is also nothing out of the ordinary.
I am just very confused about what's going on with this training scheme.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.