Comments (5)
For future reference, in Einstein notation 2D convolution of matrix X
with filter W
may be expressed as:
Y[i,j] = W[m,n] * X[i+m-1, j+n-1]
The derivative dY[i,j] / dW[m,n]
is trivially inferred to be:
dY[i,j] / dW[m,n] = X[i+m-1, j+n-1]
And for a scalar cost C
the derivative is:
dC / dW[m,n] = X[i+m-1, j+n-1]
Given that according to Einstein notation indices i
and j
are summed out, we essentially get a convolution of ones(size(W))
with "filter" X
. If I haven't forgotten any important details, adding derivatives w.r.t. W
should be trivial.
Derivative w.r.t. X
may be more tricky, although I haven't thought about it yet.
from xdiff.jl.
Seems like derivative w.r.t. X
is also trivial - if we seek a derivative dY[i,j] / dX[p, q]
, we can substitute:
- p = i + m - 1
- q = j + n - 1
and thus express m
and n
as:
- m = p - i + 1
- n = q - j + 1
which gives us an alternative form for convolution:
Y[i,j] = W[p-i+1, q-j+1] * X[p, q]
and derivatives:
dY[i,j] / dX[p, q] = W[p-i+1, q-j+1]
dC / dX[p, q] = W[p-i+1, q-j+1]
And if I'm not mistaken, the last one is a convolution of ones(size(X))
with filter flip(W)
.
from xdiff.jl.
Also, if we add stride s
, convolution becomes:
Y[i,j] = W[m,n] * X[s*i+m-1, s*j+n-1]
This has been described in 1 along with the equation for pooling operations. E.g. max pooling may be described as:
Y[i,j] = max(W[m,n], X[s*i+m-1, s*j+n-1])
from xdiff.jl.
Basic differentiation of convolutions is now in conv
branch.
from xdiff.jl.
Done for conv2
and VectorCodeGen
, although we need more sophisticated version for serious convolution stuff (e.g. support strides as TensorFlow does it).
from xdiff.jl.
Related Issues (20)
- Error tagging new release
- Test on GPUArrays HOT 5
- Special functions HOT 1
- Support tuples as return types HOT 1
- Parse nested function calls HOT 1
- Unpack struct fields HOT 1
- Documentation
- Cache derivative function for different input sizes
- Deprecate from_einstein?
- Self-consistency checks?
- Test ALL to_buffered rules
- Multiple methods for a derivative function HOT 1
- Support simple loops
- `meta` field and keyword arguments HOT 1
- Info about upcoming removal of packages in the General registry
- Support broadcasting HOT 1
- `propagate_size` is broken for :(z = log(sum(x)))
- Reduce common subexpressions HOT 1
- Improve correctness
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from xdiff.jl.