Welcome to gWOT’s documentation!

Simulations

class gwot.sim.Simulation(V, dV, N, T, d, D, t_final, ic_func, pool, birth_death=False, birth=None, death=None)

Bases: gwot.ts.TimeSeries

Diffusion-drift SDE simulations using the Euler-Maruyama method.

Parameters
  • V – potential function \((x, t) \mapsto V(x, t)\)

  • dV – potential gradient \((x, t) \mapsto \nabla V(x, t)\)

  • N – number of initial particles to use, \(N_i\) corresponds to time point \(t_i\)

  • T – number of timepoints at which to capture snapshots

  • d – dimension \(d\) of simulation

  • D – diffusivity \(D\)

  • t_final – final time \(t_\mathrm{final}\) (initial time is always 0)

  • ic_func – function accepting arguments (N, d) and returning an array X of dimensions (N, d) where X[i, :] corresponds to the `i`th initial particle position

  • pool – ProcessingPool to use for parallel computation (or None)

  • birth_death – whether to incorporate birth-death process

  • birth – if birth_death == True, a function accepting arguments (X, t) returning a vector of birth rates \(\beta\) for each row in X

  • death – if birth_death == True, a function accepting arguments (X, t) returning a vector of death rates \(\delta\) for each row in X

sample(steps_scale=1, trunc=None)
Sample time-series from Simulation. Simulates independent evolving particles using

Euler-Maruyama method.

Parameters
  • steps_scale – number of Euler-Maruyama steps to take between timepoints.

  • trunc – if provided, subsample all snapshots to have trunc particles.

sample_trajectory(steps_scale=1, N=1)

Sample trajectory from simulation

Parameters
  • steps_scale – number of Euler-Maruyama steps to take between timepoints.

  • N – number of trajectories to sample

Returns

np.array of dimensions

Core implementation

class gwot.models.OTModel(ts, lamda_reg, D=None, w=None, lamda=None, eps=None, eps_df=None, c_scale=None, c_scale_df=None, m=None, g=None, kappa=None, growth_constraint='exact', u_hat=None, v_hat=None, pi_0='stationary', device=None, use_keops=True)

Bases: torch.nn.modules.module.Module

Core gWOT model class for inference and manipulating outputs. Forms and solves the optimisation problem with general form \(\min_{\mathbf{R}_{t_1}, \ldots, \mathbf{R}_{t_T}} \lambda \mathrm{Reg}(\mathbf{R}_{t_1}, \ldots, \mathbf{R}_{t_T}) + \mathrm{Fit}(\mathbf{R}_{t_1}, \ldots, \mathbf{R}_{t_T})\), where \(\mathbf{R}_{t_1}, \ldots, \mathbf{R}_{t_T}\) are the reconstructed marginals at times \(t_1, \ldots, t_T\).

