gwot package
Submodules
gwot.altsolver module
- gwot.altsolver.solve_adam(self, steps=250, print_steps=10, lr=0.001, factor=1, retry_max=1, tol=0.005)
gwot.anndata_utils module
gwot.bridgesampling module
- gwot.bridgesampling.sample_brownian_bridge(t0, x0, t1, x1, sigma, N)
Sample Brownian bridge between (t0, x0) and (t1, x1) with diffusivity \(\sigma^2\). Uses a recursive scheme, partitioning the interval (t0, t1) into \(2^N\) steps.
- gwot.bridgesampling.sample_coupling(gamma, N=1, norm=True)
Sample from coupling
- Parameters
gamma – coupling to sample from
N – number of pairs to sample
norm – if True, re-normalise gamma.
- gwot.bridgesampling.sample_paths(gamma_all, N=1, coord=False, x_all=None, get_gamma_fn=None, num_couplings=None)
Sample paths from sequence of couplings
- Parameters
gamma_all – sequence of T-1 couplings to sample from. If None, then get_gamma_fn is used.
N – number of paths to sample
coord – if True, return samples in coordinates (rather than indices)
x_all – supports of T marginals corresponding to the T-1 provided couplings
get_gamma_fn – get_gamma_fn(i) should return the i`th coupling (of `T-1 couplings)
num_couplings – if gamma_all == None, need to specify number of total couplings (T-1)
- gwot.bridgesampling.sample_schrodinger_bridge(t0, t1, gamma, mu_spt, nu_spt, sigma, N, M)
Sample Schrodinger bridge from t0 to t1.
- Parameters
t0 – initial time
t1 – final time
gamma – coupling to sample from
mu_spt – support of initial marginal
nu_spt – support of final marginal
sigma – square root of diffusivity (diffusivity is \(\sigma^2\))
N – passed to sample_brownian_bridge
M – number of paths to sample
gwot.lambertw module
- gwot.lambertw.evalpoly(coeff, degree, z)
- gwot.lambertw.lambertw(z0, tol=1e-05)
gwot.models module
- class gwot.models.OTModel(ts, lamda_reg, D=None, w=None, lamda=None, eps=None, eps_df=None, c_scale=None, c_scale_df=None, m=None, g=None, kappa=None, growth_constraint='exact', u_hat=None, v_hat=None, pi_0='stationary', device=None, use_keops=True)
Bases:
torch.nn.modules.module.ModuleCore gWOT model class for inference and manipulating outputs. Forms and solves the optimisation problem with general form \(\min_{\mathbf{R}_{t_1}, \ldots, \mathbf{R}_{t_T}} \lambda \mathrm{Reg}(\mathbf{R}_{t_1}, \ldots, \mathbf{R}_{t_T}) + \mathrm{Fit}(\mathbf{R}_{t_1}, \ldots, \mathbf{R}_{t_T})\), where \(\mathbf{R}_{t_1}, \ldots, \mathbf{R}_{t_T}\) are the reconstructed marginals at times \(t_1, \ldots, t_T\).
- Parameters
ts – TimeSeries object containing input data.
lamda_reg – regularisation strength parameter \(\lambda\).
D – diffusivity \(D\).
w – torch.Tensor of weights for data-fitting term at each timepoint. If None, then we take \(w_i = N_i/\sum_j N_j\) where \(N_i\) is the number of particles at timepoint i.
lamda – torch.Tensor of weights \(\lambda_i\) controlling tradeoff of cross-entropy vs OT in data-fitting term at each timepoint.
eps – torch.Tensor of entropic regularisation parameters to use in the regularising functional OT term. If None, then we take the entries to be \(2D\Delta t_i\) (the theoretically correct value).
eps_df – torch.Tensor of entropic regularisation parameters \(\varepsilon_i\) to use in the OT component of the data-fitting functional.
c_scale – torch.Tensor of cost matrix scalings \(\overline{C}_i\) to use in the regularising functional. That is, the cost matrix for the pair of timepoints \((t_i, t_{i+1})\) will be \(C^{(i)}_{jk} = C_{jk}/\overline{C}_i\).
c_scale_df – torch.Tensor of cost matrix scalings to use in the OT component of the data-fitting functional. Defined in the same way as c_scale.
m – torch.Tensor of estimates of the total mass \(m_i\) at each timepoint \(t_i\).
g – torch.Tensor of growth rates, where \(g_{ij}\) denotes the growth rate at time \(t_i\) and spatial location \(x_j\).
kappa – (Only used for soft branching constraint) torch.Tensor of penalty weights \(\kappa_i\) corresponding \(G_{i}(\overline{\mathbf{R}}_{t_i}, \mathbf{R}_{t_i})\).
growth_constraint – “exact” for exact branching constraint, and “KL” for soft branching constraint.
u_hat – torch.Tensor of initial values for dual variables \(\{\hat{u}_i\}_{i=1}^T\). If None, we initialise with zeros.
v_hat – torch.Tensor of initial values for \(\{\hat{v}_i\}_{i=1}^T\). If None, we initialise with zeros.
pi_0 – torch.Tensor of initial distribution \(\pi_0\) to use, or else a choice of “uniform” (uniform on the space \(\overline{\mathcal{X}}\)) or “stationary” (stationary distribution of heat kernel on \(\overline{\mathcal{X}}\))
device – Device to use with PyTorch (e.g. torch.device(“cuda:0”) in the case of GPU). If None, we use torch.device(“cpu”).
use_keops – True to use KeOps for on-the-fly kernel reductions. Otherwise all kernels are precomputed and stored in memory.
- compute_phi(i=None, out_arr=None)
Compute auxiliary dual variable \(\phi_i\)
- Parameters
i – index of which \(\phi_i\) variable we want to compute. Set to be None if we want to compute all of them, but then out_arr must be specified.
out_arr – preallocated torch.Tensor in which to output phi.
- Returns
if out_arr == None, returns the value of \(\phi_i\) as a torch.Tensor. Otherwise returns None and the result is written to out_arr.
- crossent_star(u, i)
- Legendre transform \(u \mapsto \mathrm{KL}^*(\rho_{t_i} | u)\) of generalised cross-entropy
in its second argument \(x \mapsto \mathrm{KL}(\rho_{t_i} | x)\), where the first argument is the observed sample \(\rho_{t_i}\) at time i.
- Parameters
u – dual variable
i – timepoint index
- dual_obj()
Evaluate dual objective (see Eq. 3.16 in manuscript)
- forward()
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- get_K(i)
Get Gibbs kernel as a torch.Tensor for regulariser OT term from timepoint i to i+1. N.B. the main reason this exists is because there is not an easy way to convert between LazyTensor and torch.Tensor (standard dense array). Instead, we need to recompute from scratch.
- get_K_df(i)
Get Gibbs kernel as a torch.Tensor for data-fitting OT term from timepoint i to i+1. N.B. the main reason this exists is because there is not an easy way to convert between LazyTensor and torch.Tensor (standard dense array). Instead, we need to recompute from scratch.
- get_R(i=None)
Get reconstructed marginal \(\mathbf{R}_{t_i}\) at timepoint i.
- get_R_bar(i=None, phi_all=None)
Get intermediate growth marginal \(\mathbf{\overline{R}}_{t_i}\) at timepoint i.
- get_R_hat(i=None)
Get intermediate marginal \(\mathbf{\hat{R}}_{t_i}\) at timepoint i.
- get_coupling_df(i)
- Get the OT coupling \(\gamma\) for the `i`th data-fitting OT term
\(\mathrm{OT}_{\varepsilon_i}(\mathbf{R}_{t_i}, \mathbf{\hat{R}}_{t_i})\).
- Returns
\(\gamma\) as a torch.Tensor.
- get_coupling_reg(i, K=None)
- Get the OT coupling \(\gamma\) for the `i`th regularisation OT term
\(\mathrm{OT}_{\sigma^2 \Delta t_i}(\mathbf{R}_{t_i}, \mathbf{\overline{R}}_{t_{i+1}})\).
- Parameters
K – kernel matrix to use. If K == None, then use the kernels computed by kernel_init.
- Returns
\(\gamma\) as either a LazyTensor or torch.Tensor, depending on the type of K.
- interp(i, coord_orig=None, R=None, R_bar=None, N=100, interp_frac=0.5, method='geo')
- Compute displacement interpolation at time (1-interp_frac)*t[i] + interp_frac*t[i+1],
with growth.
- Parameters
i – index of timepoints to interpolate
coord_orig – if not None, a torch.Tensor of alternate coordinates with row-wise correspondence to x in which to compute the interpolation.
R – torch.Tensor of precomputed marginals \(\mathbf{R}_{t_i}\).
R_bar – torch.Tensor of precomputed intermediate growth marginals \(\overline{\mathbf{R}}_{t_i}\).
- kernel_init()
Initialise kernel matrices (as either torch.Tensor or LazyTensor) for use later.
- Parameters
use_keops – If True, then all kernel matrices are initialised as `LazyTensor`s.
- static load(path)
(Experimental) load OTModel using dill.
- logKexp(K, x)
Compute kernel reduction of the form \(\log(K\exp(x))\)
- logsumexp_weight(w, x, dim=1)
Compute kernel reduction of the form \(\log(\langle w, \exp(x) \rangle)\)
- primal_obj(terms=False)
Evaluate primal objective (see Eq. C.3 in manuscript)
- save(path)
(Experimental) save OTModel using dill.
- solve_lbfgs(max_iter=50, steps=10, lr=0.01, max_eval=None, history_size=100, line_search_fn='strong_wolfe', factor=1, retry_max=1, tol=0.005)
Solve using LBFGS (works in the general case)
- Parameters
max_iter – max LBFGS iterations per step (passed to torch.optim.LBFGS)
steps – number of steps
lr – learning rate (passed to torch.optim.LBFGS)
max_eval – maximum function evals (passed to torch.optim.LBFGS)
history_size – history size (passed to torch.optim.LBFGS)
line_search_fn – line search function to use (passed to torch.optim.LBFGS)
factor – if NaN encountered, decrease lr by factor and retry
retry_max – maximum number of restarts
tol – primal-dual tolerance for convergence.
- solve_sinkhorn(steps=1000, tol=0.005, precompute_K=False, print_interval=25)
Solve using Sinkhorn-like scheme (only for case without branching/growth, i.e. g = 1)
- Parameters
steps – number of Sinkhorn steps to use
tol – primal-dual tolerance for convergence
precompute_K – if True, store kernel matrices in dense form (as torch.Tensor)
print_interval – iteration interval at which to print iteration info.
- training: bool
- uv_init(u_hat=None, v_hat=None)
Initialise dual variables \(\{\hat{u}_i, \hat{v}_i\}_i\) for model.
- Parameters
u_hat – initial value of u_hat to use. If None, then initialise with zeros.
v_hat – initial value of v_hat to use. If None, then initialise with zeros.
- class gwot.models.OTModel_kl(ts, lamda_reg, D=None, w=None, eps=None, c_scale=None, u0=None, pi_0=None, device=None, use_keops=True)
Bases:
gwot.models.OTModelAlternative OTModel for the case where the data-fitting term is pure cross-entropy
- Parameters
ts – TimeSeries object containing input data.
lamda_reg – regularisation strength parameter \(\lambda\).
D – diffusivity \(D\).
w – torch.Tensor of weights for data-fitting term at each timepoint. If None, then we take \(w_i = N_i/\sum_j N_j\) where \(N_i\) is the number of particles at timepoint i.
eps – torch.Tensor of entropic regularisation parameters to use in the regularising functional OT term. If None, then we take the entries to be \(2D\Delta t_i\)
c_scale – torch.Tensor of cost matrix scalings \(\overline{C}_i\) to use in the regularising functional. That is, the cost matrix for the pair of timepoints \((t_i, t_{i+1})\) will be \(C^{(i)}_{jk} = C_{jk}/\overline{C}_i\).
u0 – torch.Tensor of initial values for dual variables
pi_0 – torch.Tensor of initial distribution \(\pi_0\) to use, or else a choice of “uniform” (uniform on the space \(\overline{\mathcal{X}}\)) or “stationary” (stationary distribution of heat kernel on \(\overline{\mathcal{X}}\))
device – Device to use with PyTorch (e.g. torch.device(“cuda:0”) in the case of GPU). If None, we use torch.device(“cpu”).
use_keops – True to use KeOps for on-the-fly kernel reductions. Otherwise all kernels are precomputed and stored in memory.
- F_star(u, i)
Legendre dual \(u \mapsto F^*(u)\) of cross-entropy data-fitting term \(x \mapsto -\sum_{k \in \mathrm{supp}(\hat{\rho}_{t_i})} \log(x_k)\)
- Z(log=False)
Compute normalising constant Z (TODO)
- dual_obj()
Compute dual objective
- get_R(i=None)
Get reconstructed marginal \(\mathbf{R}_{t_i}\) at timepoint i.
- get_gamma_branch(i)
- get_gamma_spine(i, K=None)
- primal_obj()
Compute primal objective
- solve_lbfgs(max_iter=50, steps=10, lr=0.01, max_eval=None, history_size=100, line_search_fn='strong_wolfe', factor=1, retry_max=1, tol=0.005)
Solve using LBFGS
- Parameters
max_iter – max LBFGS iterations per step (passed to torch.optim.LBFGS)
steps – number of steps
lr – learning rate (passed to torch.optim.LBFGS)
max_eval – maximum function evals (passed to torch.optim.LBFGS)
history_size – history size (passed to torch.optim.LBFGS)
line_search_fn – line search function to use (passed to torch.optim.LBFGS)
factor – if NaN encountered, decrease lr by factor and retry
retry_max – maximum number of restarts
tol – primal-dual tolerance for convergence.
- training: bool
- u_init(u0=None)
Initialise dual variable u
- Parameters
u0 – initial value of u. If None, then all ones.
- v1(i, log=False)
Compute intermediate variable v_1 (TODO)
- v2(i, log=False)
Compute intermediate variable v_2 (TODO)
- class gwot.models.OTModel_ot(*args, **kwargs)
Bases:
gwot.models.OTModel- crossent_star(u, i)
- Legendre transform \(u \mapsto \mathrm{KL}^*(\rho_{t_i} | u)\) of generalised cross-entropy
in its second argument \(x \mapsto \mathrm{KL}(\rho_{t_i} | x)\), where the first argument is the observed sample \(\rho_{t_i}\) at time i.
- Parameters
u – dual variable
i – timepoint index
- training: bool
gwot.sim module
- class gwot.sim.Simulation(V, dV, N, T, d, D, t_final, ic_func, pool, birth_death=False, birth=None, death=None)
Bases:
gwot.ts.TimeSeriesDiffusion-drift SDE simulations using the Euler-Maruyama method.
- Parameters
V – potential function \((x, t) \mapsto V(x, t)\)
dV – potential gradient \((x, t) \mapsto \nabla V(x, t)\)
N – number of initial particles to use, \(N_i\) corresponds to time point \(t_i\)
T – number of timepoints at which to capture snapshots
d – dimension \(d\) of simulation
D – diffusivity \(D\)
t_final – final time \(t_\mathrm{final}\) (initial time is always 0)
ic_func – function accepting arguments (N, d) and returning an array X of dimensions (N, d) where X[i, :] corresponds to the `i`th initial particle position
pool – ProcessingPool to use for parallel computation (or None)
birth_death – whether to incorporate birth-death process
birth – if birth_death == True, a function accepting arguments (X, t) returning a vector of birth rates \(\beta\) for each row in X
death – if birth_death == True, a function accepting arguments (X, t) returning a vector of death rates \(\delta\) for each row in X
- sample(steps_scale=1, trunc=None)
- Sample time-series from Simulation. Simulates independent evolving particles using
Euler-Maruyama method.
- Parameters
steps_scale – number of Euler-Maruyama steps to take between timepoints.
trunc – if provided, subsample all snapshots to have trunc particles.
- sample_trajectory(steps_scale=1, N=1)
Sample trajectory from simulation
- Parameters
steps_scale – number of Euler-Maruyama steps to take between timepoints.
N – number of trajectories to sample
- Returns
np.array of dimensions
gwot.ts module
- class gwot.ts.TimeSeries(x, dt, t_idx, D=None)
Bases:
objectBase class for time-series dataset.
- Parameters
x – np.array of observed datapoints.
dt – np.array of time increments t[i+1] - t[i].
t_idx – np.array of time indices for each datapoint in x.
D – diffusivity
gwot.util module
- gwot.util.dW(dt, sz)
Wiener process increments of size sz
- gwot.util.density_to_grid(d, x, n=(100, 100), box=array([[- 2, - 2], [2, 2]]))
Discretize a 2D distribution with weights d supported on x onto a regular grid with n = (n_x, n_y) grid elements, corresponding to box.
- gwot.util.density_to_grid_1d(d, x, n=100, box=array([- 2, 2]))
Discretize a 1D distribution with weights d supported on x onto a regular grid with n grid elements, corresponding to box.
- gwot.util.empirical_dist(mu_spt, nu_spt, max_iter=1000000)
Compute \(W_2\) distance between two empirical distributions \(\mu, \nu\).
- Parameters
mu_spt – support of measure \(\mu\)
nu_spt – support of measure \(\nu\)
max_iter – passed to ot.emd2
- gwot.util.empirical_entropic_coupling(mu_spt, nu_spt, eps, max_iter=5000, method='sinkhorn')
Compute entropy-regularised OT coupling between two empirical distributions
- Parameters
mu_spt – support of measure \(\mu\)
nu_spt – support of measure \(\nu\)
eps – regularisation parameter to use
max_iter – passed to ot.sinkhorn
method – passed to ot.sinkhorn
- gwot.util.ker_smooth(m, h)
Kernel smoothing in time domain
- Parameters
m – OTModel object
h – bandwidth of kernel-in-time
- gwot.util.pi0_kde(ts, bw_method=None, num_times=1)
Compute initial distribution \(\pi_0\) as KDE estimate
- Parameters
ts – TimeSeries object
bw_method – KDE bandwidth estimation method (passed to scipy.stats.gaussian_kde)
num_times – compute KDE of the first num_times timepoints.
- gwot.util.prod_to_grid(gamma, mu_spt, nu_spt, n=(20, 20), box=array([[- 2, - 2], [2, 2]]))
Discretize a joint distribution gamma on the product space, i.e. supported on mu_spt x nu_spt onto a grid of n = (n_x, n_y), corresponding to box.
- gwot.util.sde_integrate(dV, nu, x0, t, steps, birth_death=False, b=None, d=None, g_max=250, snaps=None)
Integrate SDE using Euler-Maruyama method (with birth-death)
- Parameters
dV – function dV(x, t) specifying the drift field
nu – diffusivity
x0 – initial particle positions at time t = 0
steps – time steps to use in Euler-Maruyama method
birth_death – True if simulation needs birth-death
b – if birth_death == True, birth rate b(x, t)
d – if birth_death == True, death rate d(x, t)
g_max – if birth_death == True, we store g_max*x0.shape[0] particles and error if exceeded.
snaps – np.array of step indices at which to record particle snapshot.
- gwot.util.to_grid_coord_1d(x, n=100, box=array([- 2, 2]))
Convert 1D coordinates x to grid indices on a 1D grid of size n, corresponding to box.
- gwot.util.velocity_from_coupling(gamma, mu_spt, nu_spt, dt)
Estimate velocity field from coupling
- Parameters
gamma – coupling
mu_spt – support of source measure
nu_spt – support of target measure
dt – time interval