The fast wavelet transform is an important signal processing algorithm. Jet a differentiable implementation in JAX has been missing so far, I have therefore opened my implementation . It supports the one and two dimensional analysis and synthesis transforms. As well as an implementation of the forward wavelet packet transform. The plot below shows an analysis of a linear chirp signal using a Daubechies wavelet.
As the chirps’ frequency increases we see that the wavelet coefficients rise as well.
Source code is available at https://github.com/v0lta/jaxlets .