Parameters
  • ts – TimeSeries object containing input data.

  • lamda_reg – regularisation strength parameter \(\lambda\).

  • D – diffusivity \(D\).

  • wtorch.Tensor of weights for data-fitting term at each timepoint. If None, then we take \(w_i = N_i/\sum_j N_j\) where \(N_i\) is the number of particles at timepoint i.

  • lamdatorch.Tensor of weights \(\lambda_i\) controlling tradeoff of cross-entropy vs OT in data-fitting term at each timepoint.

  • epstorch.Tensor of entropic regularisation parameters to use in the regularising functional OT term. If None, then we take the entries to be \(2D\Delta t_i\) (the theoretically correct value).

  • eps_dftorch.Tensor of entropic regularisation parameters \(\varepsilon_i\) to use in the OT component of the data-fitting functional.

  • c_scaletorch.Tensor of cost matrix scalings \(\overline{C}_i\) to use in the regularising functional. That is, the cost matrix for the pair of timepoints \((t_i, t_{i+1})\) will be \(C^{(i)}_{jk} = C_{jk}/\overline{C}_i\).

  • c_scale_dftorch.Tensor of cost matrix scalings to use in the OT component of the data-fitting functional. Defined in the same way as c_scale.

  • mtorch.Tensor of estimates of the total mass \(m_i\) at each timepoint \(t_i\).

  • gtorch.Tensor of growth rates, where \(g_{ij}\) denotes the growth rate at time \(t_i\) and spatial location \(x_j\).

  • kappa – (Only used for soft branching constraint) torch.Tensor of penalty weights \(\kappa_i\) corresponding \(G_{i}(\overline{\mathbf{R}}_{t_i}, \mathbf{R}_{t_i})\).

  • growth_constraint – “exact” for exact branching constraint, and “KL” for soft branching constraint.

  • u_hattorch.Tensor of initial values for dual variables \(\{\hat{u}_i\}_{i=1}^T\). If None, we initialise with zeros.

  • v_hattorch.Tensor of initial values for \(\{\hat{v}_i\}_{i=1}^T\). If None, we initialise with zeros.

  • pi_0torch.Tensor of initial distribution \(\pi_0\) to use, or else a choice of “uniform” (uniform on the space \(\overline{\mathcal{X}}\)) or “stationary” (stationary distribution of heat kernel on \(\overline{\mathcal{X}}\))

  • device – Device to use with PyTorch (e.g. torch.device(“cuda:0”) in the case of GPU). If None, we use torch.device(“cpu”).

  • use_keopsTrue to use KeOps for on-the-fly kernel reductions. Otherwise all kernels are precomputed and stored in memory.

compute_phi(i=None, out_arr=None)

Compute auxiliary dual variable \(\phi_i\)

Parameters
  • i – index of which \(\phi_i\) variable we want to compute. Set to be None if we want to compute all of them, but then out_arr must be specified.

  • out_arr – preallocated torch.Tensor in which to output phi.

Returns

if out_arr == None, returns the value of \(\phi_i\) as a torch.Tensor. Otherwise returns None and the result is written to out_arr.

crossent_star(u, i)
Legendre transform \(u \mapsto \mathrm{KL}^*(\rho_{t_i} | u)\) of generalised cross-entropy

in its second argument \(x \mapsto \mathrm{KL}(\rho_{t_i} | x)\), where the first argument is the observed sample \(\rho_{t_i}\) at time i.

Parameters
  • u – dual variable

  • i – timepoint index

dual_obj()

Evaluate dual objective (see Eq. 3.16 in manuscript)

forward()

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_K(i)

Get Gibbs kernel as a torch.Tensor for regulariser OT term from timepoint i to i+1. N.B. the main reason this exists is because there is not an easy way to convert between LazyTensor and torch.Tensor (standard dense array). Instead, we need to recompute from scratch.

get_K_df(i)

Get Gibbs kernel as a torch.Tensor for data-fitting OT term from timepoint i to i+1. N.B. the main reason this exists is because there is not an easy way to convert between LazyTensor and torch.Tensor (standard dense array). Instead, we need to recompute from scratch.

get_R(i=None)

Get reconstructed marginal \(\mathbf{R}_{t_i}\) at timepoint i.

get_R_bar(i=None, phi_all=None)

Get intermediate growth marginal \(\mathbf{\overline{R}}_{t_i}\) at timepoint i.

get_R_hat(i=None)

Get intermediate marginal \(\mathbf{\hat{R}}_{t_i}\) at timepoint i.

get_coupling_df(i)
Get the OT coupling \(\gamma\) for the `i`th data-fitting OT term

\(\mathrm{OT}_{\varepsilon_i}(\mathbf{R}_{t_i}, \mathbf{\hat{R}}_{t_i})\).

Returns

\(\gamma\) as a torch.Tensor.

get_coupling_reg(i, K=None)
Get the OT coupling \(\gamma\) for the `i`th regularisation OT term

\(\mathrm{OT}_{\sigma^2 \Delta t_i}(\mathbf{R}_{t_i}, \mathbf{\overline{R}}_{t_{i+1}})\).

