Infering neural dynamics of motor cortex during dealyed reach task using a latent SDE
In this example, we will show how to use the latentsde model to infer underlying neural dynamics from single trial spiking recordings of neurons in the dorsal premotor (PMd) and primary motor (M1) cortices. The data is available for download here.
Dynamics in the motor cortext are known to be highly autonomus during simple stereotyped tasks, so it can be predictable given an "informative" initial condition even in the absence of stimulus information.
using Pkg, Revise, Lux, LuxCUDA, Random, DifferentialEquations, SciMLSensitivity, ComponentArrays, Plots, MLUtils, OptimizationOptimisers, LinearAlgebra, Statistics, Printf, PyCall, Distributions, BenchmarkTools, Zygote
using IterTools: ncycle
using NeuroDynamics
np = pyimport("numpy");
device = "cpu"
const dev = device == "gpu" ? gpu_device() : cpu_device()
1. Loading the data and creating the dataloaders
You can prepare the data yourself or use our preprocessed data staright away which is available here
file_path = "/Users/ahmed.elgazzar/Datasets/NLB/mc_maze.npy" # change this to the path of the dataset
data = np.load(file_path, allow_pickle=true)
Y = permutedims(get(data[1], "spikes") , [3, 2, 1]) |> Array{Float32}
n_neurons , n_timepoints, n_trials = size(Y)
ts = range(0, 5.0, length=n_timepoints) |> Array{Float32}
Y_trainval , Y_test = splitobs(Y; at=0.8)
Y_train , Y_val = splitobs(Y_trainval; at=0.8);
train_loader = DataLoader((Y_train, Y_train), batchsize=32, shuffle=true)
val_loader = DataLoader((Y_val, Y_val), batchsize=16, shuffle=true)
test_loader = DataLoader((Y_test, Y_test), batchsize=16, shuffle=true);
2. Defining the model
- We will use a "Recurrent_Encoder" to infer the initial hidden state from a portion of the observations.
- We will use a BlackBox (Neural) SDE with multiplicative noise to model the latent dynamics.
- We will use a decoder with a Poisson likelihood to model the spike counts.
hp = Dict("n_states" => 10, "hidden_dim" => 64, "context_dim" => 32, "t_init" => Int(0.9 * n_timepoints))
rng = Random.MersenneTwister(2)
obs_encoder = Recurrent_Encoder(n_neurons, hp["n_states"], hp["context_dim"], 32, hp["t_init"])
drift = ModernWilsonCowan(hp["n_states"], 0)
drift_aug = Chain(Dense(hp["n_states"] + hp["context_dim"], hp["hidden_dim"], softplus), Dense(hp["hidden_dim"], hp["n_states"], tanh))
diffusion = Scale(hp["n_states"], sigmoid, init_weight=identity_init(gain=0.1))
dynamics = SDE(drift, drift_aug, diffusion, EulerHeun(), saveat=ts, dt=ts[2]-ts[1]) #ODE(drift, Tsit5)
obs_decoder = MLP_Decoder(hp["n_states"], n_neurons, 64, 1, "Poisson")
ctrl_encoder, ctrl_decoder = NoOpLayer(), NoOpLayer()
model = LatentUDE(obs_encoder, ctrl_encoder, dynamics, obs_decoder, ctrl_decoder, dev)
p, st = Lux.setup(rng, model) .|> dev
p = p |> ComponentArray{Float32}
3. Training the model
We will train the model using the AdamW optimizer with a learning rate of 1e-3 for 200 epochs.
function train(model::LatentUDE, p, st, train_loader, val_loader, epochs, print_every)
epoch = 0
L = frange_cycle_linear(epochs+1, 0.5f0, 1.0f0, 1, 0.5)
losses = []
θ_best = nothing
best_metric = -Inf
@info "Training ...."
function loss(p, y, _)
y, ts_ = y |> dev, ts |> dev
ŷ, _, x̂₀, kl_path = model(y, nothing, ts_, p, st)
batch_size = size(y)[end]
recon_loss = - poisson_loglikelihood(ŷ, y)/batch_size
kl_init = kl_normal(x̂₀[1], x̂₀[2])
kl_path = mean(kl_path[end,:])
kl_loss = kl_path + kl_init
l = recon_loss + L[epoch+1]*kl_loss
return l, recon_loss, kl_loss
end
callback = function(opt_state, l, recon_loss , kl_loss)
θ = opt_state.u
push!(losses, l)
if length(losses) % length(train_loader) == 0
epoch += 1
end
if length(losses) % (length(train_loader)*print_every) == 0
@printf("Current epoch: %d, Loss: %.2f, PoissonLL: %d, KL: %.2f\n", epoch, losses[end], recon_loss, kl_loss)
y, _ = first(val_loader)
ŷ, _, _ = predict(model, y, nothing, ts, θ, st, 20)
ŷₘ = dropdims(mean(ŷ, dims=4), dims=4)
val_bps = bits_per_spike(ŷₘ, y)
@printf("Validation bits/spike: %.2f\n", val_bps)
if val_bps > best_metric
best_metric = val_bps
θ_best = copy(θ)
@printf("Saving best model")
end
end
return false
end
adtype = Optimization.AutoZygote()
optf = OptimizationFunction((p, _ , y, y_) -> loss(p, y, y_), adtype)
optproblem = OptimizationProblem(optf, p)
result = Optimization.solve(optproblem, ADAMW(1e-3), ncycle(train_loader, epochs); callback)
return model, θ_best
end
model, θ_best = train(model, θ_best, st, train_loader, val_loader, 100, 10);
y, _ = first(test_loader)
sample = 24
ch = 4
ŷ, _, x = predict(model, y, nothing, ts, θ_best, st, 20)
ŷₘ = dropdims(mean(ŷ, dims=4), dims=4)
ŷₛ = dropdims(std(ŷ, dims=4), dims=4)
dist = Poisson.(ŷₘ)
pred_spike = rand.(dist)
xₘ = dropdims(mean(x, dims=4), dims=4)
val_bps = bits_per_spike(ŷₘ, y)
p1 = plot(transpose(y[ch:ch,:,sample]), label="True Spike", lw=2)
p2 = plot(transpose(pred_spike[ch:ch,:,sample]), label="Predicted Spike", lw=2, color="red")
p3 = plot(transpose(ŷₘ[ch:ch,:,sample]), ribbon=transpose(ŷₛ[ch:ch,:,sample]), label="Infered rates", lw=2, color="green", yticks=false)
plot(p1, p2,p3, layout=(3,1), size=(800, 400), legend=:topright)