import numpy as np
import matplotlib.pyplot as plt
from scipy import linalg
7755)
np.random.seed(
def SIR(state, *args):
"""
Parameters
----------
state: array-like, shape (3,)
Point of interest in three-dimensional space.
*args: (sigma, lambda) : float
Parameters defining the SIR dynamics.
Returns
-------
state_dot : array, shape (3,)
Values of the derivatives at *state*.
"""
= args[0]
beta = args[1]
lambd = state # Unpack the state vector
S, I, R = np.zeros(3) # Derivatives
f 0] = -beta*S*I
f[1] = beta*S*I - lambd*I
f[2] = lambd*I
f[return f
def RK4(rhs, state, dt, *args):
= rhs(state, *args)
k1 = rhs(state+k1*dt/2, *args)
k2 = rhs(state+k2*dt/2, *args)
k3 = rhs(state+k3*dt, *args)
k4
= state + (dt/6)*(k1 + 2*k2 + 2*k3 + k4)
new_state return new_state
14 Example 3: SIR Model
In this example, apply DA methods applied to an SIR systems of ordinary differential equations. The SIR system is given by
\[\begin{align} \dfrac{dS}{dt} &= - \beta SI, \quad S(0) = S_0, \\ \dfrac{dI}{dt} &= \beta SI - \lambda I, \quad I(0) = I_0, \\ \dfrac{dR}{dt} &= \lambda I, \quad R(0) = R_0, \end{align}\]
where \(S=S(t)\) are susceptibles, \(I=I(t)\) infected and \(R=R(t)\) recovered. The model parameters are \(\beta\) and \(\lambda\) and the famous rate of reproduction is then
\[ R_0 = \frac{\beta}{\gamma} S_0 .\]
# parameters SIR
= 4.0
beta = 1.0
lambd = 0.1
dt = 5
tm = int(tm/dt)
nt = np.linspace(0,tm,nt+1)
t # initialize and solve
= np.array([0.99, 0.01, 0]) # True initial conditions
u0True #time integration
= np.zeros([nt+1,3])
uTrue 0,:] = u0True
uTrue[for k in range(nt):
+1,:] = RK4(SIR,uTrue[k,:], dt, beta, lambd)
uTrue[k# Observational model. Lognormal likelihood.
= np.random.lognormal(mean=np.log(uTrue[1::]), sigma=[0.02, 0.02, 0.])
yobs
# plot results
= plt.subplots(nrows=1,ncols=1, figsize=(8,5))
fig, ax 0], 'r', label='$S(t)$', linewidth = 3)
ax.plot(t,uTrue[:,1], 'g', label='$I(t)$', linewidth = 3)
ax.plot(t,uTrue[:,2], 'b', label='$R(t)$', linewidth = 3)
ax.plot(t,uTrue[:,1::], yobs[:,0:2], marker="o", linestyle="none")
ax.plot(t[
ax.grid()
ax.legend()'t') ax.set_xlabel(
Text(0.5, 0, 't')
14.1 Ensemble KF for Data Assimilation
Here we will generalize the ensemble Kalman filter to take into account the possibility of sparse observations. This is usually the case in real-life systems, where observations are only available et fixed instants, and hence the filtering can only be applied at these times. Inbetween observations, the system evolves freely (without correction) according to its underlying state equation.
Suppose we have \(N_y\) measurements/observations at an interval of \(\delta t_y.\) This gives measurements for times \(t_0 \le t \le t_m,\) where \(t_m = N_m \delta t_m.\) This can be considered as the assimilation window. The system then evolves freely for \(t > t_m\) until some final forecast window time \(t_f.\) The state, or equation itself is simulated with a smaller \(\delta t\) and for a large number \(N_t\) steps, giving \(t_f = N_t \delta t.\) Usually, for real life systems, we will have
\[ \delta t_m \ge \delta t, \quad N_m \le N_t , \quad t_m \le t_f. \]
For code testing, we make the simplifying (unrealistic) academic assumption that
\[ \delta t_m = \delta t, \quad N_m = N_t, \quad t_f = t_m. \]
This implies the availabilty of measurements at each (and every) time step. Note that in many of the previous examples, this was indeed the case.
def enKF_SIR_setup(dt, T, dt_m, T_m, sig_w, sig_v):
"""
Prepare input (true state and observations) for the stochastic
ensemble filter of the Lorenz63 system.
Parameters:
dt: time step for state evolution
T: time interval for state evolution
dt_m: time interval between 2 measurements (can equal dt for dense observations)
T_m: time interval for observations
sig_w: state noise sd., cov. Q = sig_w**2 x np.eye(3)
sig_v: measurement noise sd., cov. R = sig_v**2 x np.eye(3)
"""
# parameters SIR
= 4.0
beta = 1.0
lambd = 3
dim_x = 3
dim_y # noise covariances
= sig_w**2 * np.eye(dim_x)
Q = sig_v**2 * np.eye(dim_y)
R # measurement operator (identity here)
def H(u):
= u
w return w
# Solve system and generate noisy observations
= int(T/dt) # number of time steps
Nt = int(T_m/dt_m) # number of observations
Nm = np.linspace(0, Nt, Nt+1) * dt # time vector
t = (np.linspace(int(dt_m/dt),int(T_m/dt),Nm)).astype(int) # obs. indices
ind_m = t[ind_m] # measurement time vector
t_m = np.array([0.99, 0.01, 0]) # True initial conditions
x0True = np.linalg.cholesky(Q) # noise std dev.
sqrt_Q = np.linalg.cholesky(R)
sqrt_R # initialize (correctly!)
= np.zeros([Nt+1, dim_x])
xTrue 0, :] = x0True
xTrue[= np.zeros((Nm, dim_y))
y = 0 # index for measurement times
km 0,:] = H(xTrue[0,:]) + sig_v * np.random.randn(dim_y)
y[for k in range(Nt):
= sqrt_Q @ np.random.randn(dim_x)
w_k +1,:] = RK4(SIR, xTrue[k,:], dt, beta, lambd) #+ w_k
xTrue[kif (km < Nm) and (k+1 == ind_m[km]):
= sqrt_R @ np.random.randn(dim_y)
v_k = H(xTrue[k+1,:]) + v_k
y[km,:] = km + 1
km # plot state and measurements
= plt.subplots(nrows=3,ncols=1, figsize=(10,8))
fig, ax = ax.flat
ax #t = T*dt
for k in range(3):
='True', linewidth = 3)
ax[k].plot(t,xTrue[:,k], label'o', fillstyle='none', \
ax[k].plot(t[ind_m],y[:,k], ='Observation', markersize = 8, markeredgewidth = 2)
label't')
ax[k].set_xlabel(0, T_m, color='lightgray', alpha=0.4, lw=0)
ax[k].axvspan(0].legend(loc="center", bbox_to_anchor=(0.5,1.25),ncol =4,fontsize=15)
ax[0].set_ylabel('S(t)')
ax[1].set_ylabel('I(t)')
ax[2].set_ylabel('R(t)')
ax[=0.5)
fig.subplots_adjust(hspace
return Q, R, xTrue, y, ind_m, Nt, Nm
= enKF_SIR_setup(dt=0.1, T=5, dt_m=0.2, T_m =2, sig_w=0.001, sig_v=0.02) Q, R, xTrue, y, ind_m, Nt, Nm
def enKF_SIR_DA(x0, P0, Q, R, y, ind_m, Nt, Nm, Ne=10):
"""
Run DA of the SIR system using the stochastic
ensemble filter with sparse observations in the DA
window, defined by time index set `ind_m`.
Parameters:
"""
# parameters SIR
= 4.0
beta = 1.0
lambd def Hx(u):
= u
w return w
= x0.shape[-1]
Nx = y.shape[-1]
Ny = np.empty((Nt+1, Nx))
enkf_m = np.empty((Nt+1, Nx, Nx))
enkf_P = np.empty((Nx, Ne))
X = np.empty((Nx, Ne))
Xf = np.empty((Ny, Nx))
HXf
= np.tile(x0, (Ne,1)).T + np.linalg.cholesky(P0)@np.random.randn(Nx, Ne) # initial ensemble state
X[:,:] = P0 # initial state covariance
P 0, :] = x0
enkf_m[0, :, :] = P0
enkf_P[
= 0 # index for measurement times
i_m
for i in range(Nt):
# ==== predict/forecast ====
for e in range(Ne):
= np.linalg.cholesky(Q) @ np.random.randn(Nx)#, Ne)
w_i = RK4(SIR, X[:,e], dt, beta, lambd) + w_i # predict state ensemble
Xf[:,e] = np.mean(Xf, axis=1) # state ensemble mean
mX = Xf - mX[:, None] # state forecast anomaly
Xfp = Xfp @ Xfp.T / (Ne - 1) # predict covariance
P # ==== prepare analysis step =====
if (i_m < Nm) and (i+1 == ind_m[i_m]):
= Hx(Xf) # nonlinear observation
HXf = np.mean(HXf, axis=1) # observation ensemble mean
mY = HXf - mY[:, None] # observation anomaly
HXp = (HXp @ HXp.T)/(Ne - 1) + R # observation covariance
S = linalg.solve(S, HXp @ Xfp.T, assume_a="pos").T / (Ne - 1) # Kalman gain
K # === perturb y and compute innovation ====
= y[i_m, :] + (np.linalg.cholesky(R)@np.random.randn(Ny, Ne)).T
ypert = ypert.T - HXf
d # ==== correct/analyze ====
= Xf + K @ d # update state ensemble
X[:,:] = np.mean(X[:,:], axis=1)# state analysis ensemble mean
mX = X[:,:] - mX[:, None] # state analysis anomaly
Xap = Xap @ Xap.T / ( Ne - 1) # update covariance
P = i_m + 1
i_m else:
= Xf # when there is no obs, then state=forecast
X[:,:] # ==== save ====
+1] = mX # save KF state estimate (mean)
enkf_m[i+1] = P # save KF error estimate (covariance)
enkf_P[ireturn enkf_m, enkf_P
# Initialize and run the analysis
= 0.0015
sig_w = 0.02
sig_v = sig_w**2 * np.eye(3) #* 1.e-6 # for comparison with DT
Q = sig_v**2 * np.eye(3)
R
= np.array([0.95, 0.05, 0]) # a little off [0.99, 0.01, 0]
x0 = 0.1
sig_vv = np.eye(3) * sig_vv**2 # Initial estimate covariance
P0 = 10
Ne = enKF_SIR_DA(x0, P0, Q, R, y, ind_m, Nt, Nm, Ne=10) Xa, P
# Post-process and plot the results
# generate unfiltered state
= np.empty((Nt+1, 3))
Xb 0,:] = x0
Xb[for i in range(Nt):
+1,:] = RK4(SIR, Xb[i,:], dt, beta, lambd)
Xb[i# plot state and measurements
= np.linspace(0, Nt, Nt+1) * dt # time vector
t = 2.
T_m = plt.subplots(nrows=3,ncols=1, figsize=(10,8))
fig, ax = ax.flat
ax
for k in range(3):
='True')#, linewidth = 3)
ax[k].plot(t,xTrue[:,k], label'--', label='EnKF analysis')#, linewidth = 3)
ax[k].plot(t,Xa[:,k], 'o', fillstyle='none', \
ax[k].plot(t[ind_m],y[:,k], ='Observation')#, markersize = 8, markeredgewidth = 2)
label':', label='Unfiltered')#, linewidth = 3)
ax[k].plot(t,Xb[:,k], 't')
ax[k].set_xlabel(0, T_m, color='lightgray', alpha=0.4, lw=0)
ax[k].axvspan(0].legend(loc="center", bbox_to_anchor=(0.5,1.25),ncol =4,fontsize=15)
ax[0].set_ylabel('S(t)')
ax[1].set_ylabel('I(t)')
ax[2].set_ylabel('R(t)')
ax[=0.5) fig.subplots_adjust(hspace
14.2 Conclusion
The ensemble Kalman filter, even with sparse observations and a nonlinear system, does an excellent job of
- tracking within the DA window
- forecasting way beyond the window, whereas the unfiltered/unassimilated, freely evolving system deviates considerably, as is to be expected from the nonlinear SIR system.