Parameters

K – kernel matrix to use. If K == None, then use the kernels computed by kernel_init.

Returns

\(\gamma\) as either a LazyTensor or torch.Tensor, depending on the type of K.

interp(i, coord_orig=None, R=None, R_bar=None, N=100, interp_frac=0.5, method='geo')
Compute displacement interpolation at time (1-interp_frac)*t[i] + interp_frac*t[i+1],

with growth.

Parameters
  • i – index of timepoints to interpolate

  • coord_orig – if not None, a torch.Tensor of alternate coordinates with row-wise correspondence to x in which to compute the interpolation.

  • Rtorch.Tensor of precomputed marginals \(\mathbf{R}_{t_i}\).

  • R_bartorch.Tensor of precomputed intermediate growth marginals \(\overline{\mathbf{R}}_{t_i}\).

kernel_init()

Initialise kernel matrices (as either torch.Tensor or LazyTensor) for use later.

Parameters

use_keops – If True, then all kernel matrices are initialised as `LazyTensor`s.

static load(path)

(Experimental) load OTModel using dill.

logKexp(K, x)

Compute kernel reduction of the form \(\log(K\exp(x))\)

logsumexp_weight(w, x, dim=1)

Compute kernel reduction of the form \(\log(\langle w, \exp(x) \rangle)\)

primal_obj(terms=False)

Evaluate primal objective (see Eq. C.3 in manuscript)

save(path)

(Experimental) save OTModel using dill.

solve_lbfgs(max_iter=50, steps=10, lr=0.01, max_eval=None, history_size=100, line_search_fn='strong_wolfe', factor=1, retry_max=1, tol=0.005)

Solve using LBFGS (works in the general case)

Parameters
  • max_iter – max LBFGS iterations per step (passed to torch.optim.LBFGS)

  • steps – number of steps

  • lr – learning rate (passed to torch.optim.LBFGS)

  • max_eval – maximum function evals (passed to torch.optim.LBFGS)

  • history_size – history size (passed to torch.optim.LBFGS)

  • line_search_fn – line search function to use (passed to torch.optim.LBFGS)

  • factor – if NaN encountered, decrease lr by factor and retry

  • retry_max – maximum number of restarts

  • tol – primal-dual tolerance for convergence.

solve_sinkhorn(steps=1000, tol=0.005, precompute_K=False, print_interval=25)

Solve using Sinkhorn-like scheme (only for case without branching/growth, i.e. g = 1)

Parameters
  • steps – number of Sinkhorn steps to use

  • tol – primal-dual tolerance for convergence

  • precompute_K – if True, store kernel matrices in dense form (as torch.Tensor)

  • print_interval – iteration interval at which to print iteration info.

training: bool
uv_init(u_hat=None, v_hat=None)

Initialise dual variables \(\{\hat{u}_i, \hat{v}_i\}_i\) for model.

Parameters
  • u_hat – initial value of u_hat to use. If None, then initialise with zeros.

  • v_hat – initial value of v_hat to use. If None, then initialise with zeros.

class gwot.models.OTModel_kl(ts, lamda_reg, D=None, w=None, eps=None, c_scale=None, u0=None, pi_0=None, device=None, use_keops=True)

Bases: gwot.models.OTModel

Alternative OTModel for the case where the data-fitting term is pure cross-entropy

