Code Monkey home page Code Monkey logo

attorch's Introduction

attorch

Introduction
Installation
Layers
PyTorch Fallback
Tests

Introduction

attorch is a subset of PyTorch's nn module, written purely in Python using OpenAI's Triton. Its goal is to be an easily hackable, self-contained, and readable collection of neural network modules whilst maintaining or improving upon the efficiency of PyTorch. In other words, it intends to be a forkable project endowed with a simple, intuitive design that can serve as an accessible starting point for those who are seeking to develop custom deep learning operations but are not satisfied with the speed of a pure PyTorch implementation and do not have the technical expertise or resources to write CUDA kernels.

There already exist a number of wonderful PyTorch-like frameworks powered by Triton, but most concentrate solely on Transformers and NLP applications, whereas attorch aims to be more inclusive by also presenting a variety of layers pertaining to areas besides NLP such as computer vision. Moreover, attorch is not an inference-only package and fully supports both forward and backward passes, meaning it can be used during training as well as inference, though its performance for the latter is generally not on par with dedicated inference engines.

Installation

The only dependencies of attorch are torch==2.2.0 and triton==2.2.0. Please install the specified versions of these two libraries and clone this repository to get started.

Layers

Currently implemented layers, with automatic mixed precision (AMP) support, are,

  • attorch.Conv2d: 2D-convolves over the input using weights, optionally adding bias.
  • attorch.GELU: Applies GELU to the input.
  • attorch.ReLU: Applies ReLU to the input.
  • attorch.SiLU: Applies SiLU to the input.
  • attorch.Sigmoid: Applies sigmoid to the input.
  • attorch.Tanh: Applies tanh to the input.
  • attorch.LogSoftmax: Normalizes the input using softmax and takes its log.
  • attorch.Softmax: Normalizes the input using softmax.
  • attorch.Softmin: Normalizes the input using softmin.
  • attorch.BatchNorm1d: Batch-normalizes the 2D or 3D input, optionally fusing an activation function and adding a residual to the pre-activation result.
  • attorch.BatchNorm2d: Batch-normalizes the 4D input, optionally fusing an activation function and adding a residual to the pre-activation result.
  • attorch.LayerNorm: Layer-normalizes the input.
  • attorch.Linear: Linearly transforms the input using weights, optionally adding bias and fusing an activation function.
  • attorch.Dropout: Randomly zeroes elements in the input during training.
  • attorch.L1Loss: Measures the mean absolute error between the input and target.
  • attorch.MSELoss: Measures the mean squared error between the input and target.
  • attorch.CrossEntropyLoss: Measures the mean cross entropy loss between the input and target, with optional reweighing of each class.
  • attorch.NLLLoss: Measures the negative log likelihood loss between the input and target, with optional reweighing of each class.

Unless otherwise noted in their docstrings, the aforementioned layers behave identically to their PyTorch equivalents.

PyTorch Fallback

To enable easier integration of attorch and PyTorch layers, attorch.nn is offered, which provides an interface to attorch's modules with PyTorch fallback should a desired layer not be available, as seen below.

from attorch import nn


lin = nn.Linear(10, 20) # Uses attorch's linear layer
gap = nn.AdaptiveAvgPool2d(1) # Uses PyTorch's global pooling since GAP is not available in attorch

Tests

Each module can be tested against its PyTorch counterpart to ensure correctness. These tests are included under tests/ and can be executed using pytest. It should be noted that some might fail owing to numerical precision issues, but in most practical use cases, that should not be a problem.

attorch's People

Contributors

bobmcdear avatar

Stargazers

Chen Yimin avatar  avatar Di Chen avatar tensorboy avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.