Code Monkey home page Code Monkey logo

lstm-siamese-text-similarity's Introduction

Text Similarity Using Siamese Deep Neural Network

Siamese neural network is a class of neural network architectures that contain two or more identical subnetworks. identical here means they have the same configuration with the same parameters and weights. Parameter updating is mirrored across both subnetworks.

It is a keras based implementation of deep siamese Bidirectional LSTM network to capture phrase/sentence similarity using word embeddings.

Below is the architecture description for the same.

rch_imag

Install dependencies

pip install -r requirements.txt

Usage

Training

from model import SiameseBiLSTM
from inputHandler import word_embed_meta_data, create_test_data
from config import siamese_config
import pandas as pd

############ Data Preperation ##########

df = pd.read_csv('sample_data.csv')

sentences1 = list(df['sentences1'])
sentences2 = list(df['sentences2'])
is_similar = list(df['is_similar'])
del df

######## Word Embedding ############

tokenizer, embedding_matrix = word_embed_meta_data(sentences1 + sentences2,  siamese_config['EMBEDDING_DIM'])

embedding_meta_data = {
	'tokenizer': tokenizer,
	'embedding_matrix': embedding_matrix
}

## creating sentence pairs
sentences_pair = [(x1, x2) for x1, x2 in zip(sentences1, sentences2)]
del sentences1
del sentences2

######## Training ########

class Configuration(object):
    """Dump stuff here"""

CONFIG = Configuration()

CONFIG.embedding_dim = siamese_config['EMBEDDING_DIM']
CONFIG.max_sequence_length = siamese_config['MAX_SEQUENCE_LENGTH']
CONFIG.number_lstm_units = siamese_config['NUMBER_LSTM']
CONFIG.rate_drop_lstm = siamese_config['RATE_DROP_LSTM']
CONFIG.number_dense_units = siamese_config['NUMBER_DENSE_UNITS']
CONFIG.activation_function = siamese_config['ACTIVATION_FUNCTION']
CONFIG.rate_drop_dense = siamese_config['RATE_DROP_DENSE']
CONFIG.validation_split_ratio = siamese_config['VALIDATION_SPLIT']

siamese = SiameseBiLSTM(CONFIG.embedding_dim , CONFIG.max_sequence_length, CONFIG.number_lstm_units , CONFIG.number_dense_units, CONFIG.rate_drop_lstm, CONFIG.rate_drop_dense, CONFIG.activation_function, CONFIG.validation_split_ratio)

best_model_path = siamese.train_model(sentences_pair, is_similar, embedding_meta_data, model_save_directory='./')

Testing

from operator import itemgetter
from keras.models import load_model

model = load_model(best_model_path)

test_sentence_pairs = [('What can make Physics easy to learn?','How can you make physics easy to learn?'),('How many times a day do a clocks hands overlap?','What does it mean that every time I look at the clock the numbers are the same?')]

test_data_x1, test_data_x2, leaks_test = create_test_data(tokenizer,test_sentence_pairs,  siamese_config['MAX_SEQUENCE_LENGTH'])

preds = list(model.predict([test_data_x1, test_data_x2, leaks_test], verbose=1).ravel())
results = [(x, y, z) for (x, y), z in zip(test_sentence_pairs, preds)]
results.sort(key=itemgetter(2), reverse=True)
print results

References:

  1. Siamese Recurrent Architectures for Learning Sentence Similarity (2016)
  2. Inspired from Tensorflow Implementation of https://github.com/dhwajraj/deep-siamese-text-similarity

lstm-siamese-text-similarity's People

Contributors

amanhaptik avatar amansrivastava17 avatar dependabot[bot] avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lstm-siamese-text-similarity's Issues

Different prediction when changing args order

I noticed this strange behaviour on my model. The results of doing:
preds = model.predict([test_data_x1, test_data_x2, leaks_test])
are different from:
preds = model.predict([test_data_x2, test_data_x1, leaks_test])
when the only thing that changes is the order of the arguments.
Now I understand that the two networks can have different weights. But is there a way to understand when it's working properly and when it's not?