Parameters
  • ts – TimeSeries object containing input data.

  • lamda_reg – regularisation strength parameter \(\lambda\).

  • D – diffusivity \(D\).

  • wtorch.Tensor of weights for data-fitting term at each timepoint. If None, then we take \(w_i = N_i/\sum_j N_j\) where \(N_i\) is the number of particles at timepoint i.

  • epstorch.Tensor of entropic regularisation parameters to use in the regularising functional OT term. If None, then we take the entries to be \(2D\Delta t_i\)

  • c_scaletorch.Tensor of cost matrix scalings \(\overline{C}_i\) to use in the regularising functional. That is, the cost matrix for the pair of timepoints \((t_i, t_{i+1})\) will be \(C^{(i)}_{jk} = C_{jk}/\overline{C}_i\).

  • u0torch.Tensor of initial values for dual variables

  • pi_0torch.Tensor of initial distribution \(\pi_0\) to use, or else a choice of “uniform” (uniform on the space \(\overline{\mathcal{X}}\)) or “stationary” (stationary distribution of heat kernel on \(\overline{\mathcal{X}}\))

  • device – Device to use with PyTorch (e.g. torch.device(“cuda:0”) in the case of GPU). If None, we use torch.device(“cpu”).

  • use_keopsTrue to use KeOps for on-the-fly kernel reductions. Otherwise all kernels are precomputed and stored in memory.

F_star(u, i)

Legendre dual \(u \mapsto F^*(u)\) of cross-entropy data-fitting term \(x \mapsto -\sum_{k \in \mathrm{supp}(\hat{\rho}_{t_i})} \log(x_k)\)

Z(log=False)

Compute normalising constant Z (TODO)

dual_obj()

Compute dual objective

get_R(i=None)

Get reconstructed marginal \(\mathbf{R}_{t_i}\) at timepoint i.

get_gamma_branch(i)
get_gamma_spine(i, K=None)
primal_obj()

Compute primal objective

solve_lbfgs(max_iter=50, steps=10, lr=0.01, max_eval=None, history_size=100, line_search_fn='strong_wolfe', factor=1, retry_max=1, tol=0.005)

Solve using LBFGS

Parameters
  • max_iter – max LBFGS iterations per step (passed to torch.optim.LBFGS)

  • steps – number of steps

  • lr – learning rate (passed to torch.optim.LBFGS)

  • max_eval – maximum function evals (passed to torch.optim.LBFGS)

  • history_size – history size (passed to torch.optim.LBFGS)

  • line_search_fn – line search function to use (passed to torch.optim.LBFGS)

  • factor – if NaN encountered, decrease lr by factor and retry

  • retry_max – maximum number of restarts

  • tol – primal-dual tolerance for convergence.

training: bool
u_init(u0=None)

Initialise dual variable u

Parameters

u0 – initial value of u. If None, then all ones.

v1(i, log=False)

Compute intermediate variable v_1 (TODO)

v2(i, log=False)

Compute intermediate variable v_2 (TODO)

class gwot.models.OTModel_ot(*args, **kwargs)

Bases: gwot.models.OTModel

crossent_star(u, i)
Legendre transform \(u \mapsto \mathrm{KL}^*(\rho_{t_i} | u)\) of generalised cross-entropy

in its second argument \(x \mapsto \mathrm{KL}(\rho_{t_i} | x)\), where the first argument is the observed sample \(\rho_{t_i}\) at time i.

Parameters
  • u – dual variable

  • i – timepoint index

training: bool

Utility functions

gwot.util.dW(dt, sz)

Wiener process increments of size sz

gwot.util.density_to_grid(d, x, n=(100, 100), box=array([[- 2, - 2], [2, 2]]))

Discretize a 2D distribution with weights d supported on x onto a regular grid with n = (n_x, n_y) grid elements, corresponding to box.

gwot.util.density_to_grid_1d(d, x, n=100, box=array([- 2, 2]))

Discretize a 1D distribution with weights d supported on x onto a regular grid with n grid elements, corresponding to box.

gwot.util.empirical_dist(mu_spt, nu_spt, max_iter=1000000)

Compute \(W_2\) distance between two empirical distributions \(\mu, \nu\).

Parameters
  • mu_spt – support of measure \(\mu\)

  • nu_spt – support of measure \(\nu\)

  • max_iter – passed to ot.emd2

gwot.util.empirical_entropic_coupling(mu_spt, nu_spt, eps, max_iter=5000, method='sinkhorn')

Compute entropy-regularised OT coupling between two empirical distributions

