Comments (8)
Hello @salbatron @sabrinazuraimi ,
I've also been working on this implementation for short while now, but specifically for the Triplet Network not the Siamese one. To train with your own data, I'd say the easiest way to get this working would be to:
a. Implement a custom Dataset class - specifically, inherit from the base Dataset class and implement the __len__
and __getitem__
methods as described in the docs . When that's done, the DataLoaders should then be instantiated using your custom Dataset class.
b. It is also necessary to implement a BalancedBatchSampler class as well, or simply edit the base version from the repo. This is because in the source code, within the constructor of the BalancedBatchSampler, it is apparently expected that the Dataset class/implementation used has a property called train_labels
or test_labels
, the contents of which are stored in the labels
property of the BalancedBatchSampler. Therefore your Dataset implementation will have to include some way of getting the labels of the instances therein, and you'll need to edit the BalancedBatchSampler to populate its labels property using your chosen method/property.
c. When all that;s done, the already-implemented Triplet selections should work out of the box as far as I can tell.
I hope this was a little helpful to either of you guys.
Good luck!
from siamese-triplet.
class CUHK(Dataset):
def __init__(self,train_path,test_path,train):
# Transforms
self.to_tensor = transforms.ToTensor()
self.transform=None
# Read the csv file
self.train=train
if self.train:
self.train_data_info = pd.read_csv(train_path,header=None)
self.train_data =[]
print("printing train data length CUHK")
print(len(self.train_data_info.index))
for (i,j) in enumerate(np.asarray(self.train_data_info.iloc[:, 1])):
try :
self.train_data.append(self.to_tensor(Image.open("../data/CUHK3/"+j)))
except :
print(j)
self.train_data = torch.stack(self.train_data)
self.train_labels = np.asarray(self.train_data_info.iloc[:, 2])
self.train_labels = torch.from_numpy(self.train_labels)
self.train_data_len = len(self.train_data_info.index)
else :
self.test_data_info = pd.read_csv(test_path,header=None)
self.test_data =[]
for (i,j) in enumerate(np.asarray(self.test_data_info.iloc[:, 1])):
try :
self.test_data.append(self.to_tensor(Image.open("../data/CUHK3/"+j)))
except :
print(j)
self.test_data = torch.stack(self.test_data)
self.test_labels = np.asarray(self.test_data_info.iloc[:, 2])
self.test_labels = torch.from_numpy(self.test_labels)
self.test_data_len = len(self.test_data_info.index)
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
return (img,target)
def __len__(self):
if self.train :
return self.train_data_len
else :
return self.test_data_len
`
This class was used for training over the CUHK dataset , the only thing is that you've to create a csv file that contains path to your image files and labels .
class TripletCUHK(Dataset):
"""
Train: For each sample (anchor) randomly chooses a positive and negative samples
Test: Creates fixed triplets for testing
"""
def __init__(self, mnist_dataset):
self.mnist_dataset = mnist_dataset
self.train = self.mnist_dataset.train
self.transform = self.mnist_dataset.transform
if self.train:
self.train_labels = self.mnist_dataset.train_labels
self.train_data = self.mnist_dataset.train_data
self.labels_set = set(self.train_labels)
self.label_to_indices = {label.item(): np.where(np.asarray(self.train_labels) == label)[0]
for label in self.labels_set}
else:
self.test_labels = self.mnist_dataset.test_labels
self.test_data = self.mnist_dataset.test_data
# generate fixed triplets for testing
self.labels_set = set(self.test_labels)
self.label_to_indices = {label.item(): np.where(self.test_labels == label)[0]
for label in self.labels_set}
random_state = np.random.RandomState(29)
# print(self.label_to_indices)
triplets = []
for i in range(len(self.test_data)):
triplets.append([i,
random_state.choice(self.label_to_indices[self.test_labels[i].item()]),
random_state.choice(self.label_to_indices[
np.random.choice(
list(self.labels_set - set([self.test_labels[i].item()]))
)
])
])
self.test_triplets = triplets
def get_type(self):
print("test_labels type"+str(type(self.test_labels)))
def __getitem__(self, index):
if self.train:
img1, label1 = self.train_data[index], self.train_labels[index].item()
positive_index = index
while positive_index == index:
positive_index = np.random.choice(self.label_to_indices[label1])
negative_label = np.random.choice(list(self.labels_set - set([label1])))
negative_index = np.random.choice(self.label_to_indices[negative_label])
img2 = self.train_data[positive_index]
img3 = self.train_data[negative_index]
else:
img1 = self.test_data[self.test_triplets[index][0]]
img2 = self.test_data[self.test_triplets[index][1]]
img3 = self.test_data[self.test_triplets[index][2]]
# img1 = Image.fromarray(img1.numpy(), mode='RGB')
# img2 = Image.fromarray(img2.numpy(), mode='RGB')
# img3 = Image.fromarray(img3.numpy(), mode='RGB')
if self.transform is not None:
img1 = self.transform(img1)
img2 = self.transform(img2)
img3 = self.transform(img3)
return (img1, img2, img3), []
def __len__(self):
return len(self.mnist_dataset)
This is another class used for training over the triplet module . Similar logic .
from siamese-triplet.
I'm also trying to run this with my own image folder(which has a huge number of classes) but I'm having problems with the triplet selections. Anyone have any idea how to implement the triplet selection?
from siamese-triplet.
I was thinking of creating a pull request with a custom dataloader that i made with this repository if that could help some people .
from siamese-triplet.
@vishalgolcha
Would you mind share your code with us on how to train this project with our own data, which is just like train/1/001.jpg, train/1/002.jpg, ..., train/100/001.jpg, test/101/001.jpg, test/101/002.jpg, ..., test/110/001.jpg...
Thank you so much.
from siamese-triplet.
@vishalgolcha Thanks for own data custom generator :) I have some questions regarding this and sent email(hotmail) for you. Would you mind opening it?
from siamese-triplet.
@madarax64 explained it very well - thank you for this.
I refactored BalancedBatchSampler, now it should be easier to use with different datasets (it just needs labels in the same order as in the dataset.
All in all, what you need is (for the case with online triplet sampling):
- A Dataset class, so own implementation of
__len__
and__getitem__(index)
methods; this is PyTorch specific and you can use some default PyTorch loaders or customize them for your case (see post by @madarax64). See some examples of custom Datasets https://github.com/utkuozbulak/pytorch-custom-dataset-examples - BalancedBatchSampler that now takes a list of all labels in the dataset, where n-th label corresponds to n-th index in the Dataset class. The inefficient to get the list of labels would be
labels = [dataset[i][1] for i in range(len(dataset))]
, assuming that__getitem__(index)
returns a pair of image and a label. Most likely your dataset already stores that list and it can be accessed without processing the images.
from siamese-triplet.
Hi @vishalgolcha,
I need to do tests with the CIFAR100, but I can't make the function to generate the triplets and pairs of image work. Can you help me? I believe it changes little because this dataset also has in the pytorch library.
from siamese-triplet.
Related Issues (20)
- Citing Work
- how to classify
- Tripletloss finally returns to zero HOT 1
- #how could i reload the pre-trained
- #how could i reload the pre-trained parameters when I use siamese
- #how could i load the primary pre-trained params when I using the siamese module #
- utils.py:FunctionNegativeTripletSelector - 'anchor_positive' referenced before assignment when len(label_indices) < 2 HOT 6
- Implementation of Triplet loss on CASIA Web Face Dataset
- training
- References HOT 1
- BalancedBatchSampler: classes vs samples HOT 1
- How to use this siamese model for classification task?
- batch_szie
- batch_size
- datasets.py HOT 1
- ValueError when using BalancedBatchSampler
- I got this on cifar10 test set
- About generating all possible triplets using combinations() function HOT 1
- how tesT
- how should we use it to classify?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from siamese-triplet.