Code Monkey home page Code Monkey logo

Comments (6)

romanngg avatar romanngg commented on May 26, 2024

Aggregate only does (weighted) sum-pooling, so indeed the output will be of same shape B, N, d, and it has no trainable parameters. To change channel size, add a stax.Dense(d1) layer afterwards; note that it works with any input shapes/dimensions (equivalent to 1x1[...x1] convolution), you just need to specify the channel_axis (https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Dense.html). Alternatively, you can also use stax.Conv or stax.ConvLocal layers. Lmk if this helps!

from neural-tangents.

SchenbergZY avatar SchenbergZY commented on May 26, 2024

Aggregate only does (weighted) sum-pooling, so indeed the output will be of same shape B, N, d, and it has no trainable parameters. To change channel size, add a stax.Dense(d1) layer afterwards; note that it works with any input shapes/dimensions (equivalent to 1x1[...x1] convolution), you just need to specify the channel_axis (https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Dense.html). Alternatively, you can also use stax.Conv or stax.ConvLocal layers. Lmk if this helps!

Thank you for your reply!
But here I still have questions about channel-axis: why this channel-axis exists? In the origin paper I could not find anything about channel axis. Compared to the origin code by the paper author, how can I ignore this channel-axis parameter?

from neural-tangents.

romanngg avatar romanngg commented on May 26, 2024

Which paper/code do you refer to? Channel axis is just the axis that contains your channels / hidden units / features, it's the last axis (-1, or 2) in your example of size d or d1. It's very ubiquitous in all standard deep learning layers. It is commonly the last axis (-1), but you have the flexibility to specify it to be elsewhere by setting the channel_axis parameter.

from neural-tangents.

SchenbergZY avatar SchenbergZY commented on May 26, 2024

Which paper/code do you refer to? Channel axis is just the axis that contains your channels / hidden units / features, it's the last axis (-1, or 2) in your example of size d or d1. It's very ubiquitous in all standard deep learning layers. It is commonly the last axis (-1), but you have the flexibility to specify it to be elsewhere by setting the channel_axis parameter.

In paper arxiv:1905.13192 (which you have quoted in your Aggregiate wiki), their example datasets contains no feature dimensions. Only just scalar nodes and their neighbours are contained.
In arxiv:2103.03113, GCNTK is specified but no public code provided.
I need to make GCNTK by myself.
Maybe I can set channel_axis=1 for GCNTK construction, axis=-1 as for my d and d1. Not sure if this is correct.
Still thank you for your answering.

from neural-tangents.

romanngg avatar romanngg commented on May 26, 2024

Thanks for clarifying! If your input has no channel_axis, I'd suggest adding a singleton channel_axis of size d=1, something like x = jnp.expand_dims(x, channel_axis), so it will be of size B, N, 1 (and then you can use channel_axis=-1 by default). Lmk if this works for you!

from neural-tangents.

yCobanoglu avatar yCobanoglu commented on May 26, 2024

Check my repo on transductive Node Classification/ Regression using Graph Neural Network Gaussian Processes and Graph Neural Tangent Kernel using the Neural Tangents Library.

from neural-tangents.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.