Code Monkey home page Code Monkey logo

tasknet's Introduction

tasknet : simple multi-task Trainer and HuggingFace datasets.

tasknet is an interface between Huggingface datasets and Huggingface transformers Trainer.

Tasknet should work with all recent versions of Transformers.

Task templates

tasknet relies on task templates to avoid boilerplate codes. The task templates correspond to Transformers AutoClasses:

  • SequenceClassification
  • TokenClassification
  • MultipleChoice
  • Seq2SeqLM (experimental support)

The task templates follow the same interface. They implement preprocess_function, a data collator and compute_metrics. Look at tasks.py and use existing templates as a starting point to implement a custom task template.

Installation and example

pip install tasknet

Each task template has fields that should be matched with specific dataset columns. Classification has two text fields s1,s2, and a label y. Pass a dataset to a template, and fill in the mapping between the template fields and the dataset columns to instantiate a task.

import tasknet as tn; from datasets import load_dataset

rte = tn.Classification(
    dataset=load_dataset("glue", "rte"),
    s1="sentence1", s2="sentence2", y="label") #s2 is optional for classification, used to represent text pairs
 # See AutoTask for shorter code

class hparams:
  # model_name='microsoft/deberta-v3-base' # deberta models have the best results (and tasknet support)
  model_name = 'sileod/deberta-v3-base-tasksource-nli' # better performance for most tasks
  learning_rate = 3e-5 # see hf.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments
 
model, trainer = tn.Model_Trainer(tasks=[rte],hparams)
trainer.train(), trainer.evaluate()
p = trainer.pipeline()
p([{'text':'premise here','text_pair': 'hypothesis here'}]) # HuggingFace pipeline for inference

Tasknet is multitask by design. model.task_models_list contains one model per task, with a shared encoder.

AutoTask

You can also leverage tasksource with tn.AutoTask and have one-line access to 600+ datasets, see implemented tasks.

rte = tn.AutoTask("glue/rte", nrows=5000)

AutoTask guesses a template based on the dataset structure. It also accepts a dataset as input, if it fits the template (e.g. after tasksource custom preprocessing).

Balancing dataset sizes

tn.Classification(dataset, nrows=5000, nrows_eval=500, oversampling=2)

You can balance multiple datasets with nrows and oversampling. nrows is the maximal number of examples. If a dataset has less than nrows, it will be oversampled at most oversampling times.

Colab examples

Minimal-ish example:

https://colab.research.google.com/drive/15Xf4Bgs3itUmok7XlAK6EEquNbvjD9BD?usp=sharing

More complex example, where tasknet was scaled to 600 tasks:

https://colab.research.google.com/drive/1iB4Oxl9_B5W3ZDzXoWJN-olUbqLBxgQS?usp=sharing

tasknet vs jiant

jiant is another library comparable to tasknet. tasknet is a minimal extension of Trainer centered on task templates, while jiant builds a Trainer equivalent from scratch called runner. tasknet is leaner and closer to Huggingface native tools. Jiant is config-based and command line focused while tasknet is designed for interactive use and python scripting.

Credit

This code uses some part of the examples of the transformers library and some code from multitask-learning-transformers.

Contact

You can request features on github or reach me at [email protected]

