Code Monkey home page Code Monkey logo

jaxampler's Introduction

Hey! I'm @MeesumQ (Meesum Qazalbash) πŸ‘‹

I am a fresh graduate (graduated in June 2024 πŸŽ“) in Computer Science (major) + Mathematics (minor), from Habib University, Pakistan πŸ‡΅πŸ‡°.

I'm passionate about numerical/accelerated computing ⚑, machine/deep learning πŸ€–, and physics-informed machine learning 🌩️. During my undergrad, I explored various projects (available on my GitHub).

Currently, I'm working on Bayesian Inference for Binary Black Holes with members of the LIGO community πŸš€, developing a device-agnostic Python package, GWKokab πŸ“¦, using JAX.

Have an idea? Let's chat! πŸ’­

Find me on:

jaxampler's People

Contributors

mahausmani avatar qazalbash avatar zeeshan5885 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

mahausmani

jaxampler's Issues

JAX operations are for `ArrayLike` objects but we are passing `Numeric` typed objects

Description

Numeric and ArrayLike are union types defined as,

Numeric = Union[Array, np.ndarray, np.bool_, np.number, bool, int, float]
Union[
  Array,  # JAX array type
  np.ndarray,  # NumPy array type
  np.bool_, np.number,  # NumPy scalar types
  bool, int, float, complex,  # Python scalar types
]

Numeric is only short of complex, other than that everything is the same. We have to come up with a method to allow Numeric typed objects for JAX operation.

/media/gradf/Academic/project/jaxampler/jaxampler/_rvs/pareto.py:62:68 - error: Operator "-" not supported for type "Numeric" when expected type is "ArrayLike" (reportGeneralTypeIssues)

What jaxampler version are you using?

nightly

Which accelerator(s) are you using?

CPU

Additional system info?

Linux

NVIDIA GPU info

No response

Accept Reject sampler not working for multivariate pdfs

For this piece of code,

from jaxampler.sampler import AcceptRejectSampler
from jaxtro.models import Wysocki2019MassModel
from matplotlib import pyplot as plt

model = Wysocki2019MassModel(alpha=0.8, k=0, mmin=5.0, mmax=40.0, Mmax=80.0, name="Wysocki2019MassModel")
sampler = AcceptRejectSampler()

samples = sampler.sample(target_rv=model, proposal_rv=model, scale=1.05, N=1000)

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(samples[:, 0], samples[:, 1], samples[:, 2])
plt.show()

we get this error,

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/gradf/Desktop/test/test_sampler.py", line 8, in <module>
    samples = sampler.sample(target_rv=model, proposal_rv=model, scale=1.05, N=1000)
  File "/home/gradf/.local/lib/python3.10/site-packages/jaxampler/sampler/arsampler.py", line 43, in sample
    pdf_ratio = target_rv.pdf(V)
  File "/home/gradf/.local/lib/python3.10/site-packages/jaxampler/rvs/crvs/crvs.py", line 38, in pdf
    return jnp.exp(self.logpdf(*x))
TypeError: Wysocki2019MassModel.logpdf() missing 1 required positional argument: 'm2'

Statistical measures like mean, median and other higher order measures for different distribution

Description

Statistical measures are very crucial for understanding distribution and its behavior. For a distribution or random variable XYZ we can define them as a property of the class.

Code

class XYZ(JObj):
    # __init__, xxf_x methods

    @property
    def mean(self):
        pass

    @property
    def std(self):
        pass

    @property
    def var(self):
        pass

    @property
    def median(self):
        pass

    @property
    def mode(self):
        pass

    @property
    def skewness(self):
        pass

    @property
    def kurtosis(self):
        pass

    @property
    def entropy(self):
        pass

Output

rv = XYZ(...)
print(f"mean of {rv} is {rv.mean}")
...

Incorrect implementation of `rvs.Triangular.logcdf_x`

Description

the log_cdf function of triangular distribution incorrectly calculates the cdf when x==self_mode:

@partial(jit, static_argnums=(0,))
def logcdf_x(self, x: Numeric) -> Numeric:
    conditions = [
        x < self._low,
        (self._low <= x) & (x < self._mode),
        x == self._mode,
        (self._mode < x) & (x < self._high),
        x >= self._high,
    ]
    choices = [
        -jnp.inf,
        2 * jnp.log(x - self._low) - jnp.log(self._high - self._low) - jnp.log(self._mode - self._low),
        jnp.log(0.5),
        jnp.log(1 - ((self._high - x) ** 2 / ((self._high - self._low) * (self._high - self._mode)))),
        jnp.log(1),
    ]
    return jnp.select(conditions, choices)

it should instead return:

2 * jnp.log(x - self._low) - jnp.log(self._high - self._low) - jnp.log(self._mode - self._low),

