Code Monkey home page Code Monkey logo

Comments (8)

madarax64 avatar madarax64 commented on May 26, 2024 6

Hello @salbatron @sabrinazuraimi ,
I've also been working on this implementation for short while now, but specifically for the Triplet Network not the Siamese one. To train with your own data, I'd say the easiest way to get this working would be to:

a. Implement a custom Dataset class - specifically, inherit from the base Dataset class and implement the __len__ and __getitem__ methods as described in the docs . When that's done, the DataLoaders should then be instantiated using your custom Dataset class.

b. It is also necessary to implement a BalancedBatchSampler class as well, or simply edit the base version from the repo. This is because in the source code, within the constructor of the BalancedBatchSampler, it is apparently expected that the Dataset class/implementation used has a property called train_labels or test_labels, the contents of which are stored in the labels property of the BalancedBatchSampler. Therefore your Dataset implementation will have to include some way of getting the labels of the instances therein, and you'll need to edit the BalancedBatchSampler to populate its labels property using your chosen method/property.

c. When all that;s done, the already-implemented Triplet selections should work out of the box as far as I can tell.

I hope this was a little helpful to either of you guys.

Good luck!

from siamese-triplet.

vishalgolcha avatar vishalgolcha commented on May 26, 2024 3
class CUHK(Dataset):
    def __init__(self,train_path,test_path,train):
        # Transforms
        self.to_tensor = transforms.ToTensor()
        self.transform=None
        # Read the csv file
        self.train=train
        if self.train:
            self.train_data_info = pd.read_csv(train_path,header=None)
            self.train_data =[] 
            
            print("printing train data length CUHK")
            print(len(self.train_data_info.index))

            for (i,j) in enumerate(np.asarray(self.train_data_info.iloc[:, 1])):
                try : 
                    self.train_data.append(self.to_tensor(Image.open("../data/CUHK3/"+j))) 
                except : 
                    print(j)
            

            self.train_data = torch.stack(self.train_data)
            self.train_labels = np.asarray(self.train_data_info.iloc[:, 2])
            self.train_labels = torch.from_numpy(self.train_labels)

            self.train_data_len = len(self.train_data_info.index)

        else :
            self.test_data_info = pd.read_csv(test_path,header=None)
            self.test_data =[] 
            for (i,j) in enumerate(np.asarray(self.test_data_info.iloc[:, 1])):
                try : 
                    self.test_data.append(self.to_tensor(Image.open("../data/CUHK3/"+j))) 
                except : 
                    print(j)  

            self.test_data = torch.stack(self.test_data)
            self.test_labels = np.asarray(self.test_data_info.iloc[:, 2])
            self.test_labels = torch.from_numpy(self.test_labels)
            
            self.test_data_len = len(self.test_data_info.index)
            

    def __getitem__(self, index):
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        return (img,target)

    def __len__(self):
        if self.train :
            return self.train_data_len
        else :
            return self.test_data_len
`

This class was used for training over the CUHK dataset , the only thing is that you've to create a csv file that contains path to your image files and labels .

class TripletCUHK(Dataset):
    """
    Train: For each sample (anchor) randomly chooses a positive and negative samples
    Test: Creates fixed triplets for testing
    """

    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.train_labels = self.mnist_dataset.train_labels
            self.train_data = self.mnist_dataset.train_data
            self.labels_set = set(self.train_labels)
            self.label_to_indices = {label.item(): np.where(np.asarray(self.train_labels) == label)[0]
                                     for label in self.labels_set}

        else:
            self.test_labels = self.mnist_dataset.test_labels
            self.test_data = self.mnist_dataset.test_data
            # generate fixed triplets for testing
            self.labels_set = set(self.test_labels)
                        
            self.label_to_indices = {label.item(): np.where(self.test_labels == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            # print(self.label_to_indices)

            triplets = []
            for i in range(len(self.test_data)):
                    triplets.append([i,
                                random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
                                random_state.choice(self.label_to_indices[
                                                        np.random.choice(
                                                            list(self.labels_set - set([self.test_labels[i].item()]))
                                                        )
                                                    ])
                                ])
            self.test_triplets = triplets

    def get_type(self):
        print("test_labels type"+str(type(self.test_labels)))
    
    def __getitem__(self, index):
        if self.train:
            img1, label1 = self.train_data[index], self.train_labels[index].item()
            positive_index = index
            while positive_index == index:
                positive_index = np.random.choice(self.label_to_indices[label1])
            negative_label = np.random.choice(list(self.labels_set - set([label1])))
            negative_index = np.random.choice(self.label_to_indices[negative_label])
            img2 = self.train_data[positive_index]
            img3 = self.train_data[negative_index]
        else:
            img1 = self.test_data[self.test_triplets[index][0]]
            img2 = self.test_data[self.test_triplets[index][1]]
            img3 = self.test_data[self.test_triplets[index][2]]

        # img1 = Image.fromarray(img1.numpy(), mode='RGB')
        # img2 = Image.fromarray(img2.numpy(), mode='RGB')
        # img3 = Image.fromarray(img3.numpy(), mode='RGB')

        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)

        return (img1, img2, img3), []

    def __len__(self):
        return len(self.mnist_dataset)

This is another class used for training over the triplet module . Similar logic .

from siamese-triplet.

sabrinazuraimi avatar sabrinazuraimi commented on May 26, 2024

I'm also trying to run this with my own image folder(which has a huge number of classes) but I'm having problems with the triplet selections. Anyone have any idea how to implement the triplet selection?

from siamese-triplet.

vishalgolcha avatar vishalgolcha commented on May 26, 2024

I was thinking of creating a pull request with a custom dataloader that i made with this repository if that could help some people .

from siamese-triplet.

WillDamon avatar WillDamon commented on May 26, 2024

@vishalgolcha
Would you mind share your code with us on how to train this project with our own data, which is just like train/1/001.jpg, train/1/002.jpg, ..., train/100/001.jpg, test/101/001.jpg, test/101/002.jpg, ..., test/110/001.jpg...
Thank you so much.

from siamese-triplet.

Ella77 avatar Ella77 commented on May 26, 2024

@vishalgolcha Thanks for own data custom generator :) I have some questions regarding this and sent email(hotmail) for you. Would you mind opening it?

from siamese-triplet.

adambielski avatar adambielski commented on May 26, 2024

@madarax64 explained it very well - thank you for this.
I refactored BalancedBatchSampler, now it should be easier to use with different datasets (it just needs labels in the same order as in the dataset.

All in all, what you need is (for the case with online triplet sampling):

  1. A Dataset class, so own implementation of __len__ and __getitem__(index) methods; this is PyTorch specific and you can use some default PyTorch loaders or customize them for your case (see post by @madarax64). See some examples of custom Datasets https://github.com/utkuozbulak/pytorch-custom-dataset-examples
  2. BalancedBatchSampler that now takes a list of all labels in the dataset, where n-th label corresponds to n-th index in the Dataset class. The inefficient to get the list of labels would be labels = [dataset[i][1] for i in range(len(dataset))], assuming that __getitem__(index) returns a pair of image and a label. Most likely your dataset already stores that list and it can be accessed without processing the images.

from siamese-triplet.

josianerodrigues avatar josianerodrigues commented on May 26, 2024

Hi @vishalgolcha,

I need to do tests with the CIFAR100, but I can't make the function to generate the triplets and pairs of image work. Can you help me? I believe it changes little because this dataset also has in the pytorch library.

from siamese-triplet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.