I have the same behaviour when using another embedding (sentence-bert), for this I slightly modified the network like this:

    # define inputs
    input1 = tf.keras.Input(shape=shape)
    input2 = tf.keras.Input(shape=shape)

    # Creating LSTM Encoder
    lstm_layer = Bidirectional(LSTM(number_lstm_units, 
                                    dropout=rate_drop_lstm, 
                                    recurrent_dropout=rate_drop_lstm))

    x1 = lstm_layer(input1)
    x2 = lstm_layer(input2)
    merged = concatenate([x1, x2])
    merged = BatchNormalization()(merged)
    merged = Dropout(rate_drop_dense)(merged)

    merged = Dense(number_dense_units, activation=activation_function)(merged)
    merged = BatchNormalization()(merged)
    merged = Dropout(rate_drop_dense)(merged)
    preds = Dense(1, activation='sigmoid')(merged)

    model = Model(inputs=[input1, input2], outputs=preds)
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['acc']) #nadam

Do you have some suggestions on possible strategies to solve this?
Thanks a lot

What are the "leaks features"?

I was trying to understand your code and I couldn't figured out what are "leaks features". I've read some papers about siamese text similarity and now I'm trying to implement character level or phrase based text similarity model.

Why did the training stop early

Epoch 8/200

64/450 [===>..........................] - ETA: 0s - loss: 0.8007 - acc: 0.5781
192/450 [===========>..................] - ETA: 0s - loss: 0.7411 - acc: 0.5781
320/450 [====================>.........] - ETA: 0s - loss: 0.7564 - acc: 0.5656
448/450 [============================>.] - ETA: 0s - loss: 0.7578 - acc: 0.5491
450/450 [==============================] - 0s 813us/step - loss: 0.7569 - acc: 0.5489 - val_loss: 0.8013 - val_acc: 0.4490

it stopped!

TypeError: list indices must be integers or slices, not str

Traceback (most recent call last):
File "lstm-siamese-text-similarity/controller.py", line 64, in
best_model_path = siamese.train_model(sentences_pair, is_similar, embedding_meta_data)
File "lstm-siamese-text-similarity/model.py", line 49, in train_model
tokenizer, embedding_matrix = embedding_meta_data['tokenizer'], embedding_meta_data['embedding_matrix']
TypeError: list indices must be integers or slices, not str

I get the following error, I am unable to debug

What are the "Leaks features" and how is the accuracy measured?

Hi, first thanks for your repo! I have two questions:

1- I understand that "leaks features" are like common characters between sentences in each pair or something like this? I I'm trying to train the model with vector similarity directly instead of sentence similarity, do you guess I can remove the leaks part from the model?

2- How is the accuracy measured? I guess that the sigmoid is outputting a number between 0 and 1, so how is the accuracy measured, considering 1 if the value is >0.5 and 0 else?

Thanks!

NameError: name 'activation_function' is not defined

Getting Error Activation_function not Defined while trying to run below code

siamese = SiameseBiLSTM(CONFIG.embedding_dim , CONFIG.max_sequence_length, CONFIG.number_lstm_units , CONFIG.number_dense_units, CONFIG.rate_drop_lstm, CONFIG.rate_drop_dense, CONFIG.activation_function, CONFIG.validation_split_ratio)

best_model_path = siamese.train_model(sentences_pair, is_similar, embedding_meta_data, model_save_directory='./')

Why is the accuracy of the training set lower than the validation set?

Hi, I use sample_data.csv to train the model,but i get lower behavior on training data than validation set. I am confused.

Epoch 1/200
450/450 [==============================] - 5s 11ms/step - loss: 0.8987 - acc: 0.4933 - val_loss: 0.7807 - val_acc: 0.4286
Epoch 2/200
450/450 [==============================] - 1s 2ms/step - loss: 0.7921 - acc: 0.5356 - val_loss: 0.6995 - val_acc: 0.5306
Epoch 3/200
450/450 [==============================] - 1s 2ms/step - loss: 0.7451 - acc: 0.5644 - val_loss: 0.6261 - val_acc: 0.5918
Epoch 4/200
450/450 [==============================] - 1s 2ms/step - loss: 0.6697 - acc: 0.6178 - val_loss: 0.5605 - val_acc: 0.7143
Epoch 5/200

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.