What jaxampler version are you using?

nightly

Which accelerator(s) are you using?

CPU

Additional system info?

Linux

NVIDIA GPU info

No response

Algorithm Classification and User Interaction Improvement

Description

This project encompasses a variety of algorithms, many of which belong to broader categories such as Monte Carlo algorithms. These algorithms share numerous aspects, including concepts and terminologies. Presently, the project's components are somewhat dispersed, necessitating a concerted effort to organize them into a cohesive structure. This particular issue emphasizes the user interaction aspect – how users engage with the system. It necessitates gathering user feedback, encompassing their preferences and criticisms. Your input and suggestions, as a user, would be greatly appreciated and valuable to this endeavor.

Add `verbose`

Description

Many frameworks provide verbose to see what is going on. We should provide an optional keyword argument verbose=False in specific functions.

Code

>>> from functools import partial
>>> from typing import Optional
>>>
>>> from jax import jit
>>> from jax.scipy.stats import norm
>>> from jaxampler.rvs import Normal, ContinuousRV
>>> from jaxampler.typing import Numeric
>>> from jaxampler.sampler import MetropolisHastingSampler
>>>
>>>
>>> class TwoPeakNormal(ContinuousRV):
...     def __init__(self, name: Optional[str] = None) -> None:
...         super().__init__(name)
...
...     @partial(jit, static_argnums=(0,))
...     def pdf_x(self, x: Numeric) -> Numeric:
...         return 0.5 * (norm.pdf(x, loc=-2.0, scale=1.0) + norm.pdf(x, loc=2.0, scale=1.0))
>>>
>>>
>>> sampler = MetropolisHastingSampler(name="forTwoPeakNormal")
>>> p = TwoPeakNormal(name="TwoPeakNormal")
>>> q = lambda x: Normal(mu=x, sigma=0.4, name="Normal")
>>>
>>> samples = sampler.sample(
...     p=p,
...     q=q,
...     N=1000,
...     burn_in=1000,
...     n_chains=3,
...     key=None,
...     hasting_ratio=True,
...     x0=q(0.0).rvs(shape=(3,)),
...     verbose=True,
... )

Output

Burn-in        : 100%|#####################################################################################################################| 1.00k/1.00k [00:00<00:00, 2.69ksamples/s]
chain      0   :  19%|######################6                                                                                                 | 189/1.00k [00:08<00:32, 24.8samples/s]
chain      1   :  19%|######################6                                                                                                 | 189/1.00k [00:08<00:36, 22.2samples/s]
chain      2   :  19%|#######################                                                                                                 | 192/1.00k [00:08<00:39, 20.4samples/s]
Total          :  19%|######################8                                                                                                 | 571/3.00k [00:08<00:36, 66.1samples/s]

broadcast numerical types to `jnp.array` for flexible shapes

Description

rvs = Beta(alpha=1.1, beta=5.5).rvs(10) # fine
rvs = Beta(alpha=jnp.array([1.1, 1.0]), beta=jnp.array([5.5, 3.0])).rvs(10) # gives error.

error is,

ValueError: beta parameter shapes must be broadcast-compatible with shape argument, and the result of broadcasting the shapes must equal the shape argument, but got result (10, 2) for shape argument (10, 1).

Proposal

It would be great to write a generic function that broadcasts every element to an array and define it in some general folder like utils.

Multivariate realizations result

Description

While investigating #9 I came upon another fact: how we return the random samples. For 10 random realizations from the Weibull distribution we get,

[3.4688995  0.854046   3.4074368  1.6077589  0.23251958 0.39575407
 3.4916189  7.738534   0.8578036  0.7367411 ]

Here comes the discussion of when we have multivariate pdfs and then how we would like to return the result. There are two options,

Column stack

[[35.49771  34.423687]
 [34.608665 33.28079 ]
 [36.43148  31.096733]
 [24.897526 17.414808]
 [39.04919  37.102623]
 [38.8019   37.70799 ]
 [36.393475 13.284719]
 [35.584866 10.177847]
 [39.162357 38.78801 ]
 [29.09693  19.118551]]

Row stack

[[31.55647  23.775223 26.185276 34.463047 32.427452 23.445316 38.49894
  21.07456  23.254461 29.685108]
 [30.232841 23.443392 22.841656 25.321205 28.777534 14.195529 36.87625
  13.814087 11.815313 23.798061]]

Shapes of parameter is not influencing the `rvs` of a random variable (distribution)

Description

0 dim parameters

It is working for the following piece of code.

from jaxampler.rvs import Uniform

U = Uniform(low=0.0, high=1.0)

samples = U.rvs(shape=(5, 5), key=None)
print(samples)

