Code Monkey home page Code Monkey logo

conditional-similarity-network-mnist's Introduction

Conditional-Similarity-Network-MNIST

This is a toy example of Conditional Similarity Networks on MNIST dataset. It is based on a paper named "Conditional Similarity Networks" written by A. Veit, S. Belongie and T. Karaletsos.

Overview

In this paper, they proposed a network named "Conditional Similarity Network" to measure the similarity between images having various attributes. The network consists of two partial networks. One is Convolutional Network to extract features from an image, and the other is a set of mask, each of which is corresponding to one attribute(color, shape, category or something), and works as an element-wise gating function selecting relevant features of the attributes from a feature vector. The most important thing is that the set of mask is also trainable. That is, the masks learn by themselves what features are actually need to distinguish images with respect to the corresponding attributes. If we apply a mask on the feature vector, then we can measure conditional similarities of images.

The characteristics of the network

It is based on deep metric learning. Deep metric learning is to train a network that maps similar input data to similar feature vectors. It means that deep metric network embeds input data in high dimensional space into low dimensional space, conserving the metric between the data. The most usual way to implement deep metric leaning is Triplet network. First, we pick a input x, calle an anchor. Then We choose a positive sample x+, which is similar to x(for example, x+ is in the same category with x), and a negative sample x-, which is not similar to x. Now, we construct a 3 parallel networks, each of which has the same weights with the others, and feed a pair of inputs (x, x+, x-) into the networks. Then, we measure the distance between the anchor output and the positive output(d+) and the distance between the anchor output and the negative output(d-). We want d+ to be small and d- to be large. So we use a hinge loss = max(0, (d+)-(d-)+margin) as our objective loss function.

According to the paper, Conditional Similarity Network works better to learn multiple similarities, which shares some features between, than standard Triplet network.

Question

The network has better performance on learning multiple similarities that correlate each other. In this case, some masks for the similarities are activated on the same indices. Then what if the similarities is unrelated? For example, there is a digit image which has a font color. It has two attributes, digit and color. However there is no relation between two. It means that, in feature vector, some dimensions are representing color attribute, and some dimensions are representing digit attribute, but there is no index that represent both color and digit. I wanted to experimentally show that if I trained the network with unrelated features, the mask of each feature does not share indecies with the other masks.

Experiment setting

  1. Data Set

I made a new data set from MNIST dataset. First, I picked rgb values from [0~200](to avoid a letter to be white) randomly per iamge, and add color on a fixel whose greyscale is nonzero. Finally I got a image whose backgraound is black and digit is colored. Because my computer is super slow, I just only use 5000 images from MNIST dataset. I assigned 50% of the images to traning set, 30% to validation set, 10% to test set.

  1. Network structure

I used Lenet as the encoder of the network, 2 convolutional layers followed by 2 dense layers. Output dimension of the encoder is 20. I used two masks for the attributes, color and digit. So the total dimension of the masks are [20, 2]. I used deep metric learning mentioned above.

  1. Training

I used AdamOptimizer as the optimizer of the network. The learning rate was 1e-3. I set the batch size 100. For each minibatch, I picked 100 triplets randomly from training dataset. In case of color, i choose a positive sample whose rgb is closer to that of an anchor, and a negative smaple whose rgb is far from that of an anchor. In case of digit, i choose a positive sample whose digit is the same as that of an anchor, and a negative smaple whose digit is diffrent from that of an anchor. I used hinge loss as loss function.

  1. Test

To test the model, I picked an anchor from the test set. And then, I passed all the inputs in test set into the network. Finally I measured the distance between anchor output and the others. I choose 10 closest inputs from the anchor.

Result

At first, i ran the training with 2000 iteration. The result is as follows. The first box means anchor, and the second box represents 10 closest samples.


Case 1. color


Case 2. digit

It seems the network works well. Hovever, there is something strange. Look at the result of case 1. The positive samples are more likely to have the same digit as the anchor. 50% of the sample are 8!. It implies that color mask and digit mask are correlated. Actually, the coefficients of two masks are as follows.

color digit
1.3572024e+00 9.2015648e-01
-3.4946132e-01 4.3405625e-01
1.0715414e+00 5.0198692e-01
2.5438932e-01 9.0543813e-01
9.8532408e-01 9.0611883e-02
1.2879860e+00 -1.5302801e-01
3.0987355e-01 1.0871087e+00
7.4810916e-01 1.9085248e+00
1.4328172e+00 4.6207100e-01
1.5345807e+00 9.2612886e-01
1.0465456e+00 2.1864297e+00
-8.5371196e-02 1.8422171e-01
2.2996366e-03 3.0973017e-01
-2.6021469e-01 9.8107708e-01
9.2680907e-01 6.4489403e-05
1.5980076e+00 3.8404701e-07
8.7253183e-01 1.9499277e-01
1.8782541e+00 1.8202764e-01
4.7374862e-01 2.1356693e-06
9.1653597e-01 1.3017328e+00

You can see that color and digit attributes shares 1th row, 11th row and 20th row.

Does it help to increase the number of iteration? I re-trained the network with 5000 of iterations. The result is as follows.


Case 1. color


Case 2. digit

Look at the first case. We can see the digits with equal probability. It seems that the masks do not share a variable. However, as you can see below, the masks shares some rows.

color digit
3.1746492e-01 -3.0449569e-02
9.7776592e-01 1.9820638e-01
3.0899182e-01 1.2614999e+00
5.0264919e-01 4.7458627e-04
-2.0565987e-02 1.0122026e+00
1.2532429e+00 -5.0570053e-04
1.1761242e+00 7.0249259e-01
4.3930151e-03 1.7211841e+00
1.5434515e+00 1.0000439e-03
7.5177276e-01 1.2705497e+00
9.5747131e-01 1.0597352e+00
2.0184386e-01 1.1644945e+00
-3.3016729e-01 7.8291142e-01
2.8982167e-05 1.6994121e+00
-2.2581347e-12 1.5167643e+00
1.5507573e-01 1.5918788e+00
1.8356254e+00 -6.9651306e-02
1.5352018e+00 7.4336982e-01
3.6188364e-01 8.4463215e-01
5.5576164e-01 7.6898432e-01

Conclusion

I expected that the activated indices of two masks should not coincide. However, even after 5000 iterations, it does not happen. The paper said that the networks is good for training multiple attributes simultaneously. However, the result of experiment does not support what the paper said.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.