Code Monkey home page Code Monkey logo

bert4classification-ru's Introduction

[RU|EN]

Open In Colab

BERT для задачи классификации

Инструкция для быстрого старта
Бинарная классификация на текстовых данных из RuTweetCorp (https://study.mokoron.com/)
отрицательный: 0
положительный: 1

Зачем?

Когда я осваиваю какой-либо новый подход, обычно находятся статьии туториалы, в которых все чрезвычайно подробно, от первичной обработки данных, до построения кривых обучения. Мне же всегда хотелось быстро понять суть подходаы и сразу начать использовать имеющиеся наработки, а не мучительно разбираться с простыней чужого кода. Поэтому я решил сделать по возможности максимально простое и прозрачное решение, которое не будет перегружено лишним кодом, в котором можно легко и быстро разобраться.
Про BERT я писать ничего не буду - про него полно отличных статей, так что просто будем использовать его в качестве черного ящика.

Структура

Данные для обучения

Используются очищенные данные русскоязычного твиттера длинее 100 символов.
RuTweetCorp (https://study.mokoron.com/)

CustomDataset

Класс CustomDataset необходим для использования с библиотекой transformers. Наследуется от класса Dataset. В нем определяются 3 обязательные функции: init, len, getitem. основное предназначение - возвращает токенизированные данные в нужном формате.

Initialize

При инициализации классификатора выполняются следующие действия:

  • Скачиваются модель и токенизатор из репозитория huggingface;
  • Определяется наличие целевого устройства для вычислений;
  • Определяется размерность ембеддингов;
  • Задается количество классов;
  • Задается количество эпох для обучения.

Preparation

Для обучения BERT нужно инициализировать несколько вспомогательных элементов:

  • DataLoader: нужен для создания батчей;
  • Optimizer: оптимизатор градиентного спуска;
  • Scheduler: планировщик, нужен для настройки параметров оптимизатора;
  • Loss: функция потерь, считаем по ней ошибку модели.

Train

  • Обучение для одной эпохи описано в методе fit.
    • Данные в цикле батчами генерируются с помощью DataLoader;
    • Батч подается в модель;
    • На выходе получаем распределение вероятности по классам и значение ошибки;
    • Делаем шаг на всех вспомогательных функциях:
      • loss.backward: обратное распространение ошибки;
      • clip_grad_norm: обрезаем градиенты для предотвращения "взрыва" градиентов;
      • optimizer.step: шаг оптимизатора;
      • scheduler.step: шаг планировщика;
      • optimizer.zero_grad: обнуляем градиенты.
  • Проверку на валидационной выборке проводим с помощью метода eval. При этом используем метод torch.no_grad для предотвращения обучения на валидационной выборке.
  • Для обучения на нескольких эпохах используется метод train, в котором последовательно вызываются методы fit и eval.

Inference

Для предсказания класса для нового текста используется метод predict, который имеет смысл вызывать только после обучения модели.
Метод работает следующим образом:

  • Токенизируется входной текст;
  • Токенизированный текст подается в модель;
  • На выходе получаем вероятности классов;
  • Возвращаем метку наиболее вероятного класса.

Заключение

Хотелось максимально просто, но все равно получилось как-то объемно. Прошу понять и простить. Пис!

bert4classification-ru's People

Contributors

shitkov avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.