Code Monkey home page Code Monkey logo

customimagefolder's Introduction

CustomImageFolder

pytorchで画像分類タスクを行うときに、不正解データの画像ファイルを出力したい!
ImageFolder クラスを使っていて、ファイルパスも返すように DataLoader を調整するために、カスタムの Dataset クラスを作成します.

from torchvision.datasets import ImageFolder

class CustomImageFolder(ImageFolder):
    def __getitem__(self, index):
        # 元の ImageFolder クラスの __getitem__ メソッドを呼び出す
        original_tuple = super(CustomImageFolder, self).__getitem__(index)
        # ファイルパスを取得
        path = self.imgs[index][0]
        # 画像データ、ラベル、ファイルパスを含むタプルを返す
        return (original_tuple + (path,))

そして、この CustomImageFolder クラスを使用してデータセットを作成します.

# カスタムデータセットのインスタンスを作成
image_datasets = {
    x: CustomImageFolder(os.path.join(data_dir, x), data_transforms[x])
    for x in ['train', 'val']
}

# DataLoader の作成
dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                   shuffle=True, num_workers=4, pin_memory=True)
    for x in ['train', 'val']
}

これで、DataLoader から得られる各バッチは (inputs, labels, paths) の形式になります。
これを使って、train_model 関数内で不正解の画像のファイルパスを取得して出力することができます.

# 検証フェーズ
if phase == 'val':
    # ...
    for inputs, labels, paths in dataloaders['val']:
        inputs = inputs.to(device)
        labels = labels.to(device)
        # ...
        incorrect_indices = (preds != labels).nonzero(as_tuple=True)[0]
        for idx in incorrect_indices:
            incorrect_samples.append(paths[idx])
    # ...

# 検証終了後に不正解データのファイル名の出力
print("Incorrect samples:")
for file_path in incorrect_samples:
    print(file_path)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.