n dim parameters (n>0)

It is not working for the following piece of code.

from jaxampler.rvs import Uniform

U = Uniform(low=0.0, high=[1.0, 2.0])

samples = U.rvs(shape=(5, 5), key=None)
print(samples)

What jaxampler version are you using?

nightly

Which accelerator(s) are you using?

CPU

Additional system info?

Linux

NVIDIA GPU info

None

Point and vectroized numerical functions

Description

This issue is the extension of #15 and #18, where we need different functions depending on the type of input and workload distribution.

Suppose we have a function that is logpdf of a distribution. I propose to have two different functions for this:

  1. logpdf_x takes a single value and returns a single value.
  2. logpdf_v takes a vector and returns a vector, using vmap.

The same can be done for other functions like pdf, cdf, logcdf, logppf, ppf. This will provide a lot of flexibility to the user.

Generalized Monte Carlo Integration

Description

#23 has a box integral which means it can only integrate over hyper-cubes (in more generalized and hyperdimensional cases). But the real power of Monte-Carlo Integration lies in allowing it to sample from a distribution that is more like the shape of the integrated. To meet this requirement a new and more general Monte Carlo Integration method should be used.

Lack of standard nomenclature among similar classes

Description

At the time of the creation of this issue, there are three major classes in Jaxampler i.e. Integration, GenericRV, and Sampler. GenericRV has a strict nomenclature - function signatures and overriding and overloading. There are two types of functions that we have classified as point-valued and vectorized functions that can be identified by the suffix _x and _v in their signature, respectively. This tight system is not enforced in the Sampler class. For example, in Sampler

_ = AdaptiveAcceptRejectSampler().sample(target_rv, proposal_rv, scale, N, key)
_ = MetropolisHastingSampler().sample(p, q, burn_in, n_chains, x0, N, key, hasting_ratio)

The problem arises from the fact that different algorithms require different parameters, it is true for the extra parameters. My first suggestions are,

  • Adopt a notation discussed in a book and continue with that.
  • Community members respond to this issue and share their ideas.

Optimize using `lax`

Many functions are not optimized using lax. It should be used wherever it could be used.

Integer arguments causing `OverflowError: cannot convert float infinity to integer`

Description

For a given test,

import jax.numpy as jnp

from jaxampler.rvs import Weibull

def test_negative_x(self):
    assert jnp.allclose(Weibull(lmbda=1, k=1).pdf_x(-1), 0)

I am getting the following error: OverflowError: cannot convert float infinity to integer. It is caused by the integer arguments in pdf_x method.

================================================================================ test session starts =================================================================================
platform linux -- Python 3.11.7, pytest-7.4.4, pluggy-1.3.0
rootdir: /media/gradf/Academic/project/jaxampler
plugins: jaxtyping-0.2.25, typeguard-2.13.3
collected 1 item                                                                                                                                                                     

tests/weibull_test.py F                                                                                                                                                        [100%]

====================================================================================== FAILURES ======================================================================================
____________________________________________________________________________ TestWeibull.test_negative_x _____________________________________________________________________________

self = <weibull_test.TestWeibull object at 0x7f4b5eea6e90>

    def test_negative_x(self):
>       assert jnp.allclose(Weibull(lmbda=1, k=1).pdf_x(-1), 0)

tests/weibull_test.py:33: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = Weibull(lambda=1, k=1, name=), x = (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>,)

    @partial(jit, static_argnums=(0,))
    def pdf_x(self, *x: Numeric) -> Numeric:
>       return jnp.exp(self.logpdf_x(*x))
E       OverflowError: cannot convert float infinity to integer
E       --------------------
E       For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

jaxampler/_src/rvs/crvs.py:40: OverflowError
============================================================================== short test summary info ===============================================================================
FAILED tests/weibull_test.py::TestWeibull::test_negative_x - OverflowError: cannot convert float infinity to integer
================================================================================= 1 failed in 0.52s ==================================================================================

What jaxampler version are you using?

nightly

Which accelerator(s) are you using?

CPU

Additional system info?

No response

NVIDIA GPU info

No response

Incompatible types in static type analysis using `pyright`

Description

There are many incompatible types of annotation that require attention. It is unlikely that we would use complex therefore we should define our Numeric type as the union of the only real number types.

What jaxampler version are you using?

nightly

Which accelerator(s) are you using?

CPU

Additional system info?

Linux

NVIDIA GPU info

None

Technical Diagram

Description

Make technical diagrams of the project to assist learning and understanding.

`rvs.Binomial.cdf_x` is not working as expected for shaped `n`

Description

calling cdf_x function for binomial distribution gives the following error:

