Code Monkey home page Code Monkey logo

pytorch-metric-learning-template's Introduction

pytorch-metric-learning-template

基于pytorch-metric-learning开源工具,实现了包括模型训练、模型验证、模型推理的相关代码。


模型训练

  • 使用cifar10数据集训练模型
python train_embedding_model_cifar10.py
  • 使用cifar100数据集训练模型
python train_embedding_model_cifar100.py
  • 使用flower花朵数据集训练模型,下载数据并解压后放在datasets目录下。flower花朵数据集见(链接:https://pan.baidu.com/s/1TfzLYZrkfwLy8wShy7nyMA 提取码:gxei)
    • 修改./config/embedding.yaml配置文件里面的train_dataset_dir指向数据集的位置
python train_embedding_model.py
  • 使用pytorch-metric-learning提供的API训练模型
python trainer_model.py
  • results目录下提供了几个训练好的模型文件
    • model_cifar10_SupervisedContrastiveLoss.pth 使用SupervisedContrastiveLoss在cifar10上训练的模型
    • model_cifar100_CircleLoss.pth 使用CircleLoss在cifar100上训练的模型
    • model_flower_photos_SupervisedContrastiveLoss.pth 使用SupervisedContrastiveLoss在flower花朵数据集上训练的模型

备注:上面这些模型使用的损失函数可以通过模型的名字得到,训练过程中做了embedding归一化,使用余弦相似度计算特征之间的距离

模型推理

使用训练好的模型,以及pytorch-metric-learning工具提供的接口进行模型推理。

python model_inference.py

Embedding特征提取

使用训练好的模型,将读入的数据转化为embedding特征。

python feature_extraction.py

Embedding特征可视化

使用训练好的模型,将读入的数据转化为embedding特征,并对embedding降维后可视化。

python visualizer.py

自定义训练数据

将数据按照类别ID存放在不同的目录中,具体格式可以参考flower花朵数据集那样。

模型效果展示

  • cifar10
embedding特征之间的相似度可视化 embedding特征降维之后可视化
  • cifar100
embedding特征之间的相似度可视化 embedding特征降维之后可视化
  • 花朵数据集(共5类)
embedding特征之间的相似度可视化 embedding特征降维之后可视化


度量学习相关的损失函数介绍:


基于度量学习方法实现音乐特征匹配的系列文章

pytorch-metric-learning-template's People

Contributors

xxcheng0708 avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.