@misc{sileod22-tasknet,
  author = {Sileo, Damien},
  doi = {10.5281/zenodo.561225781},
  month = {11},
  title = {{tasknet, multitask interface between Trainer and datasets}},
  url = {https://github.com/sileod/tasknet},
  version = {1.5.0},
  year = {2022}}

tasknet's People

Contributors

sileod avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

tasknet's Issues

Unable to load saved model

Hello, sorry im quite new to writing issues.
I trained a joint token classification and sequence classification model. To save it i used this:
trainer.save_model("multi_task/")
However trying to load the same model, i faced this issue
model_2 = tn.load_pipeline("/kaggle/working/multi_task","intent_classification", adapt_task_embedding=True)

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:1                                                                                    │
│                                                                                                  │
│ ❱ 1 model_2 =  tn.load_pipeline("/kaggle/working/multi_task","intent_classification", adapt_     │
│   2                                                                                              │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/tasknet/utils.py:198 in load_pipeline                    │
│                                                                                                  │
│   195 │   │   import tasksource                                                                  │
│   196except:                                                                                │
│   197 │   │   raise ImportError("Requires tasksource.\n pip install tasksource")                 │
│ ❱ 198task = tasksource.load_task(task_name, multilingual=multilingual)                      │
│   199 │                                                                                          │
│   200model = AutoModelForSequenceClassification.from_pretrained(                            │
│   201 │   │   model_name, ignore_mismatched_sizes=True                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/tasksource/access.py:102 in load_task                    │
│                                                                                                  │
│    99query = dict_of(id, dataset_name, config_name, task_name,preprocessing_name)           │
│   100query = {k:v for k,v in query.items() if v}                                            │
│   101_tasks = (lmtasks if multilingual else tasks)                                          │
│ ❱ 102preprocessing = load_preprocessing(_tasks, **query)                                    │
│   103dataset = load_dataset(preprocessing.dataset_name, preprocessing.config_name, **load   │
│   104dataset= preprocessing(dataset,max_rows, max_rows_eval)                                │
│   105dataset.task_type = preprocessing.__class__.__name__                                   │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/tasksource/access.py:90 in load_preprocessing            │
│                                                                                                  │
│    87                                                                                            │
│    88 def load_preprocessing(tasks=tasks, **kwargs):                                             │
│    89_tasks_df = list_tasks(multilingual=tasks==lmtasks)                                    │
│ ❱  90y = _tasks_df.copy().query(dict_to_query(**kwargs)).iloc[0]                            │
│    91preprocessing= copy.copy(getattr(tasks, y.preprocessing_name))                         │
│    92for c in 'dataset_name','config_name':                                                 │
│    93 │   │   if not isinstance(getattr(preprocessing,c), str):                                  │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/pandas/core/indexing.py:1073 in __getitem__              │
│                                                                                                  │
│   1070 │   │   │   axis = self.axis or 0                                                         │
│   1071 │   │   │                                                                                 │
│   1072 │   │   │   maybe_callable = com.apply_if_callable(key, self.obj)                         │
│ ❱ 1073 │   │   │   return self._getitem_axis(maybe_callable, axis=axis)                          │
│   1074 │                                                                                         │
│   1075def _is_scalar_access(self, key: tuple):                                              │
│   1076 │   │   raise NotImplementedError()                                                       │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/pandas/core/indexing.py:1625 in _getitem_axis            │
│                                                                                                  │
│   1622 │   │   │   │   raise TypeError("Cannot index by location index with a non-integer key")  │
│   1623 │   │   │                                                                                 │
│   1624 │   │   │   # validate the location                                                       │
│ ❱ 1625 │   │   │   self._validate_integer(key, axis)                                             │
│   1626 │   │   │                                                                                 │
│   1627 │   │   │   return self.obj._ixs(key, axis=axis)                                          │
│   1628                                                                                           │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/pandas/core/indexing.py:1557 in _validate_integer        │
│                                                                                                  │
│   1554 │   │   """                                                                               │
│   1555 │   │   len_axis = len(self.obj._get_axis(axis))                                          │
│   1556 │   │   if key >= len_axis or key < -len_axis:                                            │
│ ❱ 1557 │   │   │   raise IndexError("single positional indexer is out-of-bounds")                │
│   1558 │                                                                                         │
│   1559# -------------------------------------------------------------------                 │1560                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
IndexError: single positional indexer is out-of-bounds

I intend to load the model from this checkpoint for domain adaption
How do i address this issue? Also if you could redirect me to some resources to understand adapters that would be pretty helpful (i found out from the other issue posted here).
PS: Thank you for this incredible library- it saved me a lot of time

Why expect Z in Adapter?

The class Adapter expects Z in constructor:

class Adapter(transformers.PreTrainedModel):
    config_class = transformers.PretrainedConfig
    def __init__(self, config, classifiers=None, Z=None, labels_list=[]):
        super().__init__(config)    
        self.Z= torch.nn.Embedding(len(config.classifiers_size),config.hidden_size, max_norm=1.0).weight if Z==None else Z
        self.classifiers=torch.nn.ModuleList(
            [torch.nn.Linear(config.hidden_size,size) for size in config.classifiers_size]
        ) if classifiers==None else classifiers
        self.config=self.config.from_dict(
            {**self.config.to_dict(),
            'labels_list':labels_list}
        )
    def adapt_model_to_task(self, model, task_name):
        task_index=self.config.tasks.index(task_name)
        #setattr(model,search_module(model,'linear',mode='class')[-1], self.classifiers[task_index])
        model.classifier=self.classifiers[task_index]
        return model
    def _init_weights(*args):
        pass 

but doesn't use it at all when adapting model to task?

How to save and load a tasknet model?

Hi! I tried the basic 3-task example from the README file, and the training worked fine. Then I tried to save and load the model:

Saving the model worked ok:

trainer.save_model("tasknet-model")

But loading the model gives an error:

loaded = tn.Model.from_pretrained('./tasknet-model')
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[16], line 1
----> 1 loaded = tn.Model.from_pretrained('./tasknet-model')

File ~/projects/keha/Tekoaly/trials/skillrecommendation-language-model/venv/lib/python3.10/site-packages/transformers/modeling_utils.py:2175, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
   2173 if not isinstance(config, PretrainedConfig):
   2174     config_path = config if config is not None else pretrained_model_name_or_path
-> 2175     config, model_kwargs = cls.config_class.from_pretrained(
   2176         config_path,
   2177         cache_dir=cache_dir,
   2178         return_unused_kwargs=True,
   2179         force_download=force_download,
   2180         resume_download=resume_download,
   2181         proxies=proxies,
   2182         local_files_only=local_files_only,
   2183         use_auth_token=use_auth_token,
   2184         revision=revision,
   2185         subfolder=subfolder,
   2186         _from_auto=from_auto_class,
   2187         _from_pipeline=from_pipeline,
   2188         **kwargs,
   2189     )
   2190 else:
   2191     model_kwargs = kwargs

AttributeError: 'NoneType' object has no attribute 'from_pretrained'

I wonder what is the correct way to save and load the model?

Python 3.11 dataclass compatibility

Hi,
Using python 3.11 @dataclass decorator throws an error. Here is the traceback

    import tasknet as tn
/venv/lib/python3.11/site-packages/tasknet/__init__.py:1: in <module>
    from .models import *
/venv/lib/python3.11/site-packages/tasknet/models.py:23: in <module>
    from .tasks import Classification
/venv/lib/python3.11/site-packages/tasknet/tasks.py:198: in <module>
    @dataclass
/usr/local/lib/python3.11/dataclasses.py:1230: in dataclass
    return wrap(cls)
/usr/local/lib/python3.11/dataclasses.py:1220: in wrap
    return _process_class(cls, init, repr, eq, order, unsafe_hash,
/usr/local/lib/python3.11/dataclasses.py:958: in _process_class
    cls_fields.append(_get_field(cls, name, type, kw_only))
/usr/local/lib/python3.11/dataclasses.py:815: in _get_field
    raise ValueError(f'mutable default {type(f.default)} for field '
E   ValueError: mutable default <class 'tasknet.tasks.DataCollatorForMultipleChoice'> for field data_collator is not allowed: use default_factory

It looks like a field would be needed here.

Specify device when loading pipelines

At the moment we cannot specify a device when loading a pipeline using the load_pipeline helper function. To have a fine grained control over how pipelines are loaded, it would be nice too pass it to the TextClassificationPipeline constructor.

Return all scores from a text classification pipeline

Currently this is not a parameter that we can pass to the TextClassificationPipeline from the load_pipeline function. It would be interesting to have the functionality when using the model on specific tasks for inference.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.