E     jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
E     It arose in the jnp.arange argument 'stop'
E     The error occurred while tracing the function cdf_x at /home/maha/Desktop/semester8/KAVISHII/jaxampler/../jaxampler/jaxampler/_src/rvs/binomial.py:67 for jit. This value became a tracer due to JAX operations on these lines:
E     
E       operation a:i32[] = add b c
E         from line /home/maha/Desktop/semester8/KAVISHII/jaxampler/../jaxampler/jaxampler/_src/rvs/binomial.py:69 (cdf_x)
E     
E     See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:2456: ConcretizationTypeError
=============================================================================== short test summary info ================================================================================
FAILED tests/drvs_test.py::TestBinomial::test_cdf_x - jax.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].

What jaxampler version are you using?

nightly

Which accelerator(s) are you using?

CPU

Additional system info?

Linux

NVIDIA GPU info

No response

Parallelize over multiple devices

Description

Jaxampler is built and tested over CPU although many of its functions work on GPU too but there is a need to specifically accelerate them for GPU.

Use pre-commit hooks for CI

Description

For smooth CI and clean code, use pre-commit hooks.

What jaxampler version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info?

No response

NVIDIA GPU info

No response

Proposal to vectorize numerical function using `vmap`

Description

As we are increasing the complexity of functions to take higher dimensional inputs, there is a problem with translating single values or more specifically shaped arrays against shaped arrays. jnp.broadcast_arrays is not so promising. To achieve this purpose utilize the power of vmap in JAX. One example is given below.

@partial(jit, static_argnums=(0,))
    def pdf(self, x: ArrayLike) -> ArrayLike:
        return vmap(lambda x_: jax_norm.pdf(x_, self._mu, self._sigma))(x)

Incompatible method overriding

Description

Many classes have incompatible overriding of methods, this is due to the difference in the count of arguments.

jaxampler/jaxampler/_sampler/mhsampler.py:31:9 - error: Method "sample" overrides class "Sampler" in an incompatible manner
    Positional parameter count mismatch; base method has 3, but override has 11 (reportIncompatibleMethodOverride)

The code is working, but the error arises in static type checking by pyright.

What jaxampler version are you using?

nightly

Which accelerator(s) are you using?

CPU

Additional system info?

Linux

NVIDIA GPU info

No response

Numerical functions for point values

Description

In aeb7c65 vmap is used to make the numerical functions more flexible, but it has removed the functionality to pass single values.

Suggestion

It would be good to have a function that decides which type of function (vectorized or non-vectorized) based on the provided input. If this is not possible, then simply provide new functions.

Broadcasting arrays making objects heavy

Description

Broadcasting parameters in Random Variables make them very heavy. This issue often arises when we are sampling from them. Rather check if they can be broadcasted together to keep shapes good.

`rvs` function contains unnecessary details which can be abstracted away!

Description

rvs(...) method in GenericRV and its inherited types contain some unnecessary details that can be hidden in the GenericRV class. For example, jaxampler.rvs.Exponential.rvs(...) is implemented as,

def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
    if key is None:
        key = self.get_key()
    new_shape = shape + self._shape
    U = jax.random.uniform(key, shape=new_shape)
    rvs_val = self._loc - self._scale * jnp.log(U)
    return rvs_val

The first three lines are common in each GenericRV.rvs method. This could not be very pleasant for the users to check for key and shape each time. Instead, they should implement some method like,

def _rvs(self, shape: tuple[int, ...], key: Array) -> Array: ...

And this method should be called inside the GenericRV.rvs method, like,

def rvs(self, shape: tuple[int, ...], key: Optional[Array] = None) -> Array:
    if key is None:
        key = self.get_key()
    new_shape = shape + self._shape
    return self._rvs(shape=shape, key=key)

This design will ease user experience.

Migrate from `setup.py` to standard python package builder.

Description

While building a package by the command,

python3 setup.py sdist bdist_wheel

this small box with a warning appears,

!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!

We must respond to it and move towards a standard module-building technique. This article will be very helpful in the process.

To generate a model by using various distributions

Description

If we have $X_1,X_2,X_3,\cdots,X_n$, then we should be able to combine them to generate a new model based on them. such as $Y = f(X_1,X_2,X_3,\cdots,X_n)$, where $f$ is any computable function. The generated random number would be IIDs.

Unit tests of distributions

Description

Write unit tests for random variables using pytest.

  • bernoulli
  • beta
  • binomial
  • boltzmann
  • cauchy
  • chi2
  • exponential
  • gamma
  • geometric
  • logistic
  • lognormal
  • normal
  • pareto
  • poisson
  • rayleigh
  • studentt
  • triangular
  • truncnormal
  • truncpowerlaw
  • uniform
  • weibull

Unmarked rvs are not tested.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.