Joint modeling of neural and behavioural dynamics during dealyed reach task
In this example, we will show how to use the latentsde model to generate neural observations (spiking recordings of neurons in the dorsal premotor (PMd) and primary motor (M1) cortices) and behavioural observations (Hand velocity) of a monkey doing a dealyed reach task. The data is available for download here.
using Pkg, Revise, Lux, LuxCUDA, CUDA, Random, DifferentialEquations, SciMLSensitivity, ComponentArrays, Plots, MLUtils, OptimizationOptimisers, LinearAlgebra, Statistics, Printf, PyCall, Distributions
using IterTools: ncycle
using NeuroDynamics
np = pyimport("numpy")
device = "cpu"
const dev = device == "gpu" ? gpu_device() : cpu_device()
1. Loading the data and creating the dataloaders
You can prepare the data yourself or use our preprocessed data staright away which is available here
file_path = "/Users/ahmed.elgazzar/Datasets/NLB/mc_maze.npy" # Replace with your path to the dataset
data = np.load(file_path, allow_pickle=true)
Y_neural = permutedims(get(data[1], "spikes") , [3, 2, 1])|> Array{Float32}
Y_behaviour = permutedims(get(data[1], "hand_vel") , [3, 2, 1])|> Array{Float32}
n_neurons = size(Y_neural)[1]
n_neurons , n_timepoints, n_trials = size(Y_neural);
n_behviour = size(Y_behaviour)[1]
ts = range(0, 4, length=n_timepoints);
ts_input = repeat(ts, 1, n_trials)
U = reshape(ts_input, (1, size(ts_input)...))|> Array{Float32}
n_ctrl = size(U)[1]
(U_train, Yn_train, Yb_train) , (U_test, Yn_test, Yb_test) = splitobs((U, Y_neural, Y_behaviour); at=0.7)
train_loader = DataLoader((U_train, Yn_train, Yb_train), batchsize=28, shuffle=true)
val_loader = DataLoader((U_test, Yn_test, Yb_test), batchsize=10, shuffle=true);
2. Defining the model
- We will use a "Recurrent_Encoder" to infer the initial hidden state from a portion of the observations.
- We will use a BlackBox (Neural) SDE with multiplicative noise to model the latent dynamics.
- We will use a multi-headed decoder, one for the neural observations and one for behaviour.
hp = Dict("n_states" => 16, "hidden_dim" => 64, "context_dim" => 32, "t_init" => Int(0.8 * n_timepoints))
rng = Random.MersenneTwister(1234)
obs_encoder = Recurrent_Encoder(n_neurons, hp["n_states"], hp["context_dim"], 32, hp["t_init"])
drift = Chain(Dense(hp["n_states"], hp["hidden_dim"], softplus), Dense(hp["hidden_dim"], hp["n_states"], tanh))
drift_aug = Chain(Dense(hp["n_states"] + hp["context_dim"] + n_ctrl, hp["hidden_dim"], softplus), Dense(hp["hidden_dim"], hp["n_states"],tanh))
diffusion = Scale(hp["n_states"], sigmoid, init_weight=identity_init(gain=0.1))
dynamics = SDE(drift, drift_aug, diffusion, EulerHeun(), saveat=ts, dt=ts[2]-ts[1])
obs_decoder = Chain(MLP_Decoder(hp["n_states"], n_neurons, 64, 1, "Poisson"), Lux.BranchLayer(NoOpLayer(), Linear_Decoder(n_neurons, n_behviour,"Gaussian")))
ctrl_encoder, ctrl_decoder = NoOpLayer(), NoOpLayer()
model = LatentUDE(obs_encoder, ctrl_encoder, dynamics, obs_decoder, ctrl_decoder, dev)
p, st = Lux.setup(rng, model)
p = p |> ComponentArray{Float32};
3. Training the model
We will train the model using the AdamW optimizer with a learning rate of 1e-3 for 200 epochs.
function train(model::LatentUDE, p, st, train_loader, val_loader, epochs, print_every)
epoch = 0
L = frange_cycle_linear(epochs+1, 0.5f0, 1.0f0, 1, 0.3)
losses = []
θ_best = nothing
best_metric = -Inf
println("Training ...")
function loss(p, u, y_n, y_b)
u, y_n, y_b = u |> dev, y_n |> dev, y_b |> dev
(ŷ_n, ŷ_b), _, x̂₀, kl_path = model(y_n, u, ts, p, st)
batch_size = size(y_n)[end]
neural_loss = - poisson_loglikelihood(ŷ_n, y_n)/batch_size
behaviorual_loss = - normal_loglikelihood(ŷ_b..., y_b)
obs_loss = neural_loss + behaviorual_loss
kl_init = kl_normal(x̂₀[1], x̂₀[2])
kl_path = mean(kl_path[end,:])
kl_loss = kl_path + kl_init
l = 0.1*obs_loss + 10*L[epoch+1]*kl_loss
return l, obs_loss, kl_loss
end
callback = function(opt_state, l, obs_loss , kl_loss)
θ = opt_state.u
push!(losses, l)
if length(losses) % length(train_loader) == 0
epoch += 1
end
if length(losses) % (length(train_loader)*print_every) == 0
@printf("Current epoch: %d, Loss: %.2f, Observations_loss: %d, KL: %.2f\n", epoch, losses[end], obs_loss, kl_loss)
u, y_n, y_b = first(train_loader)
(ŷ_n, ŷ_b), _, _ = predict(model, y_n, u, ts, θ, st, 20)
ŷ_n = dropdims(mean(ŷ_n, dims=4), dims=4)
ŷ_b_m, ŷ_b_s = dropdims(mean(ŷ_b[1], dims=4), dims=4), dropdims(mean(ŷ_b[2], dims=4), dims=4)
val_bps = bits_per_spike(ŷ_n, y_n)
val_ll = normal_loglikelihood(ŷ_b_m, ŷ_b_s, y_b)
@printf("Validation bits/spike: %.2f\n", val_bps)
@printf("Validation behaviour log-likelihood: %.2f\n", val_ll)
if val_ll > best_metric
best_metric = val_ll
θ_best = copy(θ)
@printf("**** Saving best model ****\n")
end
d = plot_preds(y_b, ŷ_b[1])
display(d)
end
return false
end
adtype = Optimization.AutoZygote()
optf = OptimizationFunction((p, _ , u, y_n, y_b) -> loss(p, u, y_n, y_b), adtype)
optproblem = OptimizationProblem(optf, p)
result = Optimization.solve(optproblem, ADAMW(5e-4), ncycle(train_loader, epochs); callback)
return model, θ_best
end
model, θ_best = train(model, p, st, train_loader, val_loader, 5000, 500);
u, y_n, y_b = first(train_loader)
(ŷ_n, ŷ_b), _, x = predict(model, y_n, u, ts, θ_best, st, 20)
sample = 8
ch = 9
ŷₘ = dropmean(ŷ_n, dims=4)
ŷₛ = dropmean(ŷ_n, dims=4)
dist = Poisson.(ŷₘ)
pred_spike = rand.(dist)
xₘ = dropmean(x, dims=4)
val_bps = bits_per_spike(ŷₘ, y_n)
p1 = plot(transpose(y_n[ch:ch,:,sample]), label="True Spike", lw=2)
p2 = plot(transpose(pred_spike[ch:ch,:,sample]), label="Predicted Spike", lw=2, color="red")
p3 = plot(transpose(ŷₘ[ch:ch,:,sample]), ribbon=transpose(ŷₛ[ch:ch,:,sample]), label="Infered rates", lw=2, color="green")
plot(p1, p2,p3, layout=(3,1), size=(800, 400), legend=:topleft)
savefig("spike_prediction.png")
s = 13
plot_samples(ŷ_b[1], s)
plot!(transpose(y_b[:,:,s]), label=["True" nothing], lw=2, color="red", legend=:topleft)