Parameters
  • mu_spt – support of measure \(\mu\)

  • nu_spt – support of measure \(\nu\)

  • eps – regularisation parameter to use

  • max_iter – passed to ot.sinkhorn

  • method – passed to ot.sinkhorn

gwot.util.ker_smooth(m, h)

Kernel smoothing in time domain

Parameters
  • m – OTModel object

  • h – bandwidth of kernel-in-time

gwot.util.pi0_kde(ts, bw_method=None, num_times=1)

Compute initial distribution \(\pi_0\) as KDE estimate

Parameters
  • ts – TimeSeries object

  • bw_method – KDE bandwidth estimation method (passed to scipy.stats.gaussian_kde)

  • num_times – compute KDE of the first num_times timepoints.

gwot.util.prod_to_grid(gamma, mu_spt, nu_spt, n=(20, 20), box=array([[- 2, - 2], [2, 2]]))

Discretize a joint distribution gamma on the product space, i.e. supported on mu_spt x nu_spt onto a grid of n = (n_x, n_y), corresponding to box.

gwot.util.sde_integrate(dV, nu, x0, t, steps, birth_death=False, b=None, d=None, g_max=250, snaps=None)

Integrate SDE using Euler-Maruyama method (with birth-death)

Parameters
  • dV – function dV(x, t) specifying the drift field

  • nu – diffusivity

  • x0 – initial particle positions at time t = 0

  • steps – time steps to use in Euler-Maruyama method

  • birth_deathTrue if simulation needs birth-death

  • b – if birth_death == True, birth rate b(x, t)

  • d – if birth_death == True, death rate d(x, t)

  • g_max – if birth_death == True, we store g_max*x0.shape[0] particles and error if exceeded.

  • snapsnp.array of step indices at which to record particle snapshot.

gwot.util.to_grid_coord_1d(x, n=100, box=array([- 2, 2]))

Convert 1D coordinates x to grid indices on a 1D grid of size n, corresponding to box.

gwot.util.velocity_from_coupling(gamma, mu_spt, nu_spt, dt)

Estimate velocity field from coupling

Parameters
  • gamma – coupling

  • mu_spt – support of source measure

  • nu_spt – support of target measure

  • dt – time interval

Bridge sampling

gwot.bridgesampling.sample_brownian_bridge(t0, x0, t1, x1, sigma, N)

Sample Brownian bridge between (t0, x0) and (t1, x1) with diffusivity \(\sigma^2\). Uses a recursive scheme, partitioning the interval (t0, t1) into \(2^N\) steps.

gwot.bridgesampling.sample_coupling(gamma, N=1, norm=True)

Sample from coupling

Parameters
  • gamma – coupling to sample from

  • N – number of pairs to sample

  • norm – if True, re-normalise gamma.

gwot.bridgesampling.sample_paths(gamma_all, N=1, coord=False, x_all=None, get_gamma_fn=None, num_couplings=None)

Sample paths from sequence of couplings

Parameters
  • gamma_all – sequence of T-1 couplings to sample from. If None, then get_gamma_fn is used.

  • N – number of paths to sample

  • coord – if True, return samples in coordinates (rather than indices)

  • x_all – supports of T marginals corresponding to the T-1 provided couplings

  • get_gamma_fnget_gamma_fn(i) should return the i`th coupling (of `T-1 couplings)

  • num_couplings – if gamma_all == None, need to specify number of total couplings (T-1)

gwot.bridgesampling.sample_schrodinger_bridge(t0, t1, gamma, mu_spt, nu_spt, sigma, N, M)

Sample Schrodinger bridge from t0 to t1.

Parameters
  • t0 – initial time

  • t1 – final time

  • gamma – coupling to sample from

  • mu_spt – support of initial marginal

  • nu_spt – support of final marginal

  • sigma – square root of diffusivity (diffusivity is \(\sigma^2\))

  • N – passed to sample_brownian_bridge

  • M – number of paths to sample

Indices and tables