Code Monkey home page Code Monkey logo

ivandrokin / torch-conv-kan Goto Github PK

View Code? Open in Web Editor NEW
352.0 6.0 25.0 8.09 MB

This project is dedicated to the implementation and research of Kolmogorov-Arnold convolutional networks. The repository includes implementations of 1D, 2D, and 3D convolutions with different kernels, ResNet-like and DenseNet-like models, training code based on accelerate/PyTorch, as well as scripts for experiments with CIFAR-10 and Tiny ImageNet.

License: MIT License

Python 100.00%
computer-vision convolutional-neural-networks kolmogorov-arnold-networks

torch-conv-kan's Introduction

arXiv

TorchConv KAN: A Convolutional Kolmogorov-Arnold Networks Collection

This project introduces and demonstrates the training, validation, and quantization of the Convolutional KAN model using PyTorch with CUDA acceleration. The torch-conv-kan evaluates performance on the MNIST, CIFAR, TinyImagenet and Imagenet1k datasets.

Project Status: Under Development

Updates

  • ✅ [2024/05/13] Convolutional KALN layers are available

  • ✅ [2024/05/14] Convolutional KAN and Fast KAN layers are available

  • ✅ [2024/05/15] Convolutional ChebyKAN is available now. MNIST, CIFAR10, and CIFAR100 benchmarks are added.

  • ✅ [2024/05/19] ResNet-like, U-net like and MoE-based (don't ask why=)) models released with accelerate-based training code.

  • ✅ [2024/05/21] VGG-like and DenseNet-like models released! Gram KAN convolutional layers added.

  • ✅ [2024/05/23] WavKAN convolutional layers added. Fixed a bug with the output hook in trainer.py.

  • ✅ [2024/05/25] U2-net like models added. Fixed a memory leak in trainer.py.

  • ✅ [2024/05/27] Updated implementation of WavKAN - much faster now. Added VGG-WavKAN.

  • ✅ [2024/05/31] Fixed KACN Conv instability issue, added Lion optimizer, updated baseline models and benchmarks, and 🔥🔥🔥pretrained weights on Imagenet1k are released🔥🔥🔥, as well as Imagenet1k training scripts.

  • ✅ [2024/06/03] JacobiKAN Convs are available now.

  • ✅ [2024/06/05] BernsteinKANs and BernsteinKAN Convs are available now.

  • ✅ [2024/06/15] Introducing Bottleneck KAN Convs (with Gram polynomials as basis functions for now). Added LBFGS optimizer support (it's not well-tested, please raise an issue if you face any problems with it). Regularization benchmarks on CIFAR 100 are published. Hyperparameters tuning with Ray Tune are released.

  • ✅ [2024/06/18] ReLU KAN Convs are available now.

  • ✅ [2024/06/20] 🔥🔥🔥New pretrained checkpoint on Imagenet1k are released🔥🔥🔥 VGG11 style with Bottleneck Gram Convolutions. The model achieves 68.5% Top1 accuracy on Imagenet1k validation set with only 7.25M parameters.

  • ✅ [2024/07/02] 🔥🔥🔥 We released our paper 🔥🔥🔥 Kolmogorov-Arnold Convolutions: Design Principles and Empirical Studies

  • ✅ [2024/07/09] PEFT code for KAGN models are released, as well as new RDNet-like models, better implementation of Kolmogorov-Arnold-Gram concolutions, and medical images segmentation scripts.

TODO list and next steps

  • Right now VGG19-like model is training on Imagenet1k
  • Right now Resnet50-like model is training on Imagenet1k
  • Finetuning experiments on other benchmarks are in progress, as well as PEFT methods exploration
  • I'm working on pruning and visualization methods as well

Table of content:

Introducing Convolutional KAN layers

Kolmogorov-Arnold networks rely on Kolmogorov-Arnold representation theorem:

Kolmogorov-Arnold representation theorem

So, from this formula, the authors of KAN: Kolmogorov-Arnold Networks derived the new architecture: learnable activations on edges and summation on nodes. MLP in opposite performs fixed non-linearity on nodes and learnable linear projections on edges.

KAN vs MLP

In a convolutional layer, a filter or a kernel "slides" over the 2D input data, performing an elementwise multiplication. The results are summed up into a single output pixel. The kernel performs the same operation for every location it slides over, transforming a 2D (1D or 3D) matrix of features into a different one. Although 1D and 3D convolutions share the same concept, they have different filters, input data, and output data dimensions. However, we'll focus on 2D for simplicity.

Typically, after a convolutional layer, a normalization layer (like BatchNorm, InstanceNorm, etc.) and non-linear activations (ReLU, LeakyReLU, SiLU, and many more) are applied.

More formal: suppose we have an input image y, with N x N size. We omit the channel axis for simplicity, it adds another summations sign. So, first, we need to convolve it without kernel W with size m x m:

convolutional operation

Then, we apply batch norm and non-linearity, for example - ReLU:

batch norm and non-linearity

Kolmogorov-Arnold Convolutions work slightly differently: the kernel consists of a set of univariate non-linear functions. This kernel "slides" over the 2D input data, performing element-wise application of the kernel's functions. The results are then summed up into a single output pixel. More formal: suppose we have an input image y (again), with N x N size. We omit the channel axis for simplicity, it adds another summations sign. So, the KAN-based convolutions defined as:

Kolmogorov-Arnold Convolutions

And each phi is a univariate non-linear learnable function. In the original paper, the authors propose to use this form of the functions:

Kolmogorov-Arnold Phi

And authors propose to choose SiLU as b(x) activation:

Kolmogorov-Arnold SiLU

To sum up, the "traditional" convolution is a matrix of weights, while Kolmogorov-Arnold convolutions are a set of functions. That's the primary difference. The key question here is - how should we construct these univariate non-linear functions? The answer is the same as for KANs: B-splines, polynomials, RBFs, Wavelets, etc.

In this repository, the implementation of the following layers is presented:

  • The KANConv1DLayer, KANConv2DLayer, KANConv3DLayer classes represent convolutional layers based on the Kolmogorov Arnold Network, introduced in [1]. Baseline model implemented in models/baselines/conv_kan_baseline.py.

  • The KALNConv1DLayer, KALNConv2DLayer, KALNConv3DLayer classes represent convolutional layers based on the Kolmogorov Arnold Legendre Network, introduced in [2]. Baseline model implemented in models/baselines/conv_kaln_baseline.py.

  • The FastKANConv1DLayer, FastKANConv2DLayer, FastKANConv3DLayer classes represent convolutional layers based on the Fast Kolmogorov Arnold Network, introduced in [3]. Baseline model implemented in models/baselines/fast_conv_kan_baseline.py.

  • The KACNConv1DLayer, KACNConv1DLayer, KACNConv1DLayer classes represent convolutional layers based on Kolmogorov Arnold Network with Chebyshev polynomials instead of B-splines, introduced in [4]. Baseline model implemented in models/baselines/conv_kacn_baseline.py.

  • The KAGNConv1DLayer, KAGNConv1DLayer, KAGNConv1DLayer classes represent convolutional layers based on Kolmogorov Arnold Network with Gram polynomials instead of B-splines, introduced in [5]. Baseline model implemented in models/baselines/conv_kagn_baseline.py.

  • The WavKANConv1DLayer, WavKANConv1DLayer, WavKANConv1DLayer classes represent convolutional layers based on Wavelet Kolmogorov Arnold Network, introduced in [6]. Baseline model implemented in models/baselines/conv_wavkan_baseline.py.

  • The KAJNConv1DLayer, KAJNConv2DLayer, KAJNConv3DLayer classes represent convolutional layers based on Jacobi Kolmogorov Arnold Network, introduced in [7] with minor modifications.

  • We introduce the KABNConv1DLayer, KABNConv2DLayer, KABNConv3DLayer classes represent convolutional layers based on Bernstein Kolmogorov Arnold Network.

  • The KABNConv1DLayer, KABNConv2DLayer, KABNConv3DLayer classes represent convolutional layers based on ReLU KAN, introduced in [8].

Introducing Bottleneck Convolutional KAN layers

As we previously discussed, a phi function consists of two blocks: residual activation functions (left part of diagrams below) and learnable non-linearity (splines, polynomials, wavelet, etc; right part of diagrams below).

ConvKANvsBNConvKAN

The main problem is in the right part: the more channels we have in input data, the more learnable parameters we introduce in the model. So, as a Bottleneck layer in ResNets, we could do a simple trick: we can apply 1x1 squeezing convolution to the input data, perform splines in this space, and then apply 1x1 unsqueezing convolution.

Let's assume, we have input x with 512 channels, and we want to perform ConvKAN with 512 filters. First, conv 1x1 projects x to y with 128 channels for example. Now we apply learned non-linearity to y, and last conv 1x1 transforms y to t with 512 channels (again). Now we can sum t with residual activations.

In this repository, the implementation of the following bottleneck layers is presented:

  • The BottleNeckKAGNConv1DLayer, BottleNeckKAGNConv2DLayer, BottleNeckKAGNConv3DLayer classes represent bottleneck convolutional layers based on Kolmogorov Arnold Network with Gram polynomials instead of B-splines.

  • The BottleNeckKAGNConv1DLayer, BottleNeckKAGNConv2DLayer, BottleNeckKAGNConv3DLayer classes represent bottleneck convolutional layers based on Kolmogorov Arnold Network with Gram polynomials instead of B-splines.

Model Zoo

ResKANets

We introduce ResKANets - a ResNet-like model with KAN convolutions instead of regular ones. Main class ResKANet could be found models/densekanet.py. Our implementation supports blocks with KAN, Fast KAN, KALN, KAGN, and KACN convolutional layers.

After 75 training epochs on CIFAR10 ResKANet 18 with Kolmogorov Arnold Legendre convolutions achieved 84.17% accuracy and 0.985 AUC (OVO).

After 75 training epochs on Tiny Imagenet ResKANet 18 with Kolmogorov Arnold Legendre convolutions achieved 28.62% accuracy, 55.49% top-5 accuracy, and 0.932 AUC (OVO).

Please, take into account that these are preliminary results and more experiments are in progress right now.

DenseKANets

We introduce DenseKANets - a DenseNet-like model with KAN convolutions instead of regular ones. Main class DenseKANet could be found models/reskanet.py. Our implementation supports blocks with KAN, Fast KAN, KALN, KAGN, and KACN convolutional layers.

After 250 training epochs on Tiny Imagenet DenseNet 121 with Kolmogorov Arnold Gram convolutions achieved 40.61% accuracy, 65.08% top-5 accuracy, and 0.957 AUC (OVO).

Please, take into account that these are preliminary results and more experiments are in progress right now.

VGGKAN

We introduce VGGKANs - an VGG-like models with KAN convolutions instead of regular ones. Main class VGG could be found models/vggkan.py. The model supports all types of KANs Convolutional layers.

Pretrained on Imagenet1k ckeckpoints:

Model Accuracy, top1 Accuracy, top5 AUC (ovo) AUC (ovr)
VGG KAGN 11v2 59.1 82.29 99.43 99.43
VGG KAGN 11v4 61.17 83.26 99.42 99.43
VGG KAGN BN 11v4 68.5 88.46 99.61 99.61

More checkpoints are coming, stay tuned. Available for me computational resources are pretty limited, so it takes some time to train and eval all models.

UKANet and U2KANet

We introduce UKANets and U2KANets - a U-net-like model with KAN convolutions instead of regular ones, based on resnet blocks, and U2-net with KAN Convolutions instead of regular ones. Main class UKANet could be found models/ukanet.py. Our implementation supports Basic and Bottleneck blocks with KAN, Fast KAN, KALN, KAGC, and KACN convolutional layers.

Performance Metrics

Baseline models on MNIST and CIFAR10/100 TL;DR: 8 layer SimpleKAGNConv achieves 99.68 accuracy on MNIST, 84.32 on CIFAR 10 and 59.27 on CIFAR100. It's the best model on all datasets, except CIFAR10: 8-layer SimpleWavKANConv achieves 85.37 accuracy on CIFAR10.

VGG-like on Imagenet1k

Regularization, scaling and hyperparameter optimization study TL;DR 8 layer SimpleKAGNConv 74.87% accuracy on CIFAR100 with optimal set of params.

Discussion

First and foremost, it should be noted that the results obtained are preliminary. The model architecture has not been thoroughly explored and represents only two of many possible design variants.

Nevertheless, the experiments indicate that Kolmogorov-Arnold convolutional networks outperform the classical convolutional architecture on the MNIST dataset, but significantly underperform on CIFAR-10 and CIFAR-100 in terms of quality. The ChebyKAN-based convolution encounters stability issues during training, necessitating further investigation.

As a next step, I plan to search for a suitable architecture for KAN convolutions that can achieve acceptable quality on CIFAR-10/100 and attempt to scale these models to more complex datasets.

Prerequisites

Ensure you have the following installed on your system:

  • Python (version 3.9 or higher)
  • CUDA Toolkit (corresponding to your PyTorch installation's CUDA version)
  • cuDNN (compatible with your installed CUDA Toolkit)

Usage

Below is an example of a simple model based on KAN convolutions:

import torch
import torch.nn as nn

from kan_convs import KANConv2DLayer


class SimpleConvKAN(nn.Module):
    def __init__(
            self,
            layer_sizes,
            num_classes: int = 10,
            input_channels: int = 1,
            spline_order: int = 3,
            groups: int = 1):
        super(SimpleConvKAN, self).__init__()

        self.layers = nn.Sequential(
            KANConv2DLayer(input_channels, layer_sizes[0], spline_order, kernel_size=3, groups=1, padding=1, stride=1,
                           dilation=1),
            KANConv2DLayer(layer_sizes[0], layer_sizes[1], spline_order, kernel_size=3, groups=groups, padding=1,
                           stride=2, dilation=1),
            KANConv2DLayer(layer_sizes[1], layer_sizes[2], spline_order, kernel_size=3, groups=groups, padding=1,
                           stride=2, dilation=1),
            KANConv2DLayer(layer_sizes[2], layer_sizes[3], spline_order, kernel_size=3, groups=groups, padding=1,
                           stride=1, dilation=1),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.output = nn.Linear(layer_sizes[3], num_classes)

        self.drop = nn.Dropout(p=0.25)

    def forward(self, x):
        x = self.layers(x)
        x = torch.flatten(x, 1)
        x = self.drop(x)
        x = self.output(x)
        return x

To run the training and testing of the baseline models on the MNIST, CIFAR-10, and CIFAR-100 datasets, execute the following line of code:

python mnist_conv.py

This script will train baseline models on MNIST, CIFAR10 or CIFAR100, validate them, quantise and log performance metrics.

Accelerate-based training

We introduce training code with Accelerate, Hydra configs and Wandb logging.

1. Clone the Repository

Clone the torch-conv-kan repository and set up the project environment:

git clone https://github.com/IvanDrokin/torch-conv-kan.git
cd torch-conv-kan
pip install -r requirements.txt

2. Configure Weights & Biases (wandb)

To monitor experiments and model performance with wandb:

  1. Set Up wandb Account:
  • Sign up or log in at Weights & Biases.
  • Locate your API key in your account settings.
  1. Initialize wandb in Your Project:

Before running the training script, initialize wandb:

wandb login

Enter your API key when prompted to link your script executions to your wandb account.

  1. Adjust the Entity Name in configs/cifar10-reskanet.yaml or configs/tiny-imagenet-reskanet.yaml to Your Username or Team Name

Run

Update any parameters in configs and run

accelerate launch cifar.py

This script trains the model, validates it, and logs performance metrics using wandb on the CIFAR10 dataset.

accelerate launch tiny_imagenet.py

This script trains the model, validates it, and logs performance metrics using wandb on the Tiny Imagenet dataset.

Using your own dataset or model

If you would like to use your own dataset, please follow these steps:

  1. Copy tiny_imagenet.py and modify get_data() method. If the basic implementation of the Classification dataset is not suitable for your data - please, upgrade it or write your one.
  2. Replace model = reskalnet_18x64p(...) with your own one if necessary.
  3. Create config yaml in configfolders, following provided templates.
  4. Run accelerate launch your_script.py

Cite this Project

If you use this project in your research or wish to refer to the baseline results, please use the following BibTeX entry.

@misc{drokin2024kolmogorovarnoldconvolutionsdesignprinciples,
      title={Kolmogorov-Arnold Convolutions: Design Principles and Empirical Studies}, 
      author={Ivan Drokin},
      year={2024},
      eprint={2407.01092},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2407.01092}, 
}

Contributions

Contributions are welcome. Please raise issues as necessary.

Acknowledgements

This repository based on TorchKAN, FastKAN, ChebyKAN, GRAMKAN, WavKAN, JacobiKAN, and ReLU KAN. And we would like to say thanks for their open research and exploration.

References

Star History

Star History Chart

torch-conv-kan's People

Contributors

ivandrokin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

torch-conv-kan's Issues

a small problem

The parameter optimizer LBFGS does not appear to be in the train folder.My English is not very good, thank you for the update.

problem with splitting x and groups with kan_conv

This code split_x = torch.split(x, self.inputdim // self.groups, dim=1) will split x in various chunks.
However, when groups=1,

for group_ind, _x in enumerate(split_x):
            y = self.forward_kan(_x, group_ind)

the above code will run into error, since

self.base_conv = nn.ModuleList([conv_class(input_dim // groups,
                                                   output_dim // groups,
                                                   kernel_size,
                                                   stride,
                                                   padding,
                                                   dilation,
                                                   groups=1,
                                                   bias=False) for _ in range(groups)])

base_conv will only create a modulelist of size 1.

output_hook is not storing tensors

Hi. thanks for your awesome source code.

I'm �training the model with your example code with l1 & l2 regularization.
I tried to check the gradient flow of the tensor during the backward process.
However, I noticed that in train/trainer.py, output_hook was not storing the gradient tensors.
Seems like the registering hook function is not working in train/trainer.py, line 186.

Should I fix the code from

for module in model.named_modules():

to

for name, module in model.named_modules():

in train/trainer.py, line 184? Or did I miss something?

My torch version is 2.1

Thanks.

Symbolic visualization

Hello, so i found this repo really fascinating but can i extract params from each layer and somehow visualize it like in the kan paper or do you plan working on this later?

About loading checkpoint weights

Thank you for putting forward this great project! When training models using your framework, I found some satisfying results (checkpoint folders like Fig. 1) and would like to fine-tune those models by loading weights from the checkpoint folders. However, it seems that in the trainer.py you didn't include the loading pretrained model part, and the weight files are confusing too (I'm not sure whether the file in Fig. 2 with no suffix could be loaded).
It would be much appreciated if you could give insight into how the checkpoint weights are organized and how to load them in your trainer.py code :) Best Regards.
069e25a32f33f8eb3326a744a4df900
4a32f1925f218ed52f28983d2c82a94

Pruning and interpretability

Hi, this is incredible work. May I know if there are plans to add automatic/manual pruning, and an interpretable output component?

About Visualize experiments

Hello and thank you for your very meaningful work! I've been exploring how to do visualization experiments with KAN lately, do you have any good suggestions? Looking forward to hearing from you. Best regards.

Hi,about FastKANConvND some problems

Hi,thank you very much for your work. Could you please tell me if FastKANConvNDLayer is an improvement you made on FastKAN? Is there any related paper on FastKANConvND?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.