From 1924b31b4bce2194ab0763cd8d96b665c91195fe Mon Sep 17 00:00:00 2001 From: Lloyd Chapman Date: Fri, 21 May 2021 13:00:39 +0100 Subject: [PATCH 1/3] Add code for setting-specific contact matrices --- .gitignore | 4 + ...culate_setting_specific_contact_matrices.R | 113 ++++++++++ r/functions/calc_cm.R | 30 +-- r/functions/fitting_functions.R | 35 ++- r/functions/functions.R | 169 ++++++++++++-- r/presentation/plot_comparison_cms.R | 8 +- .../setting_specific_contact_matrix_plots.R | 206 ++++++++++++++++++ 7 files changed, 524 insertions(+), 41 deletions(-) create mode 100644 r/analyses/run_calculate_setting_specific_contact_matrices.R create mode 100644 r/presentation/setting_specific_contact_matrix_plots.R diff --git a/.gitignore b/.gitignore index 328235b..b438df9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ +.Rproj .Rproj.user .Rhistory .RData .Ruserdata *.DS_Store *SS/* +r/rough/ +outputs/setting_specific/ +all_mats_* \ No newline at end of file diff --git a/r/analyses/run_calculate_setting_specific_contact_matrices.R b/r/analyses/run_calculate_setting_specific_contact_matrices.R new file mode 100644 index 0000000..a4f8675 --- /dev/null +++ b/r/analyses/run_calculate_setting_specific_contact_matrices.R @@ -0,0 +1,113 @@ +## Name: fit_neg_binom +## Description: +## Input file: clean_participants.rds, clean_contacts.rds +## Functions: +## Output file: + +# the following code loops through the weeks to date and constructs a contact matrix using get_matrix(). +# There is also a calculation of R assuming probability of infection on contact on 0.1 + + +# Packages ---------------------------------------------------------------- +library(data.table) +library(ggplot2) +library(viridis) +library(doParallel) + +# Source user written scripts --------------------------------------------- + +source('r/functions/get_minimal_data.R') +source('r/functions/functions.R') +# source('r/functions/get_react_data.R') +source('r/functions/calc_cm.R') +source('r/functions/compare_Rs.R') + +# Set up parallel computing environment ------------------------------------- +ncores = detectCores() - 1 +registerDoParallel(cores = ncores) + +# Input data ---------------------------------------------------------------- + +# extract data with useful columns +data = get_minimal_data() + +# decant data into relevant containers +contacts = data[[1]] +parts = data[[2]] + +start_date = lubridate::ymd('20200323') +end_date = lubridate::ymd('20210421') + +parts = parts[between(date, start_date, end_date)] +contacts = contacts[between(date, start_date, end_date)] + + +# set breaks for age categories and get population proportions +breaks = c(0,5,12,18,30,40,50,60,70,Inf) +#breaks = c(0,18,65,Inf) +max_ = 50 # upper limit for censoring/truncation +popdata_totals = get_popvec(breaks, year_ = 2020) +weeks_in_parts = sort(unique(parts$survey_round)) +week_range = c(1,11,19,24,34,37,39,42,51) #c(min(weeks_in_parts):max(weeks_in_parts)) +#week_range = 34:51 +nwk = c(10,8,5,10,3,2,3,9,6) +samples_ = 1000 +fit_with_ = 'bs' +trunc_flag_ = T # flag for whether or not to use truncation rather than uncorrected censoring + +# Filter data ------------------------------------------------------------- +unique_wave_pid <- unique(parts$part_wave_uid) +contacts <- contacts[part_wave_uid %in% unique_wave_pid] + +parts[,part_id := paste(as.character(part_id), survey_round, sep = '_')] +contacts[,part_id := paste(as.character(part_id), survey_round, sep = '_')] + +countries = list(c("uk")) +country_names = c("uk") +regions = list(c("North East", "Yorkshire and The Humber"), c("North West"), c("East Midlands", "West Midlands"), c("East of England"), c("South West"), c("South East"), c("Greater London")) +nations = list(unlist(regions), c('Scotland'), c('Wales'), c('Northern Ireland')) +nation_names = c("England", "Scotland", "Wales")[1] + +settings = c("home","school","work","other") + +weights = get_contact_age_weights() + +for (i in 1:length(country_names)){ +#for (i in 1:length(regions)){ +# for (i in 1:length(nation_names)){ + for (j in 2:length(settings)){ + # for (k in 1:length(week_range)){ + lcms = foreach(k=1:length(week_range)) %dopar% { + print(k) + + contacts_setting <- contacts[country %in% countries[[i]] & eval(parse(text=paste0("cnt_",settings[j])))] + unique_wave_pid <- unique(contacts_setting$part_wave_uid) + + parts_setting <- parts[part_wave_uid %in% unique_wave_pid] + + # parts_nation <- parts[area_3_name %in% regions[[i]]] + # parts_nation <- parts[area_3_name %in% nations[[i]]] + # parts_nation <- parts[country %in% countries[[i]]] + + # unique_wave_pid <- unique(parts_nation$part_wave_uid) + # contacts_nation <- contacts[part_wave_uid %in% unique_wave_pid] + + # print(nations[[i]]) + print(countries[[i]]) + + + # calculate contact matrices------------------------------------------------------------- + + + outfolder=paste0('outputs/setting_specific/', country_names[i], '/') + if(!dir.exists(outfolder)){ + dir.create(outfolder, recursive = TRUE) + } + + cms_max50 = calc_cm_general(parts_setting, contacts_setting, breaks, max_ = max_, popdata_totals, weeks_range = week_range[k], nwks=nwk[k], outfolder=outfolder, fitwith=fit_with_, samples=samples_, weights=NULL, trunc_flag=trunc_flag_, setting=settings[j]) + + } + + } + +} diff --git a/r/functions/calc_cm.R b/r/functions/calc_cm.R index 8afa9c5..f4e774b 100644 --- a/r/functions/calc_cm.R +++ b/r/functions/calc_cm.R @@ -1,7 +1,7 @@ ## calculate the contact matrices -calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals, weeks_range=23:33, nwks=2, samples=10, fitwith='bs', outfolder='outputs/regular/', model_path='stan/trunc_negbinom_matrix_bunchtrunc.stan', prior_pars_mu=NULL, prior_pars_k=NULL, weights = NULL){ +calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals, weeks_range=23:33, nwks=2, samples=10, fitwith='bs', outfolder='outputs/regular/', model_path='stan/trunc_negbinom_matrix_bunchtrunc.stan', prior_pars_mu=NULL, prior_pars_k=NULL, weights = NULL, trunc_flag = F, setting = ""){ print(nwks) if(!dir.exists(paste0(outfolder, 'contact_matrices/'))){ dir.create(paste0(outfolder, 'contact_matrices/'), recursive = TRUE) @@ -32,22 +32,26 @@ calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals if (nwks == 'ALL'){ weeks_range = list(weeks_range) - print(length(weeks_range))} + print(length(weeks_range)) + } - for(week in weeks_range){ + for(i in 1:length(weeks_range)){ + week <- weeks_range[i] + # for(week in weeks_range){ if (nwks != 'ALL'){ - i = week + # i = week #print(i) - weeks <- week:(week + nwks - 1) - } - else{ + weeks = week:(week + nwks - 1) + } else{ weeks = week - week = weeks[1] - } - filename_primer = paste0(outfolder, 'contact_matrices/', fitwith, samples, '_ngrps', length(breaks) - 1, '_cap', max_, '_nwks', length(weeks),'_sr', week, '_') + } + filename_primer = paste0(outfolder, 'contact_matrices/', fitwith, samples, '_ngrps', length(breaks) - 1, '_cap', max_, '_nwks', length(weeks),'_sr', week, '_', setting, '_') + if (trunc_flag){ + filename_primer = paste0(filename_primer,"trunc_") + } - if(i %in% c(1:6, 17,18)) weeks <- c(weeks, 700) + if(week %in% c(1:6, 17,18)) weeks <- c(weeks, 700) if(length(conts_weekend[survey_round %in% weeks]$part_id) == 0) { @@ -81,8 +85,8 @@ calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals } if (fitwith == 'bs'){ - outs_weekend = get_matrix_bs(cont_per_age_per_part_weekend, breaks, max_, bs=samples) - outs_weekday = get_matrix_bs(cont_per_age_per_part_weekday, breaks, max_, bs=samples) + outs_weekend = get_matrix_bs(cont_per_age_per_part_weekend, breaks, max_, bs=samples, trunc_flag=trunc_flag, setting=setting) + outs_weekday = get_matrix_bs(cont_per_age_per_part_weekday, breaks, max_, bs=samples, trunc_flag=trunc_flag, setting=setting) } mus = (outs_weekend[[2]] * 2./7) + (outs_weekday[[2]] * 5./7) diff --git a/r/functions/fitting_functions.R b/r/functions/fitting_functions.R index dd14956..41fc1ee 100644 --- a/r/functions/fitting_functions.R +++ b/r/functions/fitting_functions.R @@ -16,14 +16,14 @@ library(data.table) # Input data ---------------------------------------------------------------- -# this function calculates log of the complement of the sum of a list of liklihoods. -# It is used to find the log liklihood of a tail of a right censored distribution +# this function calculates log of the complement of (i.e. one minus) the sum of a list of likelihoods. +# It is used to find the log likelihood of a tail of a right censored distribution complementary_logprob <- function(x) { tryCatch(log1p(-sum(exp(x))), error=function(e) -Inf) } -# This funtion calculates the likelihood of a negarive binomial disrtibution given set of partameters 'par' and data 'x'. +# This function calculates the log-likelihood of a negative binomial distribution given set of parameters 'par' and data 'x'. nb_loglik <- function(x, par, n) { k <- par[["k"]] mean <- par[["mu"]] @@ -33,6 +33,35 @@ nb_loglik <- function(x, par, n) { return(-sum(ll)) } +# This function calculates the log-likelihood of a Poisson distribution given mean 'par' and data 'x'. +poiss_loglik <- function(x, par, n){ + ll <- rep(NA_real_, length(x)) + ll[x < n] <- x[x < n] * log(par) - par - log(factorial(x[x < n])) #dpois(x[x < n], par, log = TRUE) + ll[x >= n] <- n * log(par) - par - log(factorial(n)) #dpois(n, par, log = TRUE) + return(-sum(ll)) +} + + +# This function calculates the log-likelihood of a truncated negative binomial distribution given set of parameters 'par', data 'x', and upper truncation limit 'n'. +trunc_nb_loglik <- function(x, par, n) { + k <- par[["k"]] + mean <- par[["mu"]] + ll <- rep(NA_real_, length(x)) + ll[x <= n] <- dnbinom(x[x <= n], mu = mean, size = 1/k, log = TRUE) - pnbinom(n, mu = mean, size = 1/k, log.p = TRUE) + ll[x > n] <- 0 + return(-sum(ll)) +} + +# This function calculates the log-likelihood of a truncated Poisson distribution given mean 'par', data 'x', and upper truncation limit 'n'. +trunc_poiss_loglik <- function(x, par, n){ + ll <- rep(NA_real_, length(x)) + ll[x <= n] <- x[x <= n] * log(par) - par - log(factorial(x[x <= n])) - ppois(n, par, log.p = TRUE) + ll[x > n] <- 0 + return(-sum(ll)) +} + + + # This function optimises negative binomial parameters mu and k for contacts reported by age group i in age group j nbinom_optim_ <- function(i, j, param, n, count_frame) { diff --git a/r/functions/functions.R b/r/functions/functions.R index 9ee6c10..05395bb 100644 --- a/r/functions/functions.R +++ b/r/functions/functions.R @@ -103,7 +103,7 @@ sample_age_table <- function(prts, cnts, agegroupbreaks, weights=NULL){ cnts_ages = cnts[!is.na(cnt_age_est_min)][,c('part_id')] cnts_ages[,'cnt_assigned_age_groups' :=sampled_ages_cnts] - } + } @@ -368,43 +368,170 @@ get_eigs_from_means = function(eg, mus, popdata_totals, breaks) { # Sample with replacement ------------------------------------------------- +nb_optim <- function(counts_, n, param) { + if(sum(counts_) == 0){ + out = 0 + }else{ + # out = tryCatch({ + # outs = optim(c(mu = 0.5, k = 1), lower = c(mean = 1e-5, k = 1e-5), nb_loglik, x = counts_, n = n, method = "L-BFGS-B") + # as.numeric(outs$par[param]) + # }, + # error = function(e){ + # message("Optim convergence failed due to too few counts, returning 0") + # message("Original error message:") + # message(e) + # return(0) + # }) + outs = optim(c(mu = 0.5, k = 1), lower = c(mean = 1e-5, k = 1e-5), nb_loglik, x = counts_, n = n, method = "L-BFGS-B") + if(outs$convergence==0){ # return value from optimisation if it converges + out = as.numeric(outs$par[param]) + } else { # print error message and final value from optimisation if it fails to converge, return 0 + # print(outs$message) + # print(as.numeric(outs$par[param])) + out = 0 + } + } + return(out) +} + +poiss_optim <- function(counts_, n) { + if(sum(counts_) == 0){ + out = 0 + }else{ + outs = optim(0.5, lower = 1e-5, poiss_loglik, x = counts_, n = n, method = "L-BFGS-B") + if(outs$convergence==0){ # return value from optimisation if it converges + out = as.numeric(outs$par) + } else { # print error message and final value from optimisation if it fails to converge, return 0 + # print(outs$message) + # print(as.numeric(outs$par)) + out = 0 + } + } + return(out) +} + +trunc_nb_optim <- function(counts_, n, param) { + # Remove values greater than truncation limit + counts_ = counts_[counts_<=n] + if(sum(counts_) == 0){ + out = 0 + }else{ + # out = tryCatch({ + # outs = optim(c(mu = 0.5, k = 1), lower = c(mean = 1e-5, k = 1e-5), trunc_nb_loglik, x = counts_, n = n, method = "L-BFGS-B") + # as.numeric(outs$par[param]) + # }, + # error = function(e){ + # print(counts_) + # message("Original error message:") + # message(e) + # return(0) + # }) + outs = optim(c(mu = 0.5, k = 1), lower = c(mean = 1e-5, k = 1e-5), trunc_nb_loglik, x = counts_, n = n, method = "L-BFGS-B") + if(outs$convergence==0){ # return value from optimisation if it converges + out = as.numeric(outs$par[param]) + } else { # print error message and final value from optimisation if it fails to converge, return 0 + # print(outs$message) + # print(as.numeric(outs$par[param])) + out = 0 + } + } + return(out) +} + +trunc_poiss_optim <- function(counts_, n) { + # Remove values greater than truncation limit + counts_ = counts_[counts_<=n] + if(sum(counts_) == 0){ + out = 0 + }else{ + outs = optim(0.5, lower = 1e-5, trunc_poiss_loglik, x = counts_, n = n, method = "L-BFGS-B") + if(outs$convergence==0){ # return value from optimisation if it converges + out = as.numeric(outs$par) + } else { # print error message and final value from optimisation if it fails to converge, return 0 + # print(outs$message) + # print(as.numeric(outs$par)) + out = 0 + } + } + return(out) +} -nbinom_optim_bs <- function(i, j, param, n, count_frame, bs = 1) { +nbinom_optim_bs <- function(i, j, param, n, count_frame, bs = 1, trunc_flag = F) { counts = count_frame[count_frame$prt_age_group == i & count_frame$cnt_age_group == j]$V1 - bs_optim <- function(counts_ ) { - if(sum(counts_) == 0){ - return(0) - }else{ - outs = optim(c(mu = 0.5, k = 1), lower = c(mean = 1e-5, k = 1e-5), nb_loglik, x = counts_, n = n, method = "L-BFGS-B") - return(as.numeric(outs$par[param])) + if(sum(counts) != 0){ + # counts_mat = matrix(counts,nrow = 1,ncol = length(counts)) + # if (bs > 1){ + # for(k in 1:(bs-1)){ + # counts_mat = rbind(counts_mat, sample(counts, replace = TRUE)) + # } + # } + # if (trunc_flag){ + # outs_mat = apply(counts_mat[,,drop=F], FUN = trunc_nb_optim, MARGIN = 1, n = n) + # } else { + # outs_mat = apply(counts_mat[,,drop=F], FUN = nb_optim, MARGIN = 1, n = n) + # } + outs_mat = numeric(bs) + # outs_mat = foreach(k=1:bs,.combine = 'c') %dopar% { + for (k in 1:bs){ + counts_smpl = sample(counts, replace = TRUE) + if (trunc_flag){ + outs_mat[k] = trunc_nb_optim(counts_smpl, n, param) + } else { + outs_mat[k] = nb_optim(counts_smpl, n, param) } - + } + } else{ + outs_mat = rep(0,bs) } + return(outs_mat) +} + +poiss_optim_bs <- function(i, j, param, n, count_frame, bs = 1, trunc_flag = F) { - if(sum(counts != 0)){ - counts_mat = c(counts) - for(k in 1:(bs-1)){ - counts_mat = rbind(counts_mat, sample(counts, replace = TRUE)) + counts = count_frame[count_frame$prt_age_group == i & count_frame$cnt_age_group == j]$V1 + + if(sum(counts) != 0){ + # counts_mat = matrix(counts,nrow = 1,ncol = length(counts)) + # if (bs > 1){ + # for(k in 1:(bs-1)){ + # counts_mat = rbind(counts_mat, sample(counts, replace = TRUE)) + # } + # } + # if (trunc_flag){ + # outs_mat = apply(counts_mat[,,drop=F], FUN = trunc_poiss_optim, MARGIN = 1, n = n) + # } else { + # outs_mat = apply(counts_mat[,,drop=F], FUN = poiss_optim, MARGIN = 1, n = n) + # } + outs_mat = numeric(bs) + # outs_mat = foreach(k=1:bs,.combine = 'c') %dopar% { + for (k in 1:bs){ + counts_smpl = sample(counts, replace = TRUE) + if (trunc_flag){ + outs_mat[k] = trunc_poiss_optim(counts_smpl, n) + } else { + outs_mat[k] = poiss_optim(counts_smpl, n) + } } - - outs_mat = apply(counts_mat, FUN = bs_optim, MARGIN = 1) - outs_mat } else{ - return(rep(0,bs)) + outs_mat = rep(0,bs) } + return(outs_mat) } -get_matrix_bs = function(cont_per_age_per_part, breaks, trunc, param = 'mu', bs = 1) { - +get_matrix_bs = function(cont_per_age_per_part, breaks, trunc, param = 'mu', bs = 1, trunc_flag = F, setting = "") { levs <- unique(unlist(cut(seq(0,120),breaks, right=FALSE), use.names = FALSE)) # Get columns of age-groups to put into mapply eg = expand.grid(sort(levs),sort(levs)) names(eg) = c('age_group', 'age_group_cont') - # Get means from neg_binom opitmsisation - means_mat <- mapply(nbinom_optim_bs, eg$age_group, eg$age_group_cont, param=param, n=trunc, bs = bs, count_frame=list(cont_per_age_per_part )) + # Get means from neg_binom regression (or Poisson regression if setting="home") + if (setting=="home"){ + means_mat <- mapply(poiss_optim_bs, eg$age_group, eg$age_group_cont, param=param, n=trunc, bs = bs, count_frame=list(cont_per_age_per_part), trunc_flag = trunc_flag) + } else { + means_mat <- mapply(nbinom_optim_bs, eg$age_group, eg$age_group_cont, param=param, n=trunc, bs = bs, count_frame=list(cont_per_age_per_part), trunc_flag = trunc_flag) + } # eg['size'] = mapply(nbinom_optim, eg$age_group, eg$age_group_cont, param='size') means_mat <- matrix(unlist(means_mat), ncol = (length(breaks)-1)^2) eg = data.table(eg) diff --git a/r/presentation/plot_comparison_cms.R b/r/presentation/plot_comparison_cms.R index 7647752..098b088 100644 --- a/r/presentation/plot_comparison_cms.R +++ b/r/presentation/plot_comparison_cms.R @@ -90,7 +90,7 @@ plot_cms_comparison = function(all_egs, pair_sr, region=''){ geom_tile()+ geom_text(color='white', size=5)+ facet_wrap( ~sr, ncol=1) + - scale_fill_viridis(discrete=FALSE, name='Mean \ncontacts', begin=0, end=1., limits = c(0,4.))+ + scale_fill_viridis(discrete=FALSE, name='Mean \ncontacts', begin=0, end=1., limits = c(0,7.5))+ ylab('Contact age group') + xlab('Participant age group') + theme(axis.line=element_blank(), @@ -154,7 +154,7 @@ plot_all_cms = function(all_egs, region='', periods, breaks = c(0,5,12,18,30,40, geom_tile()+ geom_text(color='white', size=5)+ facet_wrap( ~sr) + - scale_fill_viridis(discrete=FALSE, name='Mean \ncontacts', begin=0, end=1., limits = c(0.02,4.5), trans='log', breaks=c(0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2., 4.), na.value='black')+ + scale_fill_viridis(discrete=FALSE, name='Mean \ncontacts', begin=0, end=1., limits = c(0.02,7.5), trans='log', breaks=c(0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2., 4.), na.value='black')+ guides(fill=guide_colorbar(barheight=30))+ ylab('Contact age group') + xlab('Participant age group') + @@ -199,9 +199,9 @@ augment_cms = function(cmss, swapouts = 2, breaks = c(0,5,12,18,30,40,50,60,70, eg = data.table(eg) - boolian_switch = matrix(eg[,Var1 == levs[swapouts] | Var2 == levs[swapouts]], nrow=81, ncol=1000) + boolean_switch = matrix(eg[,Var1 == levs[swapouts] | Var2 == levs[swapouts]], nrow=81, ncol=1000) - cms_aug = cms2 * data.table(boolian_switch) + cms1 * (-(data.table(boolian_switch) -1)) + cms_aug = cms2 * data.table(boolean_switch) + cms1 * (-(data.table(boolean_switch) -1)) cms_aug } diff --git a/r/presentation/setting_specific_contact_matrix_plots.R b/r/presentation/setting_specific_contact_matrix_plots.R new file mode 100644 index 0000000..b873855 --- /dev/null +++ b/r/presentation/setting_specific_contact_matrix_plots.R @@ -0,0 +1,206 @@ +library(data.table) +library(cowplot) +library(ggplot2) +library(patchwork) + +source('r/presentation/plot_comparison_cms.R') +source('r/functions/get_r_estimates.R') +source('r/functions/get_minimal_data.R') + +country_names = c("uk") +settings = c("home","school","work","other") + +breaks = c(0,5,12,18,30,40,50,60,70,Inf) +week_range = c(1,11,19,24,34,37,39,42,51) +nwk = c(10,8,5,10,3,2,3,9,6) +samples_ = 1000 +fit_with_ = 'bs' +max_ = 50 +trunc_flag_ = T + +outfolder=paste0('outputs/setting_specific/', country_names[i], '/') +filename_primer = paste0(outfolder, 'contact_matrices/', fit_with_, samples_, '_ngrps', length(breaks) - 1, '_cap', max_) +fnms = character(length(week_range)) +for (k in 1:length(week_range)){ + fnms[k] = paste0(filename_primer, '_nwks', nwk[k],'_sr', week_range[k]) +} + +periods = c('1. Lockdown 1', + '2. Lockdown 1 easing', + '3. Relaxed restrictions', + '4. School reopening', + '5. Lockdown 2', + '6. Lockdown 2 easing', + '7. Christmas', + '8. Lockdown 3', + '9. Lockdown 3 + schools') + +for (i in 1:length(country_names)){ + for (j in 1:length(settings)){ + + fnms_setting <- paste0(fnms, '_', settings[j]) + if (trunc_flag_){ + print("hello") + fnms_setting <- paste0(fnms_setting,"_trunc") + } + fnms_setting <- paste0(fnms_setting,'_scms.qs') + + all_egs = get_all_egs(filenames = fnms_setting, periods = periods, breaks = breaks) + + # compared_egs = plot_full_comparison(all_egs, periods = periods[1:3], scale1 = 1.7, scale2 = 1.2, orient='lower') + # + # mat3 = compared_egs[[1]] + plot_spacer() + plot_spacer() + + # compared_egs[[4]] + compared_egs[[5]] + plot_spacer() + + # compared_egs[[7]] + compared_egs[[8]] + compared_egs[[9]] + # + # compared_egs = plot_full_comparison(all_egs, periods = periods[4:6], scale1 = 4.5, scale2 = 3., orient='lower') + # + # mat4 = compared_egs[[1]] + plot_spacer() + plot_spacer() + + # compared_egs[[4]] + compared_egs[[5]] + plot_spacer() + + # compared_egs[[7]] + compared_egs[[8]] + compared_egs[[9]] + # + # compared_egs = plot_full_comparison(all_egs, periods = periods[7:9], scale1 = 4.5, scale2 = 3., orient='lower') + # + # mat5 = compared_egs[[1]] + plot_spacer() + plot_spacer() + + # compared_egs[[4]] + compared_egs[[5]] + plot_spacer() + + # compared_egs[[7]] + compared_egs[[8]] + compared_egs[[9]] + # + # ggsave('compare_mat_LD1_plus.pdf', mat3, width=20, height=20) + # ggsave('compare_mat_LD2_plus.pdf', mat4, width=20, height=20) + # ggsave('compare_mat_LD3_plus.pdf', mat5, width=20, height=20) + + + all_mats = plot_all_cms(all_egs = all_egs, periods = periods, title='A') + + # dates = data.table::transpose( + # data.table( + # c('20200324', '20200603'), + # c('20200603', '20200729'), + # c('20200729', '20200904'), + # c('20200904', '20201024'), + # c('20201105', '20201202'), + # c('20201202', '20201219'), + # c('20201219', '20210102'), + # c('20210105', '20210308'), + # c('20210308', '20210330'))) + # colnames(dates) = c('start', 'end') + # + # dates[,start := lubridate::ymd(start)] + # dates[,end := lubridate::ymd(end)] + # dates[,periods := periods] + # dates[,period_num:=1:length(periods)] + # + # + # + # eigs = data.table() + # eigs_ld1 = rbind(eigs,get_r_estimates('England', 'bs', 1000, 50, 1:1, nwks=11)) + # eigs = rbind(eigs,get_r_estimates('England', 'bs', 1000, 50, 11:11, nwks=8)/eigs_ld1) + # eigs = rbind(eigs,get_r_estimates('England', 'bs', 1000, 50, 19:19, nwks=5)/eigs_ld1) + # eigs = rbind(eigs,get_r_estimates('England', 'bs', 1000, 50, 24:24, nwks=8)/eigs_ld1) + # eigs = rbind(eigs,get_r_estimates('England', 'bs', 1000, 50, 33:33, nwks=5)/eigs_ld1) + # eigs = rbind(eigs,get_r_estimates('England', 'bs', 1000, 50, 37:37, nwks=2)/eigs_ld1) + # eigs = rbind(eigs,get_r_estimates('England', 'bs', 1000, 50, 39:39, nwks=3)/eigs_ld1) + # eigs = rbind(eigs,get_r_estimates('England', 'bs', 1000, 50, 41:41, nwks=9)/eigs_ld1) + # eigs = rbind(eigs,get_r_estimates('England', 'bs', 1000, 50, 50:50, nwks=2)/eigs_ld1) + # + # + # suscvec_davies = c(0.4, + # 0.4, + # 0.4, + # 0.79, + # 0.86, + # 0.8, + # 0.82, + # 0.88, + # 0.74) + # + # tranvec_davies = c(0.645, + # 0.645, + # 0.605, + # 0.635, + # 0.665, + # 0.7, + # 0.745, + # 0.815, + # 0.845) + # + # eigs_cvd = data.table() + # eigs_cvd_ld1 = get_r_estimates('England', 'bs', 1000, 50, 1:1, nwks=11, suscvec = suscvec_davies, tranvec = tranvec_davies) + # eigs_cvd = rbind(eigs_cvd,get_r_estimates('England', 'bs', 1000, 50, 11:11, nwks=8, suscvec = suscvec_davies, tranvec = tranvec_davies)/eigs_cvd_ld1) + # eigs_cvd = rbind(eigs_cvd,get_r_estimates('England', 'bs', 1000, 50, 19:19, nwks=5, suscvec = suscvec_davies, tranvec = tranvec_davies)/eigs_cvd_ld1) + # eigs_cvd = rbind(eigs_cvd,get_r_estimates('England', 'bs', 1000, 50, 24:24, nwks=8, suscvec = suscvec_davies, tranvec = tranvec_davies)/eigs_cvd_ld1) + # eigs_cvd = rbind(eigs_cvd,get_r_estimates('England', 'bs', 1000, 50, 33:33, nwks=5, suscvec = suscvec_davies, tranvec = tranvec_davies)/eigs_cvd_ld1) + # eigs_cvd = rbind(eigs_cvd,get_r_estimates('England', 'bs', 1000, 50, 37:37, nwks=2, suscvec = suscvec_davies, tranvec = tranvec_davies)/eigs_cvd_ld1) + # eigs_cvd = rbind(eigs_cvd,get_r_estimates('England', 'bs', 1000, 50, 39:39, nwks=3, suscvec = suscvec_davies, tranvec = tranvec_davies)/eigs_cvd_ld1) + # eigs_cvd = rbind(eigs_cvd,get_r_estimates('England', 'bs', 1000, 50, 41:41, nwks=9, suscvec = suscvec_davies, tranvec = tranvec_davies)/eigs_cvd_ld1) + # eigs_cvd = rbind(eigs_cvd,get_r_estimates('England', 'bs', 1000, 50, 50:50, nwks=2, suscvec = suscvec_davies, tranvec = tranvec_davies)/eigs_cvd_ld1) + # + # dates[,eigs_05s := c(1,sapply(1:(8), function(X){quantile(ecdf(eigs[X,]), 0.05)}))] + # dates[,eigs_75s := c(1,sapply(1:(8), function(X){quantile(ecdf(eigs[X,]), 0.75)}))] + # dates[,eigs_95s := c(1,sapply(1:(8), function(X){quantile(ecdf(eigs[X,]), 0.95)}))] + # dates[,eigs_25s := c(1,sapply(1:(8), function(X){quantile(ecdf(eigs[X,]), 0.25)}))] + # dates[,eigs_50s := c(1,sapply(1:(8), function(X){quantile(ecdf(eigs[X,]), 0.50)}))] + # dates[,eigs_cvd_05s := c(1,sapply(1:(8), function(X){quantile(ecdf(eigs_cvd[X,]), 0.05)}))] + # dates[,eigs_cvd_75s := c(1,sapply(1:(8), function(X){quantile(ecdf(eigs_cvd[X,]), 0.75)}))] + # dates[,eigs_cvd_95s := c(1,sapply(1:(8), function(X){quantile(ecdf(eigs_cvd[X,]), 0.95)}))] + # dates[,eigs_cvd_25s := c(1,sapply(1:(8), function(X){quantile(ecdf(eigs_cvd[X,]), 0.25)}))] + # dates[,eigs_cvd_50s := c(1,sapply(1:(8), function(X){quantile(ecdf(eigs_cvd[X,]), 0.50)}))] + # + # dates[, stringency:=c("Lockdown","Easing","Relaxed","Relaxed + schools","Lockdown + schools","Easing","Relaxed","Lockdown","Lockdown + schools")] + # + # color_list <- c("Equal" = "orange", "COVID-like" = "pink") + # + # dates_plot = ggplot(dates) + + # geom_rect(aes(xmin=start, xmax=end, ymin=0.4, ymax=0.9, fill=stringency), alpha=1.)+ + # geom_pointrange(aes(x= start + 0.5 * (end - start), y=eigs_50s, ymin=eigs_05s, ymax=eigs_95s, color='Equal'))+ + # geom_pointrange(aes(x= start + 0.5 * (end - start), y=eigs_cvd_50s, ymin=eigs_cvd_05s, ymax=eigs_cvd_95s, color='COVID-like'))+ + # scale_fill_brewer(palette='Pastel1')+ + # geom_text(aes(x=start + 0.5 * (end - start), y=0.65, label=period_num), color='white', size=9)+ + # scale_x_date(breaks='month', date_labels = "%b '%y",expand = expansion(0), name='')+ + # scale_y_continuous(name='Relative change \nin eigenvalue', limits=c(0.4,3.5), breaks=seq(1.,3.4,0.5))+ + # ggtitle('B')+ + # scale_color_manual(name='Transmissibility', values=color_list)+ + # theme( + # axis.line.x=element_blank(), + # panel.grid.major.y = element_line(colour='grey') + # ) + # layout = ' + # A + # A + # A + # A + # A + # A + # B' + layout = ' + A + A + A + A + A + A' + + # all_mats_dates = all_mats + dates_plot + + # plot_layout(design = layout) + # dates[,'R_inc_Equal' := paste0(round(eigs_50s,2), ' (', round(eigs_05s,2), ' - ', round(eigs_95s,2), ')')] + # dates[,'R_inc_COVID' := paste0(round(eigs_cvd_50s,2), ' (', round(eigs_cvd_05s,2), ' - ', round(eigs_cvd_95s,2), ')')] + # + # dates[,Date := paste0(format(start, "%d %b %Y"), ' - ', format(end, "%d %b %Y"))] + # + # dates_pres = dates[,c('Date', 'periods', 'R_inc_Equal', 'R_inc_COVID')] + # + # names(dates_pres) = c('Dates', 'Periods', 'Eigenvalue') + # dates_pres + # + # write.csv(dates_pres, 'periods_eigenvalues.csv') + # + # ggsave('all_mats.pdf', all_mats_dates, width=20, height=20) + if (trunc_flag_){ + ggsave(paste0('all_mats_',settings[j],'_trunc.pdf'), all_mats, width=20, height=20) + } else { + ggsave(paste0('all_mats_',settings[j],'.pdf'), all_mats, width=20, height=20) + } + + } +} + From 3e7f6f1ed7fd0cb72302752f7dd5df0bacce65cc Mon Sep 17 00:00:00 2001 From: Lloyd Chapman Date: Tue, 6 Jul 2021 22:23:11 +0100 Subject: [PATCH 2/3] Add code for fitting zero-inflated negative binomial to school, work and other contacts --- ...culate_setting_specific_contact_matrices.R | 21 +++--- r/functions/calc_cm.R | 20 ++++-- r/functions/fitting_functions.R | 12 ++++ r/functions/functions.R | 65 +++++++++++++++---- r/presentation/plot_comparison_cms.R | 4 +- .../setting_specific_contact_matrix_plots.R | 34 ++++++---- 6 files changed, 110 insertions(+), 46 deletions(-) diff --git a/r/analyses/run_calculate_setting_specific_contact_matrices.R b/r/analyses/run_calculate_setting_specific_contact_matrices.R index a4f8675..a1bed1d 100644 --- a/r/analyses/run_calculate_setting_specific_contact_matrices.R +++ b/r/analyses/run_calculate_setting_specific_contact_matrices.R @@ -53,7 +53,8 @@ week_range = c(1,11,19,24,34,37,39,42,51) #c(min(weeks_in_parts):max(weeks_in_pa nwk = c(10,8,5,10,3,2,3,9,6) samples_ = 1000 fit_with_ = 'bs' -trunc_flag_ = T # flag for whether or not to use truncation rather than uncorrected censoring +trunc_flag_ = F # flag for whether or not to use truncation rather than uncorrected censoring +zi_ = T # flag for fitting zero-inflated negative binomial vs negative binomial # Filter data ------------------------------------------------------------- unique_wave_pid <- unique(parts$part_wave_uid) @@ -75,22 +76,22 @@ weights = get_contact_age_weights() for (i in 1:length(country_names)){ #for (i in 1:length(regions)){ # for (i in 1:length(nation_names)){ - for (j in 2:length(settings)){ + for (j in 1:length(settings)){ # for (k in 1:length(week_range)){ lcms = foreach(k=1:length(week_range)) %dopar% { print(k) - contacts_setting <- contacts[country %in% countries[[i]] & eval(parse(text=paste0("cnt_",settings[j])))] - unique_wave_pid <- unique(contacts_setting$part_wave_uid) - - parts_setting <- parts[part_wave_uid %in% unique_wave_pid] + # contacts_nation <- contacts[country %in% countries[[i]] & eval(parse(text=paste0("cnt_",settings[j])))] + # unique_wave_pid <- unique(contacts_nation$part_wave_uid) + # + # parts_nation <- parts[part_wave_uid %in% unique_wave_pid] # parts_nation <- parts[area_3_name %in% regions[[i]]] # parts_nation <- parts[area_3_name %in% nations[[i]]] - # parts_nation <- parts[country %in% countries[[i]]] + parts_nation <- parts[country %in% countries[[i]]] - # unique_wave_pid <- unique(parts_nation$part_wave_uid) - # contacts_nation <- contacts[part_wave_uid %in% unique_wave_pid] + unique_wave_pid <- unique(parts_nation$part_wave_uid) + contacts_nation <- contacts[part_wave_uid %in% unique_wave_pid] # print(nations[[i]]) print(countries[[i]]) @@ -104,7 +105,7 @@ for (i in 1:length(country_names)){ dir.create(outfolder, recursive = TRUE) } - cms_max50 = calc_cm_general(parts_setting, contacts_setting, breaks, max_ = max_, popdata_totals, weeks_range = week_range[k], nwks=nwk[k], outfolder=outfolder, fitwith=fit_with_, samples=samples_, weights=NULL, trunc_flag=trunc_flag_, setting=settings[j]) + cms_max50 = calc_cm_general(parts_nation, contacts_nation, breaks, max_ = max_, popdata_totals, weeks_range = week_range[k], nwks=nwk[k], outfolder=outfolder, fitwith=fit_with_, samples=samples_, weights=NULL, trunc_flag=trunc_flag_, zi=zi_, setting=settings[j]) } diff --git a/r/functions/calc_cm.R b/r/functions/calc_cm.R index f4e774b..4817959 100644 --- a/r/functions/calc_cm.R +++ b/r/functions/calc_cm.R @@ -1,7 +1,7 @@ ## calculate the contact matrices -calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals, weeks_range=23:33, nwks=2, samples=10, fitwith='bs', outfolder='outputs/regular/', model_path='stan/trunc_negbinom_matrix_bunchtrunc.stan', prior_pars_mu=NULL, prior_pars_k=NULL, weights = NULL, trunc_flag = F, setting = ""){ +calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals, weeks_range=23:33, nwks=2, samples=10, fitwith='bs', outfolder='outputs/regular/', model_path='stan/trunc_negbinom_matrix_bunchtrunc.stan', prior_pars_mu=NULL, prior_pars_k=NULL, weights = NULL, trunc_flag = F, zi = F, setting = ""){ print(nwks) if(!dir.exists(paste0(outfolder, 'contact_matrices/'))){ dir.create(paste0(outfolder, 'contact_matrices/'), recursive = TRUE) @@ -47,23 +47,26 @@ calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals weeks = week } filename_primer = paste0(outfolder, 'contact_matrices/', fitwith, samples, '_ngrps', length(breaks) - 1, '_cap', max_, '_nwks', length(weeks),'_sr', week, '_', setting, '_') + if (zi){ + filename_primer = paste0(filename_primer,"zi_") + } if (trunc_flag){ filename_primer = paste0(filename_primer,"trunc_") } if(week %in% c(1:6, 17,18)) weeks <- c(weeks, 700) - + # Replace weekend contacts by week contacts if there were no surveys done at the weekend in that survey round if(length(conts_weekend[survey_round %in% weeks]$part_id) == 0) { conts_weekend = conts_weekday parts_weekend = parts_weekday } - ct_ac_weekend = get_age_table(conts_weekend, parts_weekend, weeks, breaks, weights = weights) + ct_ac_weekend = get_age_table(conts_weekend, parts_weekend, weeks, breaks, weights = weights, setting = setting) cont_per_age_per_part_weekend = ct_ac_weekend[[1]] all_conts_weekend = ct_ac_weekend[[2]] - ct_ac_weekday = get_age_table(conts_weekday, parts_weekday, weeks, breaks, weights = weights) + ct_ac_weekday = get_age_table(conts_weekday, parts_weekday, weeks, breaks, weights = weights, setting = setting) cont_per_age_per_part_weekday = ct_ac_weekday[[1]] all_conts_weekday = ct_ac_weekday[[2]] @@ -85,8 +88,13 @@ calc_cm_general <- function(parts_ , conts_, breaks, max_ = 1000, popdata_totals } if (fitwith == 'bs'){ - outs_weekend = get_matrix_bs(cont_per_age_per_part_weekend, breaks, max_, bs=samples, trunc_flag=trunc_flag, setting=setting) - outs_weekday = get_matrix_bs(cont_per_age_per_part_weekday, breaks, max_, bs=samples, trunc_flag=trunc_flag, setting=setting) + if (zi){ + param = c("p","mu") + } else { + param = "mu" + } + outs_weekend = get_matrix_bs(cont_per_age_per_part_weekend, breaks, max_, param=param, bs=samples, trunc_flag=trunc_flag, zi=zi, setting=setting) + outs_weekday = get_matrix_bs(cont_per_age_per_part_weekday, breaks, max_, param=param, bs=samples, trunc_flag=trunc_flag, zi=zi, setting=setting) } mus = (outs_weekend[[2]] * 2./7) + (outs_weekday[[2]] * 5./7) diff --git a/r/functions/fitting_functions.R b/r/functions/fitting_functions.R index 41fc1ee..bcf05c0 100644 --- a/r/functions/fitting_functions.R +++ b/r/functions/fitting_functions.R @@ -41,6 +41,18 @@ poiss_loglik <- function(x, par, n){ return(-sum(ll)) } +# This function calculates the log-likelihood of a zero-inflated negative binomial distribution given set of parameters 'par' and data 'x'. +zinb_loglik <- function(x, par, n){ + p <- par[["p"]] + k <- par[["k"]] + mu <- par[["mu"]] + + ll <- rep(NA_real_, length(x)) + ll[x == 0] <- log(p + (1-p)*dnbinom(0, mu = mu, size = 1/k)) + ll[x > 0 & x < n] <- log(1-p) + dnbinom(x[x > 0 & x < n], mu = mu, size = 1/k, log = T) + ll[x >= n] <- log(1-p) + dnbinom(n, mu = mu, size = 1/k, log = T) + return(-sum(ll)) +} # This function calculates the log-likelihood of a truncated negative binomial distribution given set of parameters 'par', data 'x', and upper truncation limit 'n'. trunc_nb_loglik <- function(x, par, n) { diff --git a/r/functions/functions.R b/r/functions/functions.R index 05395bb..89f58de 100644 --- a/r/functions/functions.R +++ b/r/functions/functions.R @@ -81,7 +81,7 @@ get_contact_age_weights = function(weightforage=42, weightoffage=1){ } -sample_age_table <- function(prts, cnts, agegroupbreaks, weights=NULL){ +sample_age_table <- function(prts, cnts, agegroupbreaks, weights=NULL, setting=""){ #create column of ages of participants sampled_ages_prts = mapply(FUN = function(X, Y){sample(rep(seq(X,Y),2), 1)}, prts[!is.na(part_age_est_min)]$part_age_est_min, prts[!is.na(part_age_est_min)]$part_age_est_max) @@ -94,13 +94,23 @@ sample_age_table <- function(prts, cnts, agegroupbreaks, weights=NULL){ if(is.null(weights)){ #create column of ages of contacts sampled_ages_cnts = mapply(FUN = function(X, Y){sample(rep(seq(X,Y),2), 1)}, cnts[!is.na(cnt_age_est_min)]$cnt_age_est_min, cnts[!is.na(cnt_age_est_min)]$cnt_age_est_max) - cnts_ages = cnts[!is.na(cnt_age_est_min)][,c('part_id')] + if (setting==""){ + cnts_ages = cnts[!is.na(cnt_age_est_min)][,c('part_id')] + } else { + cols = c("part_id",paste0("cnt_",setting)) + cnts_ages = cnts[!is.na(cnt_age_est_min)][,..cols] + } cnts_ages[,'cnt_assigned_age_groups' :=sampled_ages_cnts] } else { probs = sapply(cnts$prt_assigned_age_groups, FUN = function(X){weights[,X]}) #create column of ages of contacts sampled_ages_cnts = mapply(FUN = function(X, Y, Z){sample(rep(seq(X,Y),2), 1, prob = nafill(rep(probs[[Z]][(X+1):(Y+1)],2),fill=1))}, cnts[!is.na(cnt_age_est_min)]$cnt_age_est_min, cnts[!is.na(cnt_age_est_min)]$cnt_age_est_max, 1:length(cnts[!is.na(cnt_age_est_min)]$cnt_age_est_max)) - cnts_ages = cnts[!is.na(cnt_age_est_min)][,c('part_id')] + if (setting==""){ + cnts_ages = cnts[!is.na(cnt_age_est_min)][,c('part_id')] + } else { + cols = c("part_id",paste0("cnt_",setting)) + cnts_ages = cnts[!is.na(cnt_age_est_min)][,..cols] + } cnts_ages[,'cnt_assigned_age_groups' :=sampled_ages_cnts] } @@ -205,7 +215,7 @@ symetricise_matrix = function(eg, popdata_totals, breaks) { } -get_age_table = function(conts, parts, week_choice=NULL, breaks, weights=NULL){ +get_age_table = function(conts, parts, week_choice=NULL, breaks, weights=NULL, setting=""){ if (!is.null(week_choice)){ @@ -217,7 +227,7 @@ get_age_table = function(conts, parts, week_choice=NULL, breaks, weights=NULL){ # Get numeric age ranges for every participant and contact possible #parts_conts <- assign_min_max_ages(parts, conts) # sample ages from ranges and construct a table of part_id, participant age and contact age - age_table <- sample_age_table(parts, conts, breaks, weights=weights) + age_table <- sample_age_table(parts, conts, breaks, weights=weights, setting=setting) #unique_parts = unique(age_table[,c("part_id", "prt_age_group")]) #sample_props = unique(unique_parts[,count:=.N, by=prt_age_group][,c("prt_age_group", "count")])[order(prt_age_group)] #sample_props[,props:=count/sum(count)] @@ -225,7 +235,11 @@ get_age_table = function(conts, parts, week_choice=NULL, breaks, weights=NULL){ age_table = age_table[complete.cases(age_table[,'prt_age_group']),] - age_table[,contacts := .N, by = c("part_id")] + if (setting==""){ + age_table[,contacts := .N, by = c("part_id")] + } else { + age_table[,contacts := sum(eval(parse(text = paste0("cnt_",setting)))), by = c("part_id")] + } all_conts = unique(age_table[,c('part_id', 'prt_age_group', 'contacts' )]) age_group_table = age_table[,c('prt_age_group', 'cnt_age_group')] @@ -235,7 +249,11 @@ get_age_table = function(conts, parts, week_choice=NULL, breaks, weights=NULL){ count_mat_total = table(lapply(age_group_table, factor, sort(levs))) #Get counts per age group per participant - cont_per_age_per_part = age_table[, (contacts = .N), by = .(part_id, prt_age_group, cnt_age_group)] + if (setting==""){ + cont_per_age_per_part = age_table[, (contacts = .N), by = .(part_id, prt_age_group, cnt_age_group)] + } else { + cont_per_age_per_part = age_table[, (contacts = sum(eval(parse(text = paste0("cnt_",setting))))), by = .(part_id, prt_age_group, cnt_age_group)] + } # this bit is required to add 0s into the table i.e. where no contacts were reported with some age groups by some contacts. # make 'complete' data.table with all combinations of part_id and cnt_age_group @@ -410,6 +428,19 @@ poiss_optim <- function(counts_, n) { return(out) } +zinb_optim <- function(counts_, n, param) { + outs = optim(c(p = 0.8, mu = 0.5, k = 1), lower = c(p = 1e-8, mu = 1e-8, k = 1e-8), upper = c(p = 1 - 1e-8, mu = Inf, k = Inf), zinb_loglik, x = counts_, n = n, method = "L-BFGS-B") + if(outs$convergence==0){ # return value from optimisation if it converges + out = outs$par[param] + } else { # print error message and final value from optimisation if it fails to converge, return 0 + # print(outs$message) + # print(outs$par[param]) + out = rep(0,length(param)) + names(out) <- param + } + return(out) +} + trunc_nb_optim <- function(counts_, n, param) { # Remove values greater than truncation limit counts_ = counts_[counts_<=n] @@ -456,7 +487,7 @@ trunc_poiss_optim <- function(counts_, n) { return(out) } -nbinom_optim_bs <- function(i, j, param, n, count_frame, bs = 1, trunc_flag = F) { +nbinom_optim_bs <- function(i, j, param, n, count_frame, bs = 1, trunc_flag = F, zi = F) { counts = count_frame[count_frame$prt_age_group == i & count_frame$cnt_age_group == j]$V1 @@ -476,15 +507,21 @@ nbinom_optim_bs <- function(i, j, param, n, count_frame, bs = 1, trunc_flag = F) # outs_mat = foreach(k=1:bs,.combine = 'c') %dopar% { for (k in 1:bs){ counts_smpl = sample(counts, replace = TRUE) - if (trunc_flag){ - outs_mat[k] = trunc_nb_optim(counts_smpl, n, param) + if (zi){ + out = zinb_optim(counts_smpl, n, param) + outs_mat[k] = (1 - out["p"]) * out["mu"] } else { - outs_mat[k] = nb_optim(counts_smpl, n, param) + if (trunc_flag){ + outs_mat[k] = trunc_nb_optim(counts_smpl, n, param) + } else { + outs_mat[k] = nb_optim(counts_smpl, n, param) + } } } } else{ outs_mat = rep(0,bs) } + return(outs_mat) } @@ -520,7 +557,7 @@ poiss_optim_bs <- function(i, j, param, n, count_frame, bs = 1, trunc_flag = F) return(outs_mat) } -get_matrix_bs = function(cont_per_age_per_part, breaks, trunc, param = 'mu', bs = 1, trunc_flag = F, setting = "") { +get_matrix_bs = function(cont_per_age_per_part, breaks, trunc, param = 'mu', bs = 1, trunc_flag = F, zi = F, setting = "") { levs <- unique(unlist(cut(seq(0,120),breaks, right=FALSE), use.names = FALSE)) # Get columns of age-groups to put into mapply @@ -528,9 +565,9 @@ get_matrix_bs = function(cont_per_age_per_part, breaks, trunc, param = 'mu', bs names(eg) = c('age_group', 'age_group_cont') # Get means from neg_binom regression (or Poisson regression if setting="home") if (setting=="home"){ - means_mat <- mapply(poiss_optim_bs, eg$age_group, eg$age_group_cont, param=param, n=trunc, bs = bs, count_frame=list(cont_per_age_per_part), trunc_flag = trunc_flag) + means_mat <- mapply(poiss_optim_bs, eg$age_group, eg$age_group_cont, MoreArgs = list(param=param, n=trunc, bs = bs, count_frame=cont_per_age_per_part, trunc_flag = trunc_flag)) } else { - means_mat <- mapply(nbinom_optim_bs, eg$age_group, eg$age_group_cont, param=param, n=trunc, bs = bs, count_frame=list(cont_per_age_per_part), trunc_flag = trunc_flag) + means_mat <- mapply(nbinom_optim_bs, eg$age_group, eg$age_group_cont, MoreArgs = list(param=param, n=trunc, bs = bs, count_frame=cont_per_age_per_part, trunc_flag = trunc_flag, zi = zi)) } # eg['size'] = mapply(nbinom_optim, eg$age_group, eg$age_group_cont, param='size') means_mat <- matrix(unlist(means_mat), ncol = (length(breaks)-1)^2) diff --git a/r/presentation/plot_comparison_cms.R b/r/presentation/plot_comparison_cms.R index 098b088..fb93451 100644 --- a/r/presentation/plot_comparison_cms.R +++ b/r/presentation/plot_comparison_cms.R @@ -90,7 +90,7 @@ plot_cms_comparison = function(all_egs, pair_sr, region=''){ geom_tile()+ geom_text(color='white', size=5)+ facet_wrap( ~sr, ncol=1) + - scale_fill_viridis(discrete=FALSE, name='Mean \ncontacts', begin=0, end=1., limits = c(0,7.5))+ + scale_fill_viridis(discrete=FALSE, name='Mean \ncontacts', begin=0, end=1., limits = c(0,4.))+ ylab('Contact age group') + xlab('Participant age group') + theme(axis.line=element_blank(), @@ -154,7 +154,7 @@ plot_all_cms = function(all_egs, region='', periods, breaks = c(0,5,12,18,30,40, geom_tile()+ geom_text(color='white', size=5)+ facet_wrap( ~sr) + - scale_fill_viridis(discrete=FALSE, name='Mean \ncontacts', begin=0, end=1., limits = c(0.02,7.5), trans='log', breaks=c(0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2., 4.), na.value='black')+ + scale_fill_viridis(discrete=FALSE, name='Mean \ncontacts', begin=0, end=1., limits = c(0.02,4.5), trans='log', breaks=c(0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2., 4.), na.value='black')+ guides(fill=guide_colorbar(barheight=30))+ ylab('Contact age group') + xlab('Participant age group') + diff --git a/r/presentation/setting_specific_contact_matrix_plots.R b/r/presentation/setting_specific_contact_matrix_plots.R index b873855..867476d 100644 --- a/r/presentation/setting_specific_contact_matrix_plots.R +++ b/r/presentation/setting_specific_contact_matrix_plots.R @@ -16,14 +16,8 @@ nwk = c(10,8,5,10,3,2,3,9,6) samples_ = 1000 fit_with_ = 'bs' max_ = 50 -trunc_flag_ = T - -outfolder=paste0('outputs/setting_specific/', country_names[i], '/') -filename_primer = paste0(outfolder, 'contact_matrices/', fit_with_, samples_, '_ngrps', length(breaks) - 1, '_cap', max_) -fnms = character(length(week_range)) -for (k in 1:length(week_range)){ - fnms[k] = paste0(filename_primer, '_nwks', nwk[k],'_sr', week_range[k]) -} +trunc_flag_ = F +zi_ = T periods = c('1. Lockdown 1', '2. Lockdown 1 easing', @@ -36,11 +30,21 @@ periods = c('1. Lockdown 1', '9. Lockdown 3 + schools') for (i in 1:length(country_names)){ - for (j in 1:length(settings)){ + outfolder=paste0('outputs/setting_specific/', country_names[i], '/') + filename_primer = paste0(outfolder, 'contact_matrices/', fit_with_, samples_, '_ngrps', length(breaks) - 1, '_cap', max_) + fnms = character(length(week_range)) + for (k in 1:length(week_range)){ + fnms[k] = paste0(filename_primer, '_nwks', nwk[k],'_sr', week_range[k]) + } + + for (j in 1:length(settings)){ + fnms_setting <- paste0(fnms, '_', settings[j]) + if (zi_){ + fnms_setting <- paste0(fnms_setting,"_zi") + } if (trunc_flag_){ - print("hello") fnms_setting <- paste0(fnms_setting,"_trunc") } fnms_setting <- paste0(fnms_setting,'_scms.qs') @@ -195,12 +199,14 @@ for (i in 1:length(country_names)){ # write.csv(dates_pres, 'periods_eigenvalues.csv') # # ggsave('all_mats.pdf', all_mats_dates, width=20, height=20) + figname_primer <- paste0("all_mats_",settings[j]) + if (zi_){ + figname_primer <- paste0(figname_primer,"_zi") + } if (trunc_flag_){ - ggsave(paste0('all_mats_',settings[j],'_trunc.pdf'), all_mats, width=20, height=20) - } else { - ggsave(paste0('all_mats_',settings[j],'.pdf'), all_mats, width=20, height=20) + figname_primer <- paste0(figname_primer,"_trunc") } - + ggsave(paste0(figname_primer,'.pdf'), all_mats, width=20, height=20) } } From 7f0cb53739e30fedfedb8fd7f15e7d25b16d2164 Mon Sep 17 00:00:00 2001 From: Lloyd Chapman Date: Fri, 9 Jul 2021 15:04:57 +0100 Subject: [PATCH 3/3] Calculate setting-specific matrices for contact-adjusted immunity analysis --- ...culate_setting_specific_contact_matrices.R | 24 +++++++------ .../setting_specific_contact_matrix_plots.R | 36 +++++++++++-------- 2 files changed, 34 insertions(+), 26 deletions(-) diff --git a/r/analyses/run_calculate_setting_specific_contact_matrices.R b/r/analyses/run_calculate_setting_specific_contact_matrices.R index a1bed1d..68fbc6c 100644 --- a/r/analyses/run_calculate_setting_specific_contact_matrices.R +++ b/r/analyses/run_calculate_setting_specific_contact_matrices.R @@ -36,7 +36,7 @@ contacts = data[[1]] parts = data[[2]] start_date = lubridate::ymd('20200323') -end_date = lubridate::ymd('20210421') +end_date = lubridate::ymd('20210623') parts = parts[between(date, start_date, end_date)] contacts = contacts[between(date, start_date, end_date)] @@ -48,9 +48,9 @@ breaks = c(0,5,12,18,30,40,50,60,70,Inf) max_ = 50 # upper limit for censoring/truncation popdata_totals = get_popvec(breaks, year_ = 2020) weeks_in_parts = sort(unique(parts$survey_round)) -week_range = c(1,11,19,24,34,37,39,42,51) #c(min(weeks_in_parts):max(weeks_in_parts)) +week_range = c(53,54,57:63) #c(1,11,19,24,34,37,39,42,51) #c(min(weeks_in_parts):max(weeks_in_parts)) #week_range = 34:51 -nwk = c(10,8,5,10,3,2,3,9,6) +nwk = rep(2,length(week_range)) #c(10,8,5,10,3,2,3,9,6) samples_ = 1000 fit_with_ = 'bs' trunc_flag_ = F # flag for whether or not to use truncation rather than uncorrected censoring @@ -73,9 +73,9 @@ settings = c("home","school","work","other") weights = get_contact_age_weights() -for (i in 1:length(country_names)){ +# for (i in 1:length(country_names)){ #for (i in 1:length(regions)){ -# for (i in 1:length(nation_names)){ +for (i in 1:length(nation_names)){ for (j in 1:length(settings)){ # for (k in 1:length(week_range)){ lcms = foreach(k=1:length(week_range)) %dopar% { @@ -87,25 +87,27 @@ for (i in 1:length(country_names)){ # parts_nation <- parts[part_wave_uid %in% unique_wave_pid] # parts_nation <- parts[area_3_name %in% regions[[i]]] - # parts_nation <- parts[area_3_name %in% nations[[i]]] - parts_nation <- parts[country %in% countries[[i]]] + parts_nation <- parts[area_3_name %in% nations[[i]]] + # parts_nation <- parts[country %in% countries[[i]]] unique_wave_pid <- unique(parts_nation$part_wave_uid) contacts_nation <- contacts[part_wave_uid %in% unique_wave_pid] - # print(nations[[i]]) - print(countries[[i]]) + print(nations[[i]]) + # print(countries[[i]]) # calculate contact matrices------------------------------------------------------------- - outfolder=paste0('outputs/setting_specific/', country_names[i], '/') + outfolder=paste0('outputs/setting_specific/', nation_names[i], '/') + # outfolder=paste0('outputs/setting_specific/', country_names[i], '/') if(!dir.exists(outfolder)){ dir.create(outfolder, recursive = TRUE) } - cms_max50 = calc_cm_general(parts_nation, contacts_nation, breaks, max_ = max_, popdata_totals, weeks_range = week_range[k], nwks=nwk[k], outfolder=outfolder, fitwith=fit_with_, samples=samples_, weights=NULL, trunc_flag=trunc_flag_, zi=zi_, setting=settings[j]) + # cms_max50 = calc_cm_general(parts_nation, contacts_nation, breaks, max_ = max_, popdata_totals, weeks_range = week_range[k], nwks=nwk[k], outfolder=outfolder, fitwith=fit_with_, samples=samples_, weights=NULL, trunc_flag=trunc_flag_, zi=zi_, setting=settings[j]) + calc_cm_general(parts_nation, contacts_nation, breaks, max_ = max_, popdata_totals, weeks_range = week_range[k], nwks=nwk[k], outfolder=outfolder, fitwith=fit_with_, samples=samples_, weights=NULL, trunc_flag=trunc_flag_, zi=zi_, setting=settings[j]) } diff --git a/r/presentation/setting_specific_contact_matrix_plots.R b/r/presentation/setting_specific_contact_matrix_plots.R index 867476d..d5f2bfc 100644 --- a/r/presentation/setting_specific_contact_matrix_plots.R +++ b/r/presentation/setting_specific_contact_matrix_plots.R @@ -11,27 +11,32 @@ country_names = c("uk") settings = c("home","school","work","other") breaks = c(0,5,12,18,30,40,50,60,70,Inf) -week_range = c(1,11,19,24,34,37,39,42,51) -nwk = c(10,8,5,10,3,2,3,9,6) +week_range = c(53,54,57:63) #c(1,11,19,24,34,37,39,42,51) +nwk = rep(2,length(week_range)) #c(10,8,5,10,3,2,3,9,6) samples_ = 1000 fit_with_ = 'bs' max_ = 50 trunc_flag_ = F zi_ = T -periods = c('1. Lockdown 1', - '2. Lockdown 1 easing', - '3. Relaxed restrictions', - '4. School reopening', - '5. Lockdown 2', - '6. Lockdown 2 easing', - '7. Christmas', - '8. Lockdown 3', - '9. Lockdown 3 + schools') +# periods = c('1. Lockdown 1', +# '2. Lockdown 1 easing', +# '3. Relaxed restrictions', +# '4. School reopening', +# '5. Lockdown 2', +# '6. Lockdown 2 easing', +# '7. Christmas', +# '8. Lockdown 3', +# '9. Lockdown 3 + schools') +periods = as.character(week_range) -for (i in 1:length(country_names)){ +nation_names = c("England", "Scotland", "Wales")[1] + +for (i in 1:length(nation_names)){ +# for (i in 1:length(country_names)){ - outfolder=paste0('outputs/setting_specific/', country_names[i], '/') + outfolder=paste0('outputs/setting_specific/', nation_names[i], '/') + # outfolder=paste0('outputs/setting_specific/', country_names[i], '/') filename_primer = paste0(outfolder, 'contact_matrices/', fit_with_, samples_, '_ngrps', length(breaks) - 1, '_cap', max_) fnms = character(length(week_range)) for (k in 1:length(week_range)){ @@ -74,7 +79,7 @@ for (i in 1:length(country_names)){ # ggsave('compare_mat_LD3_plus.pdf', mat5, width=20, height=20) - all_mats = plot_all_cms(all_egs = all_egs, periods = periods, title='A') + all_mats = plot_all_cms(all_egs = all_egs, periods = periods) #, title='A') # dates = data.table::transpose( # data.table( @@ -199,7 +204,8 @@ for (i in 1:length(country_names)){ # write.csv(dates_pres, 'periods_eigenvalues.csv') # # ggsave('all_mats.pdf', all_mats_dates, width=20, height=20) - figname_primer <- paste0("all_mats_",settings[j]) + figname_primer <- paste0(outfolder,"contact_matrices/all_mats_",settings[j]) + # figname_primer <- paste0("all_mats_",settings[j]) if (zi_){ figname_primer <- paste0(figname_primer,"_zi") }