This document contains supplementary material for the paper titled “Prognostic model to identify and quantify risk factors for mortality among hospitalised patients with COVID-19 in the USA”. It contains all tables and figures presented in the paper, as well as supplemental results. All code for the statistical analyses can be viewed in this document and are available at our GitHub repository, but note that we are unfortunately not allowed to publicly share the data.
The figures and tables in the main text are contained in the following sections:
Analyses were run using the following R
packages and settings.
# R packages
library("corrr")
library("dplyr")
library("DT")
library("ggplot2")
library("glmnet")
library("oem")
library("gridExtra")
library("knitr")
library("kableExtra")
library("magrittr")
library("mice")
library("purrr")
library("rcompanion")
library("rms")
library("splines")
library("tableone")
library("tibble")
library("tidyr")
# Namespace clash
select <- dplyr::select
# Settings
set.seed(123)
theme_set(theme_bw())
n_imputations <- 5 # Number of multiple imputations with MICE
n_boot_val <- 50 # Number of bootstraps with rms validation and calibration
n_boot_probs <- 100 # Number of bootstraps for computing predicted probability CIs
n_rep <- 20 # Number of repeats with group lasso (for each of the n_imputations)
We begin by loading the “training” dataset and restrict to (i) patients age 18 and older and (ii) with an index date more than 2 weeks prior to the data release. We will also do a small amount of data “cleaning” and create nice labels for possible predictor variables.
filter_ie <- function(data){
data1 <- data %>%
filter(age >= 18)
data2 <- data1 %>%
filter(as.Date("2020-06-05") - index_date > 14)
n_dropped <- nrow(data1) - nrow(data2)
percent_dropped <- formatC(100 * n_dropped/nrow(data1),
format = "f", digits = 2)
message(n_dropped, " (", percent_dropped, "%)",
" patients were dropped due to the 2-week cutoff.")
return(data2)
}
train_data <- readRDS("train_data.rds") %>%
filter_ie()
## 750 (5.21%) patients were dropped due to the 2-week cutoff.
# Function to add race/ethnicity variable to dataset (done after separately
# imputing missing race + ethnicity information)
add_race_ethnicity <- function(data){
data %>% mutate(
race_ethnicity = case_when(
race == "Caucasian" & ethnicity != "Hispanic" ~ "Non-Hispanic white",
race == "African American" & ethnicity != "Hispanic" ~ "Non-Hispanic black",
race == "Asian" & ethnicity != "Hispanic" ~ "Asian",
is.na(race) | is.na(ethnicity) ~ NA_character_,
TRUE ~ "Hispanic"
),
race_ethnicity = relevel(factor(race_ethnicity), ref = "Non-Hispanic white")
)
}
clean_data <- function(data){
# Recoding
data <- data %>%
mutate(
## "died" should be an integer
died = as.integer(died),
## Convert unknown or other/unknown to missing
race = ifelse(race == "Other/Unknown", NA, race),
ethnicity = ifelse(ethnicity == "Unknown", NA, ethnicity),
sex = ifelse(sex == "Unknown", NA, sex),
## Oxygen saturation should have plausible values
spo2 = ifelse(spo2 == 0, NA_real_, spo2),
spo2 = ifelse(spo2 > 100, 100, spo2),
# # Division
## Set "Other" region to missing since all 9 geographic regions are in the data
division = ifelse(division == "Other", NA, division),
## Move small categories to other
division = ifelse(division %in% c("East South Central", "Mountain"),
"Other",
division)
) %>%
## Better names for some variables
rename(diabunc = diab, cci = score) %>%
## CCI should be an integer
mutate(cci = as.integer(cci)) %>%
## Convert comorbidities to character
mutate_at(
c("ami", "chf", "pvd", "cevd", "dementia", "cpd", "rheumd", "pud",
"mld", "diabunc", "diabwc", "hp", "rend", "canc", "msld", "metacanc",
"aids", "hypc", "hypunc"),
function (x) ifelse(x == 1, "Yes", "No")
)
# Create "derived" variables
data <- data %>%
mutate(
calendar_time = as.numeric(index_date - min(index_date)),
index_month = as.integer(format(index_date, "%m")),
death_month = as.integer(format(date_of_death,"%m")),
os_days = pmax(0, date_of_death - index_date),
## Categorize BMI
bmi_cat = case_when(
bmi < 18.5 ~ "Underweight",
bmi >= 18.5 & bmi < 25 ~ "Normal",
bmi >= 25 & bmi < 30 ~ "Overweight",
bmi >= 30 ~ "Obese",
TRUE ~ NA_character_
),
## Combine comorbidities
diab = case_when(
diabunc == "Yes" | diabwc == "Yes" ~ "Yes",
TRUE ~ "No"
),
hyp = case_when(
hypc == "Yes" | hypunc == "Yes" ~ "Yes",
TRUE ~ "No"
)
) %>%
## Create race/ethnicity variable
## Later to be created after imputation from imputed race + ethnicity
add_race_ethnicity()
# Return
return(data)
}
train_data <- clean_data(train_data)
# Labels
## Categorical variables
demographic_cat_vars <- tribble(
~var, ~varlab,
"sex", "Sex",
"race", "Race",
"ethnicity", "Ethnicity",
"division", "Geographic division",
"smoke", "Smoking",
"race_ethnicity", 'Race/Ethnicity'
) %>%
mutate(group = "Demographics")
comorbidity_cat_vars <- tribble(
~var, ~varlab,
"ami", "Acute myocardial infarction",
"chf", "Congestive heart failure",
"pvd", "Peripheral vascular disease",
"cevd", "Cerebrovascular disease",
"dementia", "Dementia",
"cpd", "Chronic pulmonary disease",
"rheumd", "Rheumatoid disease",
"pud", "Peptic ulcer disease",
"mld", "Mild liver disease",
"diabunc", "Diabetes (no complications)",
"diabwc", "Diabetes (complications)",
"hp", "Hemiplegia or paraplegia",
"rend", "Renal disease",
"canc", "Cancer",
"msld", "Moderate/severe liver disease",
"metacanc", "Metastatic cancer",
"aids", "AIDS/HIV",
"hypunc", "Hypertension (no complications)",
"hypc", "Hypertension (complications)",
"diab", "Diabetes", # Combines diabunc and diabwc
"hyp", "Hypertension", # Combine hypunc and hypc
) %>%
mutate(group = "Comorbidities")
vital_cat_vars <- tribble(
~var, ~varlab,
"bmi_cat", "Body Mass Index (BMI)",
) %>%
mutate(group = "Vitals")
cat_vars <- bind_rows(demographic_cat_vars,
comorbidity_cat_vars,
vital_cat_vars)
## Continuous variables
demographic_continuous_vars <- tribble(
~var, ~varlab,
"age", "Age",
"calendar_time", "Calendar time"
) %>%
mutate(group = "Demographics")
comorbidity_continuous_vars <- tribble(
~var, ~varlab,
"cci", "CCI",
) %>%
mutate(group = "Comorbidities")
vital_continuous_vars <- tribble(
~var, ~varlab,
"bmi", "Body Mass Index (BMI)",
"dbp", "Diastolic blood pressure",
"sbp", "Systolic blood pressure",
"hr", "Heart rate",
"resp", "Respiration rate",
"spo2", "Oxygen saturation",
"temp", "Temperature",
) %>%
mutate(group = "Vitals")
lab_vars <- tribble(
~var, ~varlab,
"alt", "Alanine aminotransferase (ALT)",
"ast", "Aspartate aminotransferase (AST)",
"crp", "C-reactive protein (CRP)",
"creatinine", "Creatinine",
"ferritin", "Ferritin",
"d_dimer", "Fibrin D-Dimer",
"ldh", "Lactate dehydrogenase (LDH)",
"lymphocyte", "Lymphocyte count",
"neutrophil", "Neutrophil count",
"pct", "Procalcitonin",
"tni", "Troponin I",
"plt", "Platelet count (PLT)",
"wbc", "White blood cell count (WBC)"
) %>%
mutate(group = "Labs")
continuous_vars <- bind_rows(
demographic_continuous_vars,
comorbidity_continuous_vars,
vital_continuous_vars,
lab_vars
)
# All variables
vars <- bind_rows(cat_vars,
continuous_vars)
get_var_labs <- function(v){
vars$varlab[match(v, vars$var)]
}
Before starting modeling, we will carefully inspect the data. This includes (i) checking the sample size, (ii) inspecting the extent of missing data, (iii) checking the distributions of possible predictor variables (and of death) including presence of potential outliers, (iv) assessing collinearities through plots of bivariate relationships between predictors, (v) summarizing the distribution of variables with a “Table 1”, and (vi) looking at univariate fits between death and the continuous predictors to assess potential non-linearities and the number of knots that should be included when using splines.
train_data %>%
count(name = "Sample size") %>%
mutate(Data = "Optum training set") %>%
select(Data, `Sample size`) %>%
kable() %>%
kable_styling()
Data | Sample size |
---|---|
Optum training set | 13658 |
missing_df <- train_data %>%
select(one_of(vars$var)) %>%
mutate_all(function (x) ifelse(is.na(x), 1, 0)) %>%
mutate(id = factor(1:n())) %>%
pivot_longer(cols = vars$var, names_to = "var", values_to = "missing") %>%
left_join(vars, by = "var")
# Compute proportion missing
prop_missing <- missing_df %>%
group_by(varlab) %>%
summarise(prop = mean(missing))
# Plot
ggplot(prop_missing, aes(x = varlab, y = prop)) +
geom_bar(stat = "identity") +
geom_text(aes(label = formatC(prop, format = "f", digits = 2)),
nudge_y = .03, size = 3) +
ylim(c(0, 1)) +
xlab("") +
ylab("Proportion") +
coord_flip() +
scale_x_discrete(limits = rev(sort(vars$varlab)))
ggplot(missing_df,
aes(x = id, y = varlab, fill = factor(missing))) +
geom_raster() +
theme(axis.title.x = element_blank(),
axis.text.x = element_blank(),
axis.ticks.x = element_blank(),
axis.title.y = element_blank(),
legend.position = "bottom") +
scale_fill_manual(name = "Missing",
values = c("lightgrey", "steelblue"),
labels = c("No", "Yes")) +
scale_y_discrete(limits = rev(sort(vars$varlab)))
missing_df %>%
# Count of missing by patient
group_by(id) %>%
summarise(n_missing = sum(missing),
prop_missing = n_missing/n()) %>%
# Plot
ggplot(aes(x = prop_missing)) +
geom_histogram(binwidth = .03, color = "white") +
scale_x_continuous(breaks = seq(0, 1, .05)) +
scale_y_continuous(n.breaks = 20) +
xlab("Proportion of predictors that are missing") +
ylab("Count")
cat_var_df <- train_data %>%
select(one_of("ptid", cat_vars$var)) %>%
pivot_longer(cols = cat_vars$var, names_to = "var", values_to = "value") %>%
left_join(cat_vars, by = "var") %>%
filter(!is.na(value)) %>%
group_by(var, varlab, value) %>%
summarise(n = n()) %>%
group_by(varlab) %>%
mutate(freq = n / sum(n)) %>%
ungroup() %>%
mutate(
nudge_x = case_when(
freq < 0.5 ~ 0.15,
TRUE ~ -0.15
)
)
ggplot(cat_var_df,
aes(x = freq, y = value)) +
geom_point() +
geom_text(aes(label = formatC(freq, format = "f", digits = 2)),
nudge_x = cat_var_df$nudge_x, size = 3.5) +
facet_wrap(~varlab, scales = "free_y", ncol = 4) +
xlim(0, 1) +
xlab("Proportion") +
ylab("") +
theme(axis.text.x = element_text(size = 10),
axis.text.y = element_text(size = 10),
strip.text.x = element_text(size = 7))
pivot_continuous_longer <- function(data, vars){
col_names <- vars$var
train_data %>%
select(one_of("ptid", col_names)) %>%
pivot_longer(cols = col_names,
names_to = "var",
values_to = "value") %>%
left_join(vars, by = "var") %>%
filter(!is.na(value))
}
continuous_var_df <- pivot_continuous_longer(train_data,
vars = continuous_vars)
plot_box <- function(data){
ggplot(data,
aes(x = varlab, y = value)) +
geom_boxplot(outlier.size = 1) +
facet_wrap(~varlab, scales = "free") +
xlab("") +
ylab("Value") +
theme(axis.title.x=element_blank(),
axis.text.x=element_blank(),
axis.ticks.x=element_blank(),
strip.text = element_text(size = 7))
}
plot_box(continuous_var_df)
plot_hist <- function(data){
ggplot(data,
aes(x = value)) +
geom_histogram(bins = 40, color = "white") +
facet_wrap(~varlab, scales = "free", ncol = 4) +
xlab("") + ylab("Frequency") +
theme(strip.text = element_text(size = 7))
}
plot_hist(continuous_var_df)
Visual inspection of the box plots and histograms suggests that there are significant outliers in the labs. Let’s look at how many observations lie above the 99th percentile of the data and the “outer fence” (defined as the 3rd quartile plus 3 time the interquartile range). We will then create new lab variables truncated from above at the outer fence and replot the histograms.
outer_fence <- function(v){
q1 <- quantile(v, .25, na.rm = TRUE)
q3 <- quantile(v, .75, na.rm = TRUE)
iq <- (q3 - q1)
return(as.numeric(q3 + 3 * iq))
}
format_percent <- function(x){
paste0(formatC(100 * x, format = "f", digits = 1), "%")
}
train_data %>%
select(one_of(lab_vars$var)) %>%
pivot_longer(cols = lab_vars$var, names_to = "Lab") %>%
filter(!is.na(value)) %>%
group_by(Lab) %>%
summarise(Maximum = max(value),
`99%` = quantile(value, .99),
`Outer fence` = outer_fence(value),
`% above outer fence` = format_percent(mean(value > outer_fence(value)))) %>%
mutate(Lab = get_var_labs(Lab)) %>%
kable() %>%
kable_styling()
Lab | Maximum | 99% | Outer fence | % above outer fence |
---|---|---|---|---|
Alanine aminotransferase (ALT) | 3017.00 | 224.700000 | 130.00 | 3.4% |
Aspartate aminotransferase (AST) | 7000.01 | 336.000000 | 157.00 | 3.6% |
Creatinine | 27.35 | 10.880500 | 3.24 | 6.4% |
C-reactive protein (CRP) | 638.00 | 359.030000 | 458.00 | 0.2% |
Fibrin D-Dimer | 114475.00 | 19499.300000 | 4993.00 | 7.0% |
Ferritin | 100000.01 | 7746.620000 | 3648.00 | 3.4% |
Lactate dehydrogenase (LDH) | 14007.00 | 1262.440000 | 1050.00 | 1.9% |
Lymphocyte count | 120.35 | 3.740000 | 3.50 | 1.2% |
Neutrophil count | 82.40 | 18.900000 | 18.29 | 1.2% |
Procalcitonin | 753.66 | 30.122600 | 1.27 | 10.2% |
Platelet count (PLT) | 1213.00 | 527.860000 | 569.00 | 0.7% |
Troponin I | 95.40 | 2.190225 | 0.17 | 8.1% |
White blood cell count (WBC) | 127.50 | 22.600000 | 21.70 | 1.2% |
# Truncate labs using outer fence
truncate_max <- function(v) outer_fence(v)
add_truncated_lab_vars <- function(data, v){
for (i in 1:length(v)){ # Start loop over labs
original_var <- v[i]
truncated_var <- paste0(original_var, "_t")
truncated_max_i <- truncate_max(data[[original_var]])
data <- data %>% mutate(
!!truncated_var := ifelse(get(original_var) > truncated_max_i,
truncated_max_i,
get(original_var))
)
} # End loop over labs
return(data)
}
train_data <- add_truncated_lab_vars(train_data, v = lab_vars$var)
After truncating the labs, the probability distributions appear more reasonable.
lab_vars_t <- lab_vars %>%
mutate(var = paste0(var, "_t"))
vars <- bind_rows(vars, lab_vars_t)
continuous_vars_t <- bind_rows(
demographic_continuous_vars,
vital_continuous_vars,
lab_vars_t
)
continuous_var_t_df <- pivot_continuous_longer(train_data,
vars = continuous_vars_t)
plot_hist(continuous_var_t_df)
train_data %>%
count(died) %>%
mutate(died = ifelse(died == 1, "Yes", "No"),
prop = n/sum(n)) %>%
ggplot(aes(x = died, y = prop)) +
geom_bar(stat = "identity") +
geom_text(aes(label = formatC(prop, format = "f", digits = 2)),
nudge_y = .03, size = 3) +
xlab("Died") +
ylab("Proportion")
train_data %>%
mutate(months_to_death = death_month - index_month) %>%
filter(!is.na(months_to_death)) %>%
group_by(months_to_death) %>%
tally() %>%
# Plot
ggplot(aes(x = factor(months_to_death), y = n)) +
geom_bar(stat = "identity", position = "dodge") +
xlab("Months from index date to death") +
ylab("Count")
We start with the Charlson Comorbidity Index (CCI).
ggplot(train_data, aes_string(x = "age", y = "cci")) +
geom_point() +
geom_smooth() +
xlab(get_var_labs("age")) +
ylab(get_var_labs("cci"))
Next, we will examine each comorbidity separately by plotting the probability that a patient has each comorbidity as a function of their age.
train_data[, c("age", comorbidity_cat_vars$var)] %>%
pivot_longer(cols = comorbidity_cat_vars$var,
names_to = "comorbidity") %>%
mutate(value = ifelse(value == "No", 0L, 1L),
comorbidity = get_var_labs(comorbidity)) %>%
# Make plot
ggplot(aes(x = age, y = value)) +
geom_smooth() +
facet_wrap(~comorbidity, ncol = 4) +
xlab(get_var_labs("age")) +
ylab("Probability")
plot_scatter_continuous <- function(x_var, y_vars){
train_data[, c(x_var, y_vars)] %>%
pivot_longer(cols = all_of(y_vars)) %>%
mutate(name = get_var_labs(name)) %>%
# Make plot
ggplot(aes_string(x = x_var, y = "value")) +
geom_smooth(se = FALSE) +
facet_wrap(~name, ncol = 4, scales = "free_y") +
xlab(get_var_labs(x_var)) +
ylab("Value")
}
plot_scatter_continuous("age", vital_continuous_vars$var)
plot_scatter_continuous("age", lab_vars_t$var) +
theme(strip.text.x = element_text(size = 7))
plot_scatter_continuous("cci", c(vital_continuous_vars$var, lab_vars_t$var)) +
scale_x_continuous(breaks = sort(unique(train_data$cci))) +
theme(strip.text.x = element_text(size = 7))
# Select variables for table 1
varnames_to_remove <- c("race", "ethnicity", "bmi_cat", "hypunc", "hypc",
"diabunc", "diabwc")
tbl1_varnames <- bind_rows(
demographic_continuous_vars,
demographic_cat_vars %>% arrange(varlab),
comorbidity_cat_vars %>% arrange(varlab),
comorbidity_continuous_vars,
vital_continuous_vars %>% arrange(varlab),
lab_vars %>% arrange(varlab)
) %>%
filter(!var %in% varnames_to_remove) %>%
pull(var)
tbl1_cat_varnames <- bind_rows(
demographic_cat_vars,
comorbidity_cat_vars,
vital_cat_vars
) %>%
filter(!var %in% varnames_to_remove) %>%
pull(var)
tbl1_non_normal_varnames <- c(
demographic_continuous_vars$var,
vital_continuous_vars$var,
comorbidity_continuous_vars$var,
lab_vars$var
)
## Create a TableOne object
tbl1_train <- train_data %>%
select(one_of(tbl1_varnames, "died")) %>%
rename_with(get_var_labs, .cols = all_of(tbl1_varnames)) %>%
rename(Survivor = died) %>%
CreateTableOne(vars = get_var_labs(tbl1_varnames),
factorVars = get_var_labs(tbl1_cat_varnames),
strata = "Survivor", addOverall = TRUE,
data = .)
## Print table 1
print(tbl1_train, nonnormal = get_var_labs(tbl1_non_normal_varnames),
cramVars = c("Sex"),
contDigits = 1, missing = TRUE, printToggle = FALSE) %>%
set_colnames(c("Overall", "Survivor", "Non-survivor",
"p", "Test", "Missing")) %>%
kable() %>%
kable_styling()
Overall | Survivor | Non-survivor | p | Test | Missing | |
---|---|---|---|---|---|---|
n | 13658 | 11495 | 2163 | |||
Age (median [IQR]) | 62.0 [49.0, 75.0] | 59.0 [46.0, 71.0] | 77.0 [67.0, 85.0] | <0.001 | nonnorm | 0.0 |
Calendar time (median [IQR]) | 47.0 [38.0, 64.0] | 47.0 [38.0, 64.0] | 46.0 [37.0, 60.0] | <0.001 | nonnorm | 0.0 |
Geographic division (%) | <0.001 | 2.7 | ||||
East North Central | 4627 (34.8) | 3954 (35.3) | 673 (31.9) | |||
Middle Atlantic | 4636 (34.9) | 3844 (34.4) | 792 (37.6) | |||
New England | 1583 (11.9) | 1272 (11.4) | 311 (14.8) | |||
Other | 191 ( 1.4) | 175 ( 1.6) | 16 ( 0.8) | |||
Pacific | 511 ( 3.8) | 438 ( 3.9) | 73 ( 3.5) | |||
South Atl/West South Crl | 364 ( 2.7) | 317 ( 2.8) | 47 ( 2.2) | |||
West North Central | 1384 (10.4) | 1189 (10.6) | 195 ( 9.3) | |||
Race/Ethnicity (%) | <0.001 | 26.1 | ||||
Non-Hispanic white | 5647 (56.0) | 4455 (53.1) | 1192 (70.3) | |||
Asian | 362 ( 3.6) | 307 ( 3.7) | 55 ( 3.2) | |||
Hispanic | 533 ( 5.3) | 478 ( 5.7) | 55 ( 3.2) | |||
Non-Hispanic black | 3547 (35.2) | 3153 (37.6) | 394 (23.2) | |||
Sex = Female/Male (%) | 6563/7091 (48.1/51.9) | 5635/5856 (49.0/51.0) | 928/1235 (42.9/57.1) | <0.001 | 0.0 | |
Smoking (%) | <0.001 | 25.6 | ||||
Current | 866 ( 8.5) | 785 ( 9.0) | 81 ( 5.5) | |||
Never | 6207 (61.1) | 5450 (62.8) | 757 (51.1) | |||
Previous | 3092 (30.4) | 2450 (28.2) | 642 (43.4) | |||
Acute myocardial infarction = Yes (%) | 1535 (11.2) | 1028 ( 8.9) | 507 (23.4) | <0.001 | 0.0 | |
AIDS/HIV = Yes (%) | 101 ( 0.7) | 89 ( 0.8) | 12 ( 0.6) | 0.339 | 0.0 | |
Cancer = Yes (%) | 1678 (12.3) | 1282 (11.2) | 396 (18.3) | <0.001 | 0.0 | |
Cerebrovascular disease = Yes (%) | 1439 (10.5) | 1023 ( 8.9) | 416 (19.2) | <0.001 | 0.0 | |
Chronic pulmonary disease = Yes (%) | 3627 (26.6) | 2933 (25.5) | 694 (32.1) | <0.001 | 0.0 | |
Congestive heart failure = Yes (%) | 2325 (17.0) | 1604 (14.0) | 721 (33.3) | <0.001 | 0.0 | |
Dementia = Yes (%) | 1394 (10.2) | 854 ( 7.4) | 540 (25.0) | <0.001 | 0.0 | |
Diabetes = Yes (%) | 4612 (33.8) | 3669 (31.9) | 943 (43.6) | <0.001 | 0.0 | |
Hemiplegia or paraplegia = Yes (%) | 330 ( 2.4) | 228 ( 2.0) | 102 ( 4.7) | <0.001 | 0.0 | |
Hypertension = Yes (%) | 8003 (58.6) | 6333 (55.1) | 1670 (77.2) | <0.001 | 0.0 | |
Metastatic cancer = Yes (%) | 277 ( 2.0) | 188 ( 1.6) | 89 ( 4.1) | <0.001 | 0.0 | |
Mild liver disease = Yes (%) | 879 ( 6.4) | 711 ( 6.2) | 168 ( 7.8) | 0.007 | 0.0 | |
Moderate/severe liver disease = Yes (%) | 128 ( 0.9) | 88 ( 0.8) | 40 ( 1.8) | <0.001 | 0.0 | |
Peptic ulcer disease = Yes (%) | 206 ( 1.5) | 160 ( 1.4) | 46 ( 2.1) | 0.013 | 0.0 | |
Peripheral vascular disease = Yes (%) | 1671 (12.2) | 1176 (10.2) | 495 (22.9) | <0.001 | 0.0 | |
Renal disease = Yes (%) | 2833 (20.7) | 1984 (17.3) | 849 (39.3) | <0.001 | 0.0 | |
Rheumatoid disease = Yes (%) | 398 ( 2.9) | 315 ( 2.7) | 83 ( 3.8) | 0.007 | 0.0 | |
CCI (median [IQR]) | 1.0 [0.0, 3.0] | 1.0 [0.0, 2.0] | 3.0 [1.0, 5.0] | <0.001 | nonnorm | 0.0 |
Body Mass Index (BMI) (median [IQR]) | 29.7 [25.5, 35.1] | 30.0 [25.8, 35.4] | 28.1 [24.0, 33.5] | <0.001 | nonnorm | 11.9 |
Diastolic blood pressure (median [IQR]) | 73.0 [65.5, 80.5] | 74.0 [66.5, 81.0] | 68.0 [60.0, 75.5] | <0.001 | nonnorm | 3.1 |
Heart rate (median [IQR]) | 87.5 [77.5, 98.0] | 87.0 [77.5, 98.0] | 89.0 [77.5, 102.0] | <0.001 | nonnorm | 3.1 |
Oxygen saturation (median [IQR]) | 96.0 [94.0, 98.0] | 96.0 [94.5, 98.0] | 95.0 [93.0, 97.0] | <0.001 | nonnorm | 3.9 |
Respiration rate (median [IQR]) | 20.0 [18.0, 22.0] | 19.5 [18.0, 21.0] | 22.0 [19.0, 26.0] | <0.001 | nonnorm | 3.9 |
Systolic blood pressure (median [IQR]) | 126.0 [115.0, 139.0] | 127.0 [116.0, 139.0] | 122.0 [109.0, 136.5] | <0.001 | nonnorm | 3.2 |
Temperature (median [IQR]) | 37.0 [36.7, 37.4] | 37.0 [36.7, 37.4] | 37.1 [36.7, 37.6] | <0.001 | nonnorm | 3.1 |
Alanine aminotransferase (ALT) (median [IQR]) | 28.0 [18.0, 46.0] | 28.0 [18.0, 46.0] | 27.0 [18.0, 44.0] | 0.112 | nonnorm | 20.1 |
Aspartate aminotransferase (AST) (median [IQR]) | 37.0 [25.0, 58.0] | 35.0 [25.0, 54.0] | 46.0 [30.0, 73.0] | <0.001 | nonnorm | 21.0 |
C-reactive protein (CRP) (median [IQR]) | 79.1 [34.0, 140.0] | 72.2 [30.0, 130.0] | 116.0 [63.0, 184.0] | <0.001 | nonnorm | 38.7 |
Creatinine (median [IQR]) | 1.0 [0.8, 1.4] | 1.0 [0.8, 1.3] | 1.3 [1.0, 2.1] | <0.001 | nonnorm | 10.4 |
Ferritin (median [IQR]) | 510.0 [224.0, 1080.0] | 470.0 [207.0, 992.0] | 747.5 [320.8, 1501.5] | <0.001 | nonnorm | 43.6 |
Fibrin D-Dimer (median [IQR]) | 750.0 [390.0, 1540.8] | 692.5 [370.0, 1346.5] | 1345.0 [668.2, 3315.0] | <0.001 | nonnorm | 90.4 |
Lactate dehydrogenase (LDH) (median [IQR]) | 321.0 [238.0, 441.0] | 308.0 [232.0, 415.0] | 404.0 [284.0, 556.5] | <0.001 | nonnorm | 45.2 |
Lymphocyte count (median [IQR]) | 1.0 [0.7, 1.4] | 1.0 [0.7, 1.4] | 0.8 [0.5, 1.1] | <0.001 | nonnorm | 11.2 |
Neutrophil count (median [IQR]) | 4.9 [3.4, 7.1] | 4.7 [3.2, 6.7] | 6.1 [4.1, 9.2] | <0.001 | nonnorm | 11.2 |
Platelet count (PLT) (median [IQR]) | 202.0 [157.0, 260.0] | 205.0 [160.0, 262.0] | 187.5 [143.0, 245.0] | <0.001 | nonnorm | 9.8 |
Procalcitonin (median [IQR]) | 0.1 [0.1, 0.4] | 0.1 [0.1, 0.3] | 0.3 [0.1, 1.0] | <0.001 | nonnorm | 49.3 |
Troponin I (median [IQR]) | 0.0 [0.0, 0.0] | 0.0 [0.0, 0.0] | 0.0 [0.0, 0.1] | <0.001 | nonnorm | 41.2 |
White blood cell count (WBC) (median [IQR]) | 6.7 [4.9, 9.1] | 6.5 [4.8, 8.7] | 7.7 [5.6, 11.1] | <0.001 | nonnorm | 9.7 |
The functional form of the relationship between mortality and the continuous variables is assessed using a series of univariate fits. We mainly rely on visual inspection of the graphs but also report the Bayesian information criterion (BIC).
tryNULL <- function(expr) {
res <- NULL
try(res <- expr)
return(res)
}
fit_univariate_logit <- function(var, data, spline){
make_f <- function(rhs){
as.formula(paste("died", rhs, sep =" ~ "))
}
fit_logit <- function(f, data){
lrm(f, data = data)
}
list(
`Linear` = fit_logit(make_f(var), data),
`Spline 3 knots` = fit_logit(make_f(sprintf("rcs(%s, 3)", var)), data),
`Spline 4 knots` = fit_logit(make_f(sprintf("rcs(%s, 4)", var)), data),
`Spline 5 knots` = fit_logit(make_f(sprintf("rcs(%s, 5)", var)), data)
)
}
predict_univariate_logit <- function(models, var, var_values, type = "response"){
newdata <- data.frame(var = var_values)
colnames(newdata) <- var
pred_df <- map_dfc(models, function(x) {
p <- tryNULL(predict(x, newdata = newdata, type = type))
})
model_names <- colnames(pred_df)
pred_df %>%
mutate(var = var_values) %>%
pivot_longer(cols = model_names,
names_to = "Model",
values_to = "y")
}
midpoint <- function(x, digits = 2){
lower <- as.numeric(gsub(",.*", "", gsub("\\(|\\[|\\)|\\]", "", x)))
upper <- as.numeric(gsub(".*,", "", gsub("\\(|\\[|\\)|\\]", "", x)))
return(round(lower+(upper-lower)/2, digits))
}
bin_y <- function(var, var_values){
data <- train_data[, c(var, "died")] %>%
filter(!is.na(get(var)))
data <- data %>%
mutate(x_cat = cut(get(var), breaks = 20),
x_midpoint = midpoint(x_cat)) %>%
group_by(x_midpoint) %>%
summarise(y = mean(died),
n = n())
colnames(data)[1] <- var
return(data)
}
plot_univariate_logit <- function(models, var, var_values, var_lab = "Variable",
type = "response", ylab = "Probability of death"){
# Plotting data
predicted_probs <- predict_univariate_logit(models, var, var_values, type = type)
ylab <- switch(type,
"lp" = "Log odds",
"fitted" = "Probability of death",
stop("Type must be 'lp' or 'fitted'")
)
binned_y <- bin_y(var, var_values)
if (type == "lp"){
binned_y$y <- ifelse(binned_y$y == 0, .001, binned_y$y)
binned_y$y <- ifelse(binned_y$y == 1, .99, binned_y$y)
binned_y$y <- qlogis(binned_y$y)
}
# Plotting scales
y_min <- min(c(binned_y$y, predicted_probs$y))
y_max <- max(c(binned_y$y, predicted_probs$y))
size_breaks <- seq(min(binned_y$n), max(binned_y$n),
length.out = 6)
# Plot
ggplot(predicted_probs,
aes(x = var, y = y)) +
geom_line() +
geom_point(data = binned_y, aes_string(x = var, y = "y", size = "n")) +
facet_wrap(~Model, ncol = 2) +
xlab(var_lab) +
ylab(ylab) +
ylim(floor(y_min), ceiling(y_max)) +
scale_size(name = "Sample size", range = c(0.3, 3),
breaks = round(size_breaks, 0))
}
make_seq <- function(var){
var_min <- min(train_data[[var]], na.rm = TRUE)
var_max <- max(train_data[[var]], na.rm = TRUE)
seq(var_min, var_max, length.out = 100)
}
evaluate_univariate_logit <- function(var, print = TRUE){
var_values = make_seq(var)
var_lab = get_var_labs(var)
# Do evaluations
fits <- fit_univariate_logit(var, data = train_data)
p_link <- plot_univariate_logit(fits, var, var_values, var_lab, type = "lp")
p_probs <- plot_univariate_logit(fits, var, var_values, var_lab, type = "fitted")
bic <- unlist(lapply(fits, function(z) tryNULL(BIC(z))))
# Print and return
if (print){
print(p_link)
print(p_probs)
print(bic)
}
return(list(fits = fits, p_link = p_link, p_probs = p_probs,
bic = bic))
}
dd <- datadist(train_data)
options(datadist = "dd")
uv_age <- evaluate_univariate_logit("age")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10153.39 10155.66 10164.65 10172.37
uv_calendar_time <- evaluate_univariate_logit("calendar_time")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 11927.48 11923.47 11929.38 11938.75
uv_bmi <- evaluate_univariate_logit("bmi")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10634.18 10602.78 10610.83 10619.81
uv_dbp <- evaluate_univariate_logit("dbp")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 11072.71 11040.40 11048.04 11056.58
uv_sbp <- evaluate_univariate_logit("sbp")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 11463.15 11257.48 11234.40 11241.32
uv_hr <- evaluate_univariate_logit("hr")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 11501.00 11391.35 11386.31 11393.71
uv_resp <- evaluate_univariate_logit("resp")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10747.57 10757.05 10511.84 10519.00
uv_spo2 <- evaluate_univariate_logit("spo2")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 11141.19 11109.05 11109.31 11117.92
uv_temp <- evaluate_univariate_logit("temp")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 11594.91 11404.12 11413.08 11419.88
uv_alt <- evaluate_univariate_logit("alt")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10129.18 10134.74 10143.54 10150.31
uv_alt_t <- evaluate_univariate_logit("alt_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10127.15 10135.79 10144.48 10151.88
uv_ast <- evaluate_univariate_logit("ast")
## singular information matrix in lrm.fit (rank= 4 ). Offending variable(s):
## ast'''
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in UseMethod("logLik") :
## no applicable method for 'logLik' applied to an object of class "lrm"
## Linear Spline 3 knots Spline 4 knots
## 9998.672 9825.157 9825.636
uv_ast_t <- evaluate_univariate_logit("ast_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 9850.176 9818.026 9825.865 9832.198
uv_crp <- evaluate_univariate_logit("crp")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 7611.060 7581.414 7584.386 7589.655
uv_crp_t <- evaluate_univariate_logit("crp_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 7609.387 7581.499 7584.233 7589.694
uv_creatinine <- evaluate_univariate_logit("creatinine")
## singular information matrix in lrm.fit (rank= 4 ). Offending variable(s):
## creatinine'''
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in UseMethod("logLik") :
## no applicable method for 'logLik' applied to an object of class "lrm"
## Linear Spline 3 knots Spline 4 knots
## 10993.75 10492.07 10471.88
uv_creatinine_t <- evaluate_univariate_logit("creatinine_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10601.63 10466.05 10476.77 10472.93
uv_ferritin <- evaluate_univariate_logit("ferritin")
## singular information matrix in lrm.fit (rank= 4 ). Offending variable(s):
## ferritin'''
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in UseMethod("logLik") :
## no applicable method for 'logLik' applied to an object of class "lrm"
## Linear Spline 3 knots Spline 4 knots
## 7309.122 7194.479 7199.344
uv_ferritin_t <- evaluate_univariate_logit("ferritin_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 7214.138 7190.977 7198.924 7204.215
uv_d_dimer <- evaluate_univariate_logit("d_dimer")
## singular information matrix in lrm.fit (rank= 3 ). Offending variable(s):
## d_dimer''
## singular information matrix in lrm.fit (rank= 4 ). Offending variable(s):
## d_dimer''
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in UseMethod("logLik") :
## no applicable method for 'logLik' applied to an object of class "lrm"
## Error in UseMethod("logLik") :
## no applicable method for 'logLik' applied to an object of class "lrm"
## Linear Spline 3 knots
## 1201.612 1161.649
uv_d_dimer_t <- evaluate_univariate_logit("d_dimer_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 1154.281 1157.531 1164.385 1170.708
uv_ldh <- evaluate_univariate_logit("ldh")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 6916.380 6748.649 6715.047 6709.253
uv_ldh_t <- evaluate_univariate_logit("ldh_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 6722.740 6704.054 6704.866 6706.272
uv_lymphocyte <- evaluate_univariate_logit("lymphocyte")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 11021.34 10663.85 10676.34 10683.20
uv_lymphocyte_t <- evaluate_univariate_logit("lymphocyte_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10839.18 10657.41 10657.86 10667.01
uv_neutrophil <- evaluate_univariate_logit("neutrophil")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10741.03 10716.37 10709.02 10717.48
uv_neutrophil_t <- evaluate_univariate_logit("neutrophil_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10709.62 10704.29 10705.94 10714.40
uv_pct <- evaluate_univariate_logit("pct")
## singular information matrix in lrm.fit (rank= 3 ). Offending variable(s):
## pct''
## singular information matrix in lrm.fit (rank= 3 ). Offending variable(s):
## pct''' pct''
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in UseMethod("logLik") :
## no applicable method for 'logLik' applied to an object of class "lrm"
## Error in UseMethod("logLik") :
## no applicable method for 'logLik' applied to an object of class "lrm"
## Linear Spline 3 knots
## 6956.604 6494.028
uv_pct_t <- evaluate_univariate_logit("pct_t")
## singular information matrix in lrm.fit (rank= 4 ). Offending variable(s):
## pct_t'''
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in UseMethod("logLik") :
## no applicable method for 'logLik' applied to an object of class "lrm"
## Linear Spline 3 knots Spline 4 knots
## 6565.428 6491.988 6483.149
uv_tni <- evaluate_univariate_logit("tni")
## singular information matrix in lrm.fit (rank= 3 ). Offending variable(s):
## tni''
## singular information matrix in lrm.fit (rank= 3 ). Offending variable(s):
## tni''
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in terms.default(form) : no terms component nor attribute
## Error in UseMethod("logLik") :
## no applicable method for 'logLik' applied to an object of class "lrm"
## Error in UseMethod("logLik") :
## no applicable method for 'logLik' applied to an object of class "lrm"
## Linear Spline 3 knots
## 7750.579 7041.228
uv_tni_t <- evaluate_univariate_logit("tni_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 7074.305 7011.201 6970.225 6944.451
uv_plt <- evaluate_univariate_logit("plt")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 11102.84 11052.10 11059.55 11067.80
uv_plt_t <- evaluate_univariate_logit("plt_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 11099.10 11051.69 11059.24 11067.10
uv_wbc <- evaluate_univariate_logit("wbc")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10987.01 10986.79 10936.24 10942.97
uv_wbc_t <- evaluate_univariate_logit("wbc_t")
## Linear Spline 3 knots Spline 4 knots Spline 5 knots
## 10934.25 10943.26 10926.22 10930.45
Variables that will be candidates for inclusion in our model and used during variable selection are specified here. We will combine race and ethnicity into a single variable, but will wait to do this until after multiple imputation (see next section). Alanine aminotransferase (ALT) will also be excluded due to (i) a strong correlation with AST (0.83) and weak univariate associations with mortality (see above).
demographics_to_include <- c("age", "sex", "race", "ethnicity", "division",
"smoke", "calendar_time")
comorbidities_to_include <- comorbidity_cat_vars %>%
filter(!var %in% c("diabunc", "diabwc", "hypunc", "hypc")) %>%
pull(var) %>%
c("diab", "hyp")
vitals_to_include <- c("bmi", "temp", "hr", "resp", "spo2", "sbp")
labs_to_include <- c("crp_t", "tni_t", "ast_t", "ferritin_t",
"creatinine_t", "ldh_t", "lymphocyte_t", "neutrophil_t",
"plt_t", "wbc_t")
vars <- vars %>%
mutate(include = ifelse(var %in% c(demographics_to_include,
comorbidities_to_include,
vitals_to_include,
labs_to_include),
1, 0))
get_included_vars <- function(){
vars[vars$include == 1, ]$var
}
make_rhs <- function(vars){
as.formula(paste0("~", paste(vars, collapse = " + ")))
}
candidate_model_rhs <- make_rhs(get_included_vars())
Imputation will be performed using multivariate imputation by chained equations (MICE). Before imputing, we will first re-level variables so that we have preferred reference categories.
train_data <- train_data %>% mutate(
sex = relevel(factor(sex), ref = "Male"),
division = relevel(factor(division), ref = "Pacific"),
smoke = relevel(factor(smoke), ref = "Never"),
bmi_cat = relevel(factor(bmi_cat), ref = "Normal")
)
We can then impute using the mice()
function.
# Run MICE algorithm
mice_out <- train_data %>%
select(c(one_of(get_included_vars()))) %>%
mutate_if(is.character, as.factor) %>%
mice(m = n_imputations, maxit = 5)
## Save variables for test set imputation
mice_vars <- c(get_included_vars())
# Append datasets and add death
mi_df <- complete(mice_out, action = "long", include = TRUE) %>%
as_tibble() %>%
mutate(died = rep(train_data$died, mice_out$m + 1))
# To compare MICE to aregImpute
# areg_out <- aregImpute(update.formula(candidate_model_rhs, ~.),
# n.impute = 2, data = train_data)
The distributions of the imputed and observed data are compared as a diagnostic for the imputation. They look pretty similar suggesting that there is nothing terribly wrong with the imputation.
make_imp_df <- function(object){
# Get imputations
if (inherits(object, "mids")){
imp <- object$imp
} else{ # aregImpute
imp <- object$imputed
for (i in 1:length(imp)){
cat_levels_i <- object$cat.levels[[i]]
if (!is.null(cat_levels_i) && !is.null(imp[[i]])){
levels <- sort(unique(c(imp[[i]])))
imp[[i]] <- apply(imp[[i]],
2,
function(x) factor(x, levels = levels,
labels = cat_levels_i))
}
}
}
# Create list of data frames
is_numeric <- sapply(imp, function (x) is.numeric(x[, 1]))
continuous_df <- vector(mode = "list", length = sum(is_numeric))
cat_df <- vector(mode = "list", length = sum(!is_numeric))
continuous_cntr <- 1
cat_cntr <- 1
for (i in 1:length(imp)){
if(!is.null(nrow(imp[[i]])) && nrow(imp[[i]]) > 0 ){
imp_i_df <- data.frame(var = names(imp)[i],
imp = rep(1:ncol(imp[[i]]), each = nrow(imp[[i]])),
value = c(as.matrix(imp[[i]]))) %>%
as_tibble()
} else{
imp_i_df <- NULL
}
if (is_numeric[i]){
continuous_df[[continuous_cntr]] <- imp_i_df
continuous_cntr <- continuous_cntr + 1
} else{
cat_df[[cat_cntr]] <- imp_i_df
cat_cntr <- cat_cntr + 1
}
}
# Row bind data frames
continuous_df = bind_rows(continuous_df) %>%
mutate(obs = "Imputed",
varlab = get_var_labs(var))
cat_df = bind_rows(cat_df) %>%
mutate(obs = "Imputed",
varlab = get_var_labs(var))
# Return
return(list(continuous = continuous_df,
cat = cat_df))
}
imp_df <- make_imp_df(mice_out)
#imp_df <- make_imp_df(areg_out)
Note that there are some differences for sex, but this is because there are very few (n = sum(is.na(train_data$sex))
) observations with missing sex.
# Plot continuous variables
## Data for plotting
obsimp_df_continuous <- bind_rows(
imp_df$continuous,
continuous_var_df %>%
select(var, value, varlab) %>%
mutate(imp = 0, obs = "Observed")
) %>%
mutate(imp = ifelse(imp == 0, "Observed", paste0("Imputation ", imp))) %>%
filter(var %in% unique(imp_df$continuous$var))
## Plot
ggplot(obsimp_df_continuous,
aes(x = value, col = imp)) +
geom_density(position = "jitter") +
facet_wrap(~varlab, scales = "free", ncol = 3) +
xlab("") + ylab("Density") +
scale_color_discrete(name = "") +
theme(legend.position = "bottom")
# Plot categorical variables
## Data for plotting
obsimp_df_cat <-
bind_rows(
imp_df$cat %>%
group_by(var, varlab, value, imp) %>%
summarise(n = n()) %>%
group_by(var, varlab, imp) %>%
mutate(freq = n / sum(n)),
cat_var_df %>%
select(var, value, varlab, n, freq) %>%
mutate(imp = 0, obs = "Observed")
) %>%
mutate(imp = ifelse(imp == 0, "Observed", paste0("Imputation ", imp))) %>%
filter(var %in% unique(imp_df$cat$var))
# Plot
ggplot(obsimp_df_cat,
aes(x = value, y = freq, fill = imp)) +
geom_bar(position = "dodge", stat = "identity") +
facet_wrap(~varlab, scales = "free_x") +
scale_fill_discrete(name = "") +
xlab("") +
ylab("Proportion") +
theme(legend.position = "bottom",
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1))
We create the race/ethnicity variable after imputation of missing values of race and ethnicity. We will also combine the diabetes and hypertension variables.
mi_df <- mi_df %>%
add_race_ethnicity()
prop.table(table(mi_df %>% filter(.imp == 0) %>% pull(diab)))
##
## No Yes
## 0.6623224 0.3376776
prop.table(table(mi_df %>% filter(.imp == 0) %>% pull(hyp)))
##
## No Yes
## 0.4140431 0.5859569
Now lets use the combined race and ethnicity category in our model.
vars <- vars %>%
mutate(include = case_when(
var == "race_ethnicity" ~ 1,
var == "race" ~ 0,
var == "ethnicity" ~ 0,
TRUE ~ include
)
)
candidate_model_rhs <- make_rhs(get_included_vars())
Finally, we will (i) add the new race/ethnicity variable to the “MICE” object and (ii) create a list of imputed datasets for analysis.
mice_out <- as.mids(mi_df)
mi_list <- mi_df %>%
filter(.imp > 0) %>%
split(list(.$.imp))
We select variables for inclusion in the model by repeatedly fitting a logistic regression with a group lasso penalty. Each group lasso model is trained using 10-fold cross validation to select a value of the penalty parameter, lambda. Coefficients are extracted from the fit with lambda = 1se
; that is, the model with deviance within one standard error of the minimum deviance. This process was repeated \(N\) times for each of the \(M\) imputed datasets. Variables with non-zero coefficients in at least 90 percent of the \(N \times M\) iterations were deemed suitable for inclusion in the model.
We now setup the group lasso model, which will be implemented with oem
. The continuous variables are transformed using restricted cubic splines with 3 knots.
candidate_model_rhs <- candidate_model_rhs %>%
update.formula( ~. + rcs(age, 3) - age +
rcs(calendar_time, 3) - calendar_time +
rcs(bmi, 3) - bmi +
rcs(temp, 3) - temp +
rcs(hr, 3) - hr +
rcs(resp, 4) - resp +
rcs(sbp, 3) - sbp +
rcs(spo2, 3) - spo2 +
rcs(crp_t, 3) - crp_t +
rcs(tni_t, 3) - tni_t +
rcs(ast_t, 3) - ast_t +
rcs(creatinine_t, 3) - creatinine_t +
rcs(ferritin_t, 3) - ferritin_t +
rcs(ldh_t, 3) - ldh_t,
rcs(lymphocyte_t, 3) - lymphocyte_t +
rcs(neutrophil_t, 3) - neutrophil_t +
rcs(plt_t, 3) - plt_t +
rcs(wbc_t, 3) - wbc_t)
With oem
we need to create an x
matrix since there is no formula interface.
rename_rcs <- function(v){
rcs_ind <- grep("rcs", v)
v[rcs_ind] <- sub("rcs.*)", "", v[rcs_ind])
return(v)
}
rename_terms <- function(v){
v <- gsub(" ", "_", v)
v <- gsub("-", "", v)
v <- gsub("/", "", v)
v <- rename_rcs(v)
return(v)
}
make_x <- function(data, rhs){
x <- model.matrix(rhs, data)
assign <- attr(x, "assign")
colnames(x) <- rename_terms(colnames(x))
x <- x[, -1]
attr(x, "assign") <- assign[-1]
return(x)
}
# List of x and y for each imputed dataset
x <- mi_list %>% map(function(data) make_x(data, candidate_model_rhs))
y <- mi_list %>% map(function(x) x[["died"]])
To make the graphs look nice, we will import labels for the model terms.
terms <- read.csv("risk-factors-terms.csv") %>%
left_join(vars[, c("var", "group")], by = "var")
get_term_labs <- function(v, term_name = "term"){
terms$termlab[match(v, terms[[term_name]])]
}
match_terms_to_vars <- function(t){
terms$var[match(t, terms$term)]
}
Finally, we repeatedly train the group lasso model with cross-validation.
# Number of folds for cross-validation
n_folds <- 10
# Threshold for variable inclusion
inclusion_threshold <- 0.9
# Matrix to store inclusion results
inclusion_sim <- matrix(0, ncol = ncol(x[[1]]) + 1,
nrow = n_rep * n_imputations)
# Groups
groups <- attr(x[[1]], "assign")
# Convenience function to extract coefficients from group-lasso
coef_cv_oem <- function(object){
coef <- object$oem.fit$beta$grp.lasso
lse_ind <- which(object$lambda[[1]] == object$lambda.1se.models)
return(coef[, lse_ind])
}
# Variable selection via group-lasso
cntr <- 1
for (i in 1:n_imputations){
for (j in 1:n_rep){
# Cross-validation
oem_cvfit <- cv.oem(x = x[[i]], y = y[[i]],
penalty = "grp.lasso",
groups = groups,
family = "binomial",
type.measure = "deviance",
nfolds = n_folds
)
# Count nonzero coefficients
inclusion_sim[cntr, which(coef_cv_oem(oem_cvfit) != 0)] <- 1
# Iterate
cntr <- cntr + 1
} # End repeated CV loop
} # End imputation loop
We can then plot the proportion of times that coefficients for each variable were nonzero.
# Percentage of simulations each term is included
inclusion_sim <- inclusion_sim[, -1] # Remove intercept
colnames(inclusion_sim) = colnames(x[[1]])
inclusion_summary <- tibble(term = colnames(inclusion_sim),
prob = apply(inclusion_sim, 2, mean))
model_terms <- inclusion_summary %>%
filter(prob >= inclusion_threshold) %>%
pull(term)
# Percentage of simulations each variable is included
inclusion_summary <- inclusion_summary %>%
mutate(var = match_terms_to_vars(term)) %>%
mutate(varlab = get_var_labs(var)) %>%
distinct(prob, var, varlab)
# Plot
ggplot(inclusion_summary,
aes(x = reorder(varlab, prob), y = prob)) +
geom_bar(stat = "identity") +
geom_hline(yintercept = inclusion_threshold, linetype = "dotted",
color = "red", size = 1) +
ylab("Probability of inclusion") +
coord_flip() +
theme(axis.title.y = element_blank())
Variables are included in the model based on the group lasso simulation implemented above.
vars_to_exclude <- inclusion_summary %>%
filter(prob < inclusion_threshold) %>%
pull(var) %>%
setdiff("sex") # Keep sex even if not picked by group lasso
remove_terms_from_rhs <- function(f, vars_to_exclude){
# First convert formula to string separated by +
f_string <- Reduce(paste, deparse(f))
f_string <- gsub("~", "", f_string)
f_string <- gsub(" ", "", f_string)
# Then convert string to vector
f_vec <- unlist(strsplit(f_string, "\\+"))
pattern_to_exclude <- paste(vars_to_exclude, collapse = "|")
f_vec <- f_vec[!grepl(pattern_to_exclude, f_vec)]
# Convert string back to formula
f_new <- paste0("~", paste(f_vec, collapse = " + "))
return(as.formula(f_new))
}
model_rhs <- remove_terms_from_rhs(candidate_model_rhs, vars_to_exclude)
label(mi_df) <- map(colnames(mi_df),
function(x) label(mi_df[, x]) <- get_var_labs(x))
dd <- datadist(mi_df, adjto.cat = "first")
options(datadist = "dd")
As described in the paper, five models are fit that include the following predictors: (i) age only, (ii) comorbidities only, (iii) all demographics (and calendar time), and (iv) demographics (and calendar time) and comorbidities, and (v) all variables selected by the group lasso.
# The four models
lrm_names <- c("Age only",
"Comorbidities only",
"All demographics",
"Demographics and comorbidities",
"All variables")
## (1): fit_age: Only includes age
f_lrm_age <- died ~ rcs(age, 3)
lrm_fit_age <- fit.mult.impute(f_lrm_age, fitter = lrm, xtrans = mice_out, pr = FALSE,
x = TRUE, y = TRUE)
## (2): fit_c: Only includes comorbidities
c_vars <- vars %>%
filter(group == "Comorbidities" & include == 1 & !var %in% vars_to_exclude) %>%
pull(var)
f_lrm_c <- as.formula(paste0("died ~", paste(c_vars, collapse = "+")))
lrm_fit_c <- fit.mult.impute(f_lrm_c, fitter = lrm, xtrans = mice_out,
pr = FALSE, x = TRUE, y = TRUE)
## (3): fit_d: All demographics including age
f_lrm_d <- update.formula(f_lrm_age, ~. + sex + rcs(calendar_time, 3) +
race_ethnicity + division)
lrm_fit_d <- fit.mult.impute(f_lrm_d, fitter = lrm, xtrans = mice_out, pr = FALSE,
x = TRUE, y = TRUE)
## (4): fit_dc: Demographics and comorbidities
f_lrm_dc <- update.formula(f_lrm_d,
as.formula(paste0("~.+", paste(c_vars, collapse = "+"))))
lrm_fit_dc <- fit.mult.impute(f_lrm_dc, fitter = lrm, xtrans = mice_out, pr = FALSE,
x = TRUE, y = TRUE)
## (5): fit_all: The main model including demographics, comorbidities, vitals, and labs
f_lrm_all <- update.formula(model_rhs, died ~ .)
lrm_fit_all <- fit.mult.impute(f_lrm_all, fitter = lrm, xtrans = mice_out,
pr = FALSE, x = TRUE, y = TRUE)
### Note that we can also estimate the models with stats::glm
glm_fits_all <- mi_list %>%
map(function (x) glm(f_lrm_all, data = x, family = "binomial"))
glm_fit_all <- glm_fits_all %>% pool()
We will print the results of the full model for the interested reader
lrm_fit_all
## Logistic Regression Model
##
## fit.mult.impute(formula = f_lrm_all, fitter = lrm, xtrans = mice_out,
## pr = FALSE, x = TRUE, y = TRUE)
##
## Model Likelihood Discrimination Rank Discrim.
## Ratio Test Indexes Indexes
## Obs 13658 LR chi2 4048.36 R2 0.440 C 0.883
## 0 11495 d.f. 60 g 2.319 Dxy 0.766
## 1 2163 Pr(> chi2) <0.0001 gr 10.169 gamma 0.766
## max |deriv| 0.0001 gp 0.202 tau-a 0.204
## Brier 0.089
##
## Coef S.E. Wald Z Pr(>|Z|)
## Intercept 20.5996 2.6589 7.75 <0.0001
## sex=Female -0.0254 0.0655 -0.39 0.6987
## division=East North Central -0.3699 0.1720 -2.15 0.0315
## division=Middle Atlantic -0.2384 0.1685 -1.42 0.1571
## division=New England -0.1311 0.1799 -0.73 0.4662
## division=Other -0.9793 0.3667 -2.67 0.0076
## division=South Atl/West South Crl -0.5107 0.2463 -2.07 0.0382
## division=West North Central -0.3931 0.1852 -2.12 0.0338
## smoke=Current -0.0319 0.1680 -0.19 0.8493
## smoke=Previous 0.0974 0.0680 1.43 0.1521
## race_ethnicity=Asian -0.0936 0.1933 -0.48 0.6282
## race_ethnicity=Hispanic -0.1631 0.1064 -1.53 0.1251
## race_ethnicity=Non-Hispanic black -0.4232 0.0839 -5.04 <0.0001
## ami=Yes 0.1123 0.0830 1.35 0.1762
## chf=Yes 0.2046 0.0767 2.67 0.0076
## pvd=Yes 0.0693 0.0803 0.86 0.3880
## cevd=Yes 0.0603 0.0877 0.69 0.4915
## dementia=Yes 0.3877 0.0826 4.69 <0.0001
## cpd=Yes 0.0110 0.0699 0.16 0.8747
## rheumd=Yes 0.0722 0.1556 0.46 0.6428
## mld=Yes 0.0777 0.1222 0.64 0.5249
## hp=Yes 0.7016 0.1577 4.45 <0.0001
## rend=Yes 0.1141 0.0814 1.40 0.1613
## canc=Yes 0.0501 0.0856 0.59 0.5583
## msld=Yes 0.8054 0.2534 3.18 0.0015
## metacanc=Yes 0.8261 0.1759 4.70 <0.0001
## aids=Yes 0.5226 0.3588 1.46 0.1453
## diab=Yes 0.1058 0.0651 1.62 0.1043
## hyp=Yes 0.1126 0.0779 1.44 0.1486
## lymphocyte_t -0.2381 0.0609 -3.91 <0.0001
## plt_t -0.0023 0.0004 -5.69 <0.0001
## wbc_t 0.0547 0.0096 5.69 <0.0001
## age 0.0673 0.0044 15.12 <0.0001
## age' -0.0054 0.0055 -0.99 0.3246
## calendar_time -0.0169 0.0046 -3.69 0.0002
## calendar_time' 0.0069 0.0077 0.89 0.3728
## bmi -0.0242 0.0111 -2.19 0.0287
## bmi' 0.0445 0.0137 3.25 0.0012
## temp -0.2181 0.0534 -4.08 <0.0001
## temp' 0.9317 0.0839 11.11 <0.0001
## hr -0.0003 0.0043 -0.08 0.9399
## hr' 0.0127 0.0049 2.57 0.0103
## resp -0.2170 0.0383 -5.67 <0.0001
## resp' 3.4132 0.4051 8.43 <0.0001
## resp'' -7.4092 0.8839 -8.38 <0.0001
## sbp -0.0320 0.0036 -8.94 <0.0001
## sbp' 0.0304 0.0043 7.03 <0.0001
## spo2 -0.1319 0.0174 -7.56 <0.0001
## spo2' 0.1258 0.0279 4.51 <0.0001
## crp_t 0.0030 0.0014 2.12 0.0341
## crp_t' -0.0034 0.0019 -1.82 0.0686
## tni_t 9.4970 4.1262 2.30 0.0214
## tni_t' -25.9739 15.8831 -1.64 0.1020
## ast_t 0.0136 0.0049 2.76 0.0058
## ast_t' -0.0212 0.0080 -2.65 0.0081
## creatinine_t 0.3322 0.1108 3.00 0.0027
## creatinine_t' -0.1832 0.1929 -0.95 0.3424
## ferritin_t 0.0001 0.0002 0.73 0.4661
## ferritin_t' -0.0002 0.0004 -0.49 0.6267
## ldh_t 0.0015 0.0011 1.35 0.1770
## ldh_t' -0.0006 0.0014 -0.45 0.6548
##
Let’s examine the extent to which our predictors are collinear: for categorical and continuous variables, we use an anova model; for categorical and categorical variables, we use Cramer’s V; and for continuous and continuous variables, we use spearman correlation. (Note that this takes a long time to run and could probably be made more efficient.)
## from https://stackoverflow.com/questions/52554336/plot-the-equivalent-of-correlation-matrix-for-factors-categorical-data-and-mi
mixed_assoc <- function(df, cor_method = "spearman", adjust_cramersv_bias = TRUE){
df_comb <- expand.grid(names(df), names(df), stringsAsFactors = F) %>%
set_names("X1", "X2")
is_nominal <- function(x) inherits(x, c("factor", "character"))
# https://community.rstudio.com/t/why-is-purr-is-numeric-deprecated/3559
# https://github.com/r-lib/rlang/issues/781
is_numeric <- function(x) { is.integer(x) || is_double(x)}
f <- function(x_name, y_name) {
x <- pull(df, x_name)
y <- pull(df, y_name)
result <- if(is_nominal(x) && is_nominal(y)){
# use bias corrected cramersV as described in https://rdrr.io/cran/rcompanion/man/cramerV.html
cv <- cramerV(as.character(x), as.character(y),
bias.correct = adjust_cramersv_bias)
data.frame(x_name, y_name, assoc = cv, type = "cramersV")
} else if(is_numeric(x) && is_numeric(y)){
correlation <- cor(x, y, method = cor_method, use = "complete.obs")
data.frame(x_name, y_name, assoc = correlation, type = "correlation")
} else if(is_numeric(x) && is_nominal(y)){
# from https://stats.stackexchange.com/questions/119835/correlation-between-a-nominal-iv-and-a-continuous-dv-variable/124618#124618
r_squared <- summary(lm(x ~ y))$r.squared
data.frame(x_name, y_name, assoc = sqrt(r_squared), type = "anova")
} else if(is_nominal(x) && is_numeric(y)){
r_squared <- summary(lm(y ~ x))$r.squared
data.frame(x_name, y_name, assoc = sqrt(r_squared), type = "anova")
} else {
warning(paste("unmatched column type combination: ", class(x), class(y)))
}
# finally add complete obs number and ratio to table
result %>%
mutate(
complete_obs_pairs = sum(!is.na(x) & !is.na(y)),
complete_obs_ratio = complete_obs_pairs/length(x)) %>%
rename(x = x_name, y = y_name)
}
# apply function to each variable combination
map2_df(df_comb$X1, df_comb$X2, f)
}
# Create correlation matrix of associations
corr_mat <- mi_df %>%
filter(.imp == 1) %>%
select(any_of(get_included_vars())) %>%
mixed_assoc() %>%
select(x, y, assoc) %>%
pivot_wider(names_from = y, values_from = assoc) %>%
column_to_rownames("x") %>%
as.matrix
# Make tile plot
m <- abs(corr_mat)
heatmap_df <- tibble(row = rownames(m)[row(m)],
col = colnames(m)[col(m)], corr = c(m)) %>%
mutate(row = get_var_labs(row),
col = get_var_labs(col))
heatmap_df %>%
ggplot(aes(x = row, y = col, fill = corr)) +
geom_tile() +
scale_fill_continuous("Correlation") +
theme(axis.text.x = element_text(angle = 90, vjust = 0.5, hjust = 1),
axis.title = element_blank())
Variable importance is assessed with a Wald test using rms::anova()
.
# Compute variable important with Wald chi-square
lrm_anova_all <- anova(lrm_fit_all)
# Plot the result
## Make data frame
lrm_anova_all_df <- lrm_anova_all %>%
as_tibble() %>%
mutate(var = gsub(" ", "", rownames(lrm_anova_all)),
varlab = get_var_labs(var),
value = as.double(`Chi-Square` - `d.f.`)) %>%
filter(!var %in% c("TOTAL", "Nonlinear", "TOTALNONLINEAR"))
## Plot
ggplot(lrm_anova_all_df, aes(x = value, y = reorder(varlab, value))) +
geom_point() +
theme(axis.title.y = element_blank()) +
xlab(expression(chi^2-df))
my_datatable <- function(data, filename) {
datatable(
data,
rownames = FALSE,
filter = "top",
extensions = "Buttons",
options = list(pageLength = 20,
dom = "Bfrtip",
buttons = list(list(extend = "copy"),
list(extend = "csv", filename = filename))
))
}
lrm_anova_all_df %>%
select(Variable = varlab, `Chi-Square`, `DF` = `d.f.`,
`Chi-Square - DF` = value, `P`) %>%
arrange(desc(`Chi-Square - DF`)) %>%
my_datatable(filename = "varimp") %>%
formatRound(c("Chi-Square", "Chi-Square - DF", "P"),
3)
The model is summarized with odds ratios using rms::summary()
. We will start by assessing the full model. Next, since labs and vitals may themselves be “caused” by comorbidities, we will see how the odds ratios for the comorbidities change after dropping labs and vitals from the model.
The plot displays odds ratios and 95% confidence intervals for each variable in the full model. Note that odds ratios for continuous variables reflect a change from upper:lower
. For example, the odds ratio for age (75:49
) is for a change from age 49 to age 75.
lrm_summary_all <- summary(lrm_fit_all)
# Odds ratios
format_or_range <- function(x, term){
case_when(
x < 10 ~ formatC(x, format = "f", digits = 2),
term == "temp" ~ formatC(x, format = "f", digits = 1),
TRUE ~ formatC(x, format = "f", digits = 0)
)
}
make_tidy_or <- function(object, model_name = NULL){
if (is.null(model_name)) model_name <- "Model"
object %>%
as.data.frame() %>%
as_tibble() %>%
mutate(term = rownames(object),
High = format_or_range(High, term),
Low = format_or_range(Low, term),
termlab = get_term_labs(term, "term2"),
termlab = ifelse(!is.na(`Diff.`),
paste0(termlab, " - ", High, ":", Low),
termlab),
or = exp(Effect),
or_lower = as.double(exp(`Lower 0.95`)),
or_upper = exp(`Upper 0.95`)) %>%
filter(Type == 1) %>%
select(term, termlab, or, or_lower, or_upper) %>%
arrange(-or) %>%
mutate(model = model_name)
}
lrm_or_all <- make_tidy_or(lrm_summary_all, "All variables")
# Odds ratio plot
ggplot(lrm_or_all,
aes(x = or, y = reorder(termlab, or))) +
geom_point() +
geom_errorbarh(aes(xmax = or_upper, xmin = or_lower,
height = .2)) +
geom_vline(xintercept = 1, linetype = "dashed", col = "grey") +
theme(axis.title.y = element_blank()) +
xlab("Odds ratio")
lrm_summary_dc <- summary(lrm_fit_dc)
lrm_or_dc <- make_tidy_or(lrm_summary_dc, "Demographics + comorbidities")
# Odds ratio comparison plot
lrm_or_comp <- bind_rows(lrm_or_all, lrm_or_dc) %>%
filter(term %in%
terms[terms$group %in% c("Demographics", "Comorbidities"), ]$term2) %>%
mutate(termlab = factor(termlab),
termlab = reorder(termlab, or, function (x) -mean(x)))
ggplot(lrm_or_comp,
aes(x = termlab, y = or, col = model)) +
geom_point(position = position_dodge(width = 1)) +
geom_errorbar(aes(ymax = or_upper, ymin = or_lower,
width = .2), position = position_dodge(width = 1)) +
facet_wrap(~termlab, strip.position = "left", ncol = 1, scales = "free_y") +
geom_hline(yintercept = 1, linetype = "dashed") +
theme(axis.title.y = element_blank()) +
scale_color_discrete(name = "Model") +
theme(axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
strip.text.y.left = element_text(hjust = 0, vjust = 1,
angle = 0, size = 8),
legend.position = "bottom") +
ylab("Odds ratio") +
coord_flip()
# Results in a table
lrm_or_comp %>%
select(-term) %>%
rename(Term = termlab,
`Odds ratio` = or,
`Odds ratio (lower)` = or_lower,
`Odds ratio (upper)` = or_upper,
Model = model) %>%
my_datatable(filename = "coef_comp") %>%
formatRound(c("Odds ratio", "Odds ratio (lower)", "Odds ratio (upper)"),
4)
One limitation of the odds ratio plots is that it is hard to examine the potentially non-linear relationships between mortality and the continuous variable. The rms::Predict()
function is especially helpful in this regard. We use it to vary all of the predictors and plot predicted log odds across the different values of each predictor.
Each prediction is made with all other variables at their “Adjust to” value as specified with datadist()
above.
t(dd$limits) %>%
as_tibble() %>%
mutate(Variable = get_var_labs(colnames(dd$limits))) %>%
relocate(Variable, .before = "Low:effect") %>%
filter(!is.na(Variable)) %>%
arrange(Variable) %>%
kable() %>%
kable_styling()
Variable | Low:effect | Adjust to | High:effect | Low:prediction | High:prediction | Low | High |
---|---|---|---|---|---|---|---|
Acute myocardial infarction | NA | No | NA | No | Yes | No | Yes |
Age | 49 | 62 | 75 | 18 | 89 | 18 | 89 |
AIDS/HIV | NA | No | NA | No | Yes | No | Yes |
Aspartate aminotransferase (AST) | 25 | 36 | 57 | 6 | 157 | 4 | 157 |
Body Mass Index (BMI) | 25.52 | 29.69 | 35.13 | 13.08 | 94.84 | 11.84 | 149.50 |
C-reactive protein (CRP) | 28.8 | 71.8 | 132.0 | 0.2 | 458.0 | 0.0 | 458.0 |
Calendar time | 38 | 47 | 64 | 0 | 91 | 0 | 91 |
Cancer | NA | No | NA | No | Yes | No | Yes |
Cerebrovascular disease | NA | No | NA | No | Yes | No | Yes |
Chronic pulmonary disease | NA | No | NA | No | Yes | No | Yes |
Congestive heart failure | NA | No | NA | No | Yes | No | Yes |
Creatinine | 0.80 | 1.00 | 1.39 | 0.19 | 3.24 | 0.19 | 3.24 |
Dementia | NA | No | NA | No | Yes | No | Yes |
Diabetes | NA | No | NA | No | Yes | No | Yes |
Ethnicity | NA | Hispanic | NA | Hispanic | Not Hispanic | Hispanic | Not Hispanic |
Ferritin | 202.000000 | 459.000000 | 985.000000 | 3.273076 | 3648.000000 | 3.000000 | 3648.000000 |
Geographic division | NA | Pacific | NA | Pacific | West North Central | Pacific | West North Central |
Heart rate | 77.5 | 87.5 | 98.0 | 36.0 | 165.0 | 20.0 | 203.0 |
Hemiplegia or paraplegia | NA | No | NA | No | Yes | No | Yes |
Hypertension | NA | No | NA | No | Yes | No | Yes |
Lactate dehydrogenase (LDH) | 227.00 | 304.00 | 420.00 | 24.99 | 1050.00 | 24.99 | 1050.00 |
Lymphocyte count | 0.7 | 1.0 | 1.4 | 0.0 | 3.5 | 0.0 | 3.5 |
Metastatic cancer | NA | No | NA | No | Yes | No | Yes |
Mild liver disease | NA | No | NA | No | Yes | No | Yes |
Moderate/severe liver disease | NA | No | NA | No | Yes | No | Yes |
Neutrophil count | 3.32 | 4.87 | 7.10 | 0.00 | 18.29 | 0.00 | 18.29 |
Oxygen saturation | 94 | 96 | 98 | 34 | 100 | 26 | 100 |
Peptic ulcer disease | NA | No | NA | No | Yes | No | Yes |
Peripheral vascular disease | NA | No | NA | No | Yes | No | Yes |
Platelet count (PLT) | 158 | 203 | 261 | 2 | 569 | 1 | 569 |
Race | NA | African American | NA | African American | Caucasian | African American | Caucasian |
Race/Ethnicity | NA | Non-Hispanic white | NA | Non-Hispanic white | Non-Hispanic black | Non-Hispanic white | Non-Hispanic black |
Renal disease | NA | No | NA | No | Yes | No | Yes |
Respiration rate | 18.0 | 20.0 | 22.0 | 2.0 | 65.5 | 2.0 | 99.0 |
Rheumatoid disease | NA | No | NA | No | Yes | No | Yes |
Sex | NA | Male | NA | Male | Female | Male | Female |
Smoking | NA | Never | NA | Never | Previous | Never | Previous |
Systolic blood pressure | 115 | 126 | 139 | 33 | 242 | 30 | 266 |
Temperature | 36.7 | 37.0 | 37.4 | 16.0 | 41.8 | 16.0 | 42.2 |
Troponin I | 0.010 | 0.010 | 0.042 | 0.000 | 0.170 | 0.000 | 0.170 |
White blood cell count (WBC) | 4.90 | 6.70 | 9.10 | 0.21 | 21.70 | 0.20 | 21.70 |
Let’s make the predictions.
lrm_log_odds <- Predict(lrm_fit_all, ref.zero = TRUE)
# Get plotting data
p_log_odds <- ggplot(lrm_log_odds, sepdiscrete = "list")
# Continuous plot
log_odds_limit <- max(ceiling(c(abs(p_log_odds$continuous$data$lower),
abs(p_log_odds$continuous$data$upper))))
log_odds_breaks <- seq(-log_odds_limit, log_odds_limit, 2)
p_log_odds_continuous <- p_log_odds$continuous$data %>%
as_tibble() %>%
mutate(varlab = get_var_labs(.predictor.)) %>%
ggplot(aes(x = .xx., y = yhat)) +
facet_wrap(~varlab, scales = "free_x", ncol = 4) +
geom_line() +
geom_ribbon(aes(ymin = lower, ymax = upper), alpha = 0.3) +
geom_hline(yintercept = 0, linetype = "dashed", color = "grey") +
ylab("Log odds") +
scale_y_continuous(breaks = log_odds_breaks,
limits = c(-log_odds_limit, log_odds_limit)) +
theme(axis.title.x = element_blank(),
strip.text = element_text(size = 7))
# Discrete plot
log_odds_limit <- max(ceiling(c(abs(p_log_odds$discrete$data$lower),
abs(p_log_odds$discrete$data$upper))))
log_odds_breaks <- seq(-log_odds_limit, log_odds_limit, 1)
p_log_odds_discrete <- p_log_odds$discrete$data %>%
as_tibble() %>%
mutate(varlab = get_var_labs(.predictor.)) %>%
ggplot(aes(x = yhat, y = .xx.)) +
facet_wrap(~varlab, scales = "free_y", ncol = 4) +
geom_point(size = .9) +
geom_errorbarh(aes(xmin = lower , xmax = upper, height = 0)) +
geom_vline(xintercept = 0, linetype = "dashed", color = "grey") +
xlab("Log odds") +
scale_x_continuous(breaks = log_odds_breaks,
limits = c(-log_odds_limit, log_odds_limit)) +
theme(axis.title.y = element_blank(),
strip.text = element_text(size = 7))
# Combine plots
grid.arrange(p_log_odds_discrete, p_log_odds_continuous,
heights = c(4, 6))
The predicted log odds plot is a nice way to summarize non-linear effects, but it’s not the ultimate quantity of interest. We really care about predicted probabilities. But to compute a predicted probability we need to set all predictor variables in the model to certain values. Rather that choosing specific values (i.e., creating a representative patient), we will make predictions over a representative sample and average predictions across the sample.
Note that these plots show the impact of changes in predictor variables adjusted for all other variables in the model. The impact of a change in age from 55 to 60 on the probability of mortality is, for instance, estimated after conditioning on the values of the laboratory results, vital signs, comorbidities, other demographics, and calendar time. To assess predicted probabilities without adjustment refer to the univariate fits in the transformation of continuous variables section.
# Make newdata
## Start with a random sample of patients
n_samples <- 1000
train_sample <- mi_list[[1]] %>%
filter(.imp > 0) %>%
sample_n(size = n_samples)
We will write a general function to predict mortality as a function of (i) our fitted model, (ii) the representative sample, and (iii) the variables that we are varying.
expand_newdata <- function(data, vars_to_vary){
varnames_to_vary <- names(vars_to_vary)
expanded_vars_to_vary <- expand.grid(vars_to_vary)
data %>%
select(-all_of(varnames_to_vary)) %>%
crossing(expanded_vars_to_vary)
}
predict_mortality <- function(fit, newdata, vars_to_vary){
expanded_newdata <- expand_newdata(newdata, vars_to_vary)
fit_probs <- predict(fit, newdata = data.frame(expanded_newdata),
type = "fitted", se.fit = FALSE)
# Average predictions for variables to vary
expanded_newdata %>%
mutate(prob = fit_probs) %>%
group_by(across(all_of(names(vars_to_vary)))) %>%
summarise(prob = mean(prob))
}
predict_mortality_boot <- function(fit, newdata, vars_to_vary, B = 100){
n_obs <- nrow(newdata)
boot_probs <- vector(mode = "list", length = B)
for (b in 1:B){
indices <- sample(x = 1:n_obs, size = n_obs, replace = TRUE)
newdata_boot <- newdata[indices, ]
boot_probs[[b]] <- predict_mortality(fit, newdata_boot, vars_to_vary)
boot_probs[[b]][, "b"] <- b
}
return(boot_probs %>% bind_rows())
}
ages_to_vary <- seq(min(train_data$age), max(train_data$age), 1)
probs_age <- predict_mortality(lrm_fit_all, newdata = train_sample,
vars_to_vary = list(age = ages_to_vary)) %>%
rename(Age = age, `Mortality probability` = prob)
# Table
my_datatable(probs_age, filename = "mortprob_by_age") %>%
formatRound("Mortality probability", 4)
ggplot(probs_age, aes(x = `Age`, y = `Mortality probability`)) +
geom_line()
min_index_date <- min(train_data$index_date)
calendar_times_to_vary <- c(as.Date("2020-03-01"),
as.Date("2020-04-01"),
as.Date("2020-05-01")) - min_index_date
calendar_times_to_vary <- as.numeric(calendar_times_to_vary)
probs_calendar_time <- predict_mortality(
lrm_fit_all, newdata = train_sample,
vars_to_vary = list(calendar_time = calendar_times_to_vary)
) %>%
mutate(Date = min_index_date + calendar_time) %>%
rename(`Calendar time` = calendar_time, `Mortality probability` = prob) %>%
relocate(Date)
kable(probs_calendar_time) %>% kable_styling()
Date | Calendar time | Mortality probability |
---|---|---|
2020-03-01 | 10 | 0.2239372 |
2020-04-01 | 41 | 0.1721382 |
2020-05-01 | 71 | 0.1401583 |
We will start by predicting point estimates only.
# Predict point estimates
probs_age_calendar <- predict_mortality(
lrm_fit_all, newdata = train_sample,
vars_to_vary = list(age = ages_to_vary,
calendar_time = calendar_times_to_vary))
# Summarize in table
probs_age_calendar %>%
arrange(calendar_time, age) %>%
mutate(Date = min_index_date + calendar_time) %>%
select(-calendar_time) %>%
relocate(Date) %>%
rename(`Age` = age,
`Mortality probability` = prob) %>%
my_datatable(filename = "mortprob_by_age_time") %>%
formatRound("Mortality probability", 4)
We we will also compute 95% confidence intervals by bootstrapping.
bootprobs_age_calendar <- predict_mortality_boot(
lrm_fit_all, newdata = train_sample,
vars_to_vary = list(age = ages_to_vary,
calendar_time = calendar_times_to_vary),
B = n_boot_probs)
Let’s plot the results.
bootprobs_age_calendar %>%
# Summarize bootstrap results
group_by(age, calendar_time) %>%
summarise(
prob_mean = mean(prob),
prob_lower = quantile(prob, .025),
prob_upper = quantile(prob, .975)
) %>%
mutate(date = factor(min_index_date + calendar_time)) %>%
# Plot
ggplot(aes(x = age, y = prob_mean)) +
geom_line(aes(color = date)) +
geom_ribbon(aes(ymin = prob_lower, ymax = prob_upper, fill = date),
alpha = 0.2) +
xlab("Age") +
ylab("Mortality probability") +
scale_color_discrete("") +
scale_fill_discrete("") +
theme(legend.position = "bottom")
We will start by using bootstrapping to estimate model performance and check for whether our in-sample fits are too optimistic. Specifically, we will use the following algorithm implemented in the rms
package:
A shrinkage factor can also be estimated within each bootstrap sample to gauge the extent of overfitting. This is done by fitting \(g(Y) = \gamma_0 + \gamma_1 X\hat{\beta}\) where \(X\) and \(Y\) are the predictors and outcome, respectively, in the test sample (i.e., in step 2c) and \(\hat{\beta}\) is estimated in the training sample (i.e., step 2b). If there is no overfitting, then \(\gamma_0 = 0\) and \(\gamma_1 = 1\); conversely, if there is overfitting, then \(\gamma_1 < 1\) and \(\gamma_0 \neq 1\) to compensate.
lrm_val_age <- validate(lrm_fit_age, B = n_boot_val)
lrm_val_c <- validate(lrm_fit_c, B = n_boot_val)
lrm_val_d <- validate(lrm_fit_d, B = n_boot_val)
lrm_val_dc <- validate(lrm_fit_dc, B = n_boot_val)
lrm_val_all <- validate(lrm_fit_all, B = n_boot_val)
lrm_val_train <- list(lrm_val_age, lrm_val_c, lrm_val_d,
lrm_val_dc, lrm_val_all)
bind_cindex <- function(object){
n_rows <- nrow(object)
c_index <- (object["Dxy", 1:3] + 1)/2
c_index[4] <- c_index[2] - c_index[3]
c_index[5] <- c_index[1] - c_index[4]
c_index[6] <- object[1, 6]
return(rbind(object, c_index))
}
make_validation_tbl <- function(object){
object %>%
bind_cindex() %>%
set_colnames(c("(1) Original", "(2) Bootstrap training",
"(3) Bootstrap test", "Optimism: (2) - (3)",
"Original (corrected): (1) - (4)", "N")) %>%
kable() %>%
kable_styling()
}
make_validation_tbl(lrm_val_age)
|
|
|
Optimism: (2) - (3) | Original (corrected): (1) - (4) | N | |
---|---|---|---|---|---|---|
Dxy | 0.5491275 | 0.5504411 | 0.5491275 | 0.0013137 | 0.5478138 | 50 |
R2 | 0.2128736 | 0.2140936 | 0.2127503 | 0.0013434 | 0.2115302 | 50 |
Intercept | 0.0000000 | 0.0000000 | -0.0055930 | 0.0055930 | -0.0055930 | 50 |
Slope | 1.0000000 | 1.0000000 | 0.9966123 | 0.0033877 | 0.9966123 | 50 |
Emax | 0.0000000 | 0.0000000 | 0.0017803 | 0.0017803 | 0.0017803 | 50 |
D | 0.1323592 | 0.1332082 | 0.1322772 | 0.0009311 | 0.1314282 | 50 |
U | -0.0001464 | -0.0001464 | -0.0000055 | -0.0001409 | -0.0000055 | 50 |
Q | 0.1325056 | 0.1333547 | 0.1322827 | 0.0010720 | 0.1314337 | 50 |
B | 0.1159297 | 0.1158128 | 0.1159645 | -0.0001517 | 0.1160814 | 50 |
g | 1.4670934 | 1.4720104 | 1.4649418 | 0.0070686 | 1.4600248 | 50 |
gp | 0.1460443 | 0.1464139 | 0.1459707 | 0.0004432 | 0.1456011 | 50 |
c_index | 0.7745637 | 0.7752206 | 0.7745637 | 0.0006568 | 0.7739069 | 50 |
make_validation_tbl(lrm_val_c)
|
|
|
Optimism: (2) - (3) | Original (corrected): (1) - (4) | N | |
---|---|---|---|---|---|---|
Dxy | 0.4620295 | 0.4580554 | 0.4560464 | 0.0020090 | 0.4600205 | 50 |
R2 | 0.1382186 | 0.1384063 | 0.1363094 | 0.0020969 | 0.1361217 | 50 |
Intercept | 0.0000000 | 0.0000000 | -0.0164858 | 0.0164858 | -0.0164858 | 50 |
Slope | 1.0000000 | 1.0000000 | 0.9900233 | 0.0099767 | 0.9900233 | 50 |
Emax | 0.0000000 | 0.0000000 | 0.0052634 | 0.0052634 | 0.0052634 | 50 |
D | 0.0838930 | 0.0840050 | 0.0826839 | 0.0013211 | 0.0825718 | 50 |
U | -0.0001464 | -0.0001464 | -0.0000121 | -0.0001343 | -0.0000121 | 50 |
Q | 0.0840394 | 0.0841515 | 0.0826960 | 0.0014555 | 0.0825839 | 50 |
B | 0.1215780 | 0.1214799 | 0.1217953 | -0.0003154 | 0.1218934 | 50 |
g | 0.7727543 | 0.7727824 | 0.7661209 | 0.0066614 | 0.7660929 | 50 |
gp | 0.1097598 | 0.1097123 | 0.1088877 | 0.0008245 | 0.1089352 | 50 |
c_index | 0.7310148 | 0.7290277 | 0.7280232 | 0.0010045 | 0.7300103 | 50 |
make_validation_tbl(lrm_val_all)
|
|
|
Optimism: (2) - (3) | Original (corrected): (1) - (4) | N | |
---|---|---|---|---|---|---|
Dxy | 0.7650537 | 0.7697927 | 0.7615384 | 0.0082543 | 0.7567993 | 50 |
R2 | 0.4402338 | 0.4456944 | 0.4340642 | 0.0116302 | 0.4286036 | 50 |
Intercept | 0.0000000 | 0.0000000 | -0.0304069 | 0.0304069 | -0.0304069 | 50 |
Slope | 1.0000000 | 1.0000000 | 0.9712878 | 0.0287122 | 0.9712878 | 50 |
Emax | 0.0000000 | 0.0000000 | 0.0117978 | 0.0117978 | 0.0117978 | 50 |
D | 0.2963362 | 0.3006789 | 0.2915127 | 0.0091662 | 0.2871700 | 50 |
U | -0.0001464 | -0.0001464 | 0.0001691 | -0.0003156 | 0.0001691 | 50 |
Q | 0.2964827 | 0.3008253 | 0.2913435 | 0.0094817 | 0.2870009 | 50 |
B | 0.0896838 | 0.0888244 | 0.0903855 | -0.0015610 | 0.0912448 | 50 |
g | 2.3192859 | 2.3558240 | 2.2874964 | 0.0683277 | 2.2509582 | 50 |
gp | 0.2020541 | 0.2032803 | 0.2008164 | 0.0024639 | 0.1995902 | 50 |
c_index | 0.8825268 | 0.8848963 | 0.8807692 | 0.0041272 | 0.8783997 | 50 |
We will first use the bootstrap so that we can generate bias adjusted predictions.
# Calibrate
lrm_cal_age <- calibrate(lrm_fit_age, B = n_boot_val)
lrm_cal_c <- calibrate(lrm_fit_c, B = n_boot_val)
lrm_cal_d <- calibrate(lrm_fit_d, B = n_boot_val)
lrm_cal_dc <- calibrate(lrm_fit_dc, B = n_boot_val)
lrm_cal_all <- calibrate(lrm_fit_all, B = n_boot_val)
lrm_cal_list <- list(lrm_cal_age, lrm_cal_c, lrm_cal_d, lrm_cal_dc, lrm_cal_all)
names(lrm_cal_list) <- lrm_names
Calibration plots can then be created.
plot_calibration <- function(object){
# Make tibble
cal_df <- map2(object, names(object), function(x, y){
x[, ] %>%
as_tibble() %>%
mutate(model = y)
}) %>%
bind_rows() %>%
mutate(model = factor(model, levels = lrm_names))
# Plot
breaks <- seq(0, 1, .2)
ggplot() +
geom_line(data = cal_df, mapping = aes(x = predy, y = calibrated.orig,
color = "Apparent")) +
geom_line(data = cal_df, mapping = aes(x = predy, y = calibrated.corrected,
color = "Bias-corrected")) +
geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "grey") +
facet_wrap(~model) +
scale_x_continuous(breaks = breaks, limits = c(0, 1)) +
scale_y_continuous(breaks = breaks, limits = c(0, 1)) +
xlab("Predicted probability") +
ylab("Actual probability") +
scale_colour_manual(name = "",
values = c("Apparent" = "black",
"Bias-corrected" = "red")) +
theme(legend.position = "bottom")
}
plot_calibration(lrm_cal_list)
Calibration at the tails of the predictive distribution might be based on fewer patients, so it can be useful to plot the distribution of predicted probabilities. We start with histograms.
predprobs_df <- map2(lrm_cal_list, names(lrm_cal_list), function (x, model_name){
x_df <- tibble(model =model_name, predicted = attr(x, "predicted"))
}) %>%
bind_rows() %>%
mutate(model = factor(model, levels = names(lrm_cal_list)))
plot_predprobs_hist <- function(data){
ggplot(data, aes(x = predicted)) +
facet_wrap(~model, scales = "free_y") +
geom_histogram(fill = "black", bins = 200) +
scale_x_continuous(limits = c(0, 1), breaks = seq(0, 1, by = 0.1)) +
scale_y_continuous(breaks = function (x) floor(seq(0, .9 * max(x),
length.out = 7))) +
xlab("Predicted Probability") +
ylab("Count")
}
plot_predprobs_hist(predprobs_df)
The cumulative density function (CDF) is also informative since it provides the proportion of patients with predicted probabilities above and below given thresholds. For example, we can see that in the full model only a proportion 0.09 have predicted probabilities above 0.5.
plot_predprobs_cdf <- function(data){
ggplot(data, aes(x = predicted, col = model)) +
stat_ecdf() +
scale_x_continuous(limits = c(0, 1), breaks = seq(0, 1, by = 0.1)) +
scale_y_continuous(breaks = seq(0, 1, by = 0.1)) +
xlab("Predicted Probability") +
ylab("CDF") +
scale_color_discrete(name = "") +
theme(legend.position = "bottom")
}
plot_predprobs_cdf(predprobs_df)
The specification of the model has been finalized so we are ready to assess it on the test set.
To make predictions on the training set, we must preprocess the data in the same way that we preprocessed the training data.
test_data <- readRDS("test_data.rds") %>%
filter_ie()
## 183 (5.07%) patients were dropped due to the 2-week cutoff.
# Clean data
test_data <- clean_data(test_data)
# Truncate labs
test_data <- add_truncated_lab_vars(test_data, v = lab_vars$var)
Multiple imputation will be performed on the combined train and test set.
# Missing data imputation
## rbind training and test sets
train_data$set = "train"
test_data$set = "test"
train_test_data <- bind_rows(train_data, test_data)
## Multiple imputation via MICE
mice_train_test <- train_test_data %>%
select(one_of(mice_vars)) %>%
mutate_if(is.character, as.factor) %>%
mice(m = n_imputations, maxit = 5)
## Add death and subset to test set
mi_df_test <- complete(mice_train_test, action = "long", include = FALSE) %>%
as_tibble() %>%
mutate(died = rep(train_test_data$died, mice_train_test$m),
set = rep(train_test_data$set, mice_train_test$m)) %>%
filter(set == "test")
## Create race/ethnicity variable
mi_df_test <- add_race_ethnicity(mi_df_test)
Before making predictions, it is helpful to compare the training and test sets.
tbl1_train_test <- train_test_data %>%
select(one_of("died", "set", tbl1_varnames)) %>%
rename_with(get_var_labs, .cols = all_of(tbl1_varnames)) %>%
rename(Died = died) %>%
mutate(set = ifelse(set == "train", "Train set", "Test set")) %>%
CreateTableOne(vars = c("Died", get_var_labs(tbl1_varnames)),
factorVars = c("Died", get_var_labs(tbl1_cat_varnames)),
strata = "set", addOverall = TRUE,
data = .)
print(tbl1_train_test, nonnormal = get_var_labs(tbl1_non_normal_varnames),
cramVars = c("Sex"),
contDigits = 3, missing = TRUE, printToggle = FALSE) %>%
set_colnames(c("Overall", "Test set", "Train set",
"p", "Test", "Missing")) %>%
kable() %>%
kable_styling()
Overall | Test set | Train set | p | Test | Missing | |
---|---|---|---|---|---|---|
n | 17086 | 3428 | 13658 | |||
Died = 1 (%) | 2660 (15.6) | 497 (14.5) | 2163 (15.8) | 0.057 | 0.0 | |
Age (median [IQR]) | 62.000 [49.000, 75.000] | 62.000 [49.000, 74.000] | 62.000 [49.000, 75.000] | 0.416 | nonnorm | 0.0 |
Calendar time (median [IQR]) | 47.000 [38.000, 64.000] | 47.000 [37.000, 64.000] | 47.000 [38.000, 64.000] | 0.272 | nonnorm | 0.0 |
Geographic division (%) | <0.001 | 2.7 | ||||
East North Central | 5835 (35.1) | 1208 (36.3) | 4627 (34.8) | |||
Middle Atlantic | 5900 (35.5) | 1264 (38.0) | 4636 (34.9) | |||
New England | 1915 (11.5) | 332 (10.0) | 1583 (11.9) | |||
Other | 238 ( 1.4) | 47 ( 1.4) | 191 ( 1.4) | |||
Pacific | 591 ( 3.6) | 80 ( 2.4) | 511 ( 3.8) | |||
South Atl/West South Crl | 455 ( 2.7) | 91 ( 2.7) | 364 ( 2.7) | |||
West North Central | 1690 (10.2) | 306 ( 9.2) | 1384 (10.4) | |||
Race/Ethnicity (%) | 0.069 | 26.0 | ||||
Non-Hispanic white | 7071 (56.0) | 1424 (55.9) | 5647 (56.0) | |||
Asian | 444 ( 3.5) | 82 ( 3.2) | 362 ( 3.6) | |||
Hispanic | 700 ( 5.5) | 167 ( 6.6) | 533 ( 5.3) | |||
Non-Hispanic black | 4423 (35.0) | 876 (34.4) | 3547 (35.2) | |||
Sex = Female/Male (%) | 8223/8857 (48.1/51.9) | 1660/1766 (48.5/51.5) | 6563/7091 (48.1/51.9) | 0.700 | 0.0 | |
Smoking (%) | 0.367 | 25.7 | ||||
Current | 1078 ( 8.5) | 212 ( 8.4) | 866 ( 8.5) | |||
Never | 7714 (60.8) | 1507 (59.7) | 6207 (61.1) | |||
Previous | 3896 (30.7) | 804 (31.9) | 3092 (30.4) | |||
Acute myocardial infarction = Yes (%) | 1896 (11.1) | 361 (10.5) | 1535 (11.2) | 0.250 | 0.0 | |
AIDS/HIV = Yes (%) | 136 ( 0.8) | 35 ( 1.0) | 101 ( 0.7) | 0.121 | 0.0 | |
Cancer = Yes (%) | 2087 (12.2) | 409 (11.9) | 1678 (12.3) | 0.591 | 0.0 | |
Cerebrovascular disease = Yes (%) | 1807 (10.6) | 368 (10.7) | 1439 (10.5) | 0.758 | 0.0 | |
Chronic pulmonary disease = Yes (%) | 4623 (27.1) | 996 (29.1) | 3627 (26.6) | 0.003 | 0.0 | |
Congestive heart failure = Yes (%) | 2921 (17.1) | 596 (17.4) | 2325 (17.0) | 0.631 | 0.0 | |
Dementia = Yes (%) | 1706 (10.0) | 312 ( 9.1) | 1394 (10.2) | 0.058 | 0.0 | |
Diabetes = Yes (%) | 5800 (33.9) | 1188 (34.7) | 4612 (33.8) | 0.336 | 0.0 | |
Hemiplegia or paraplegia = Yes (%) | 419 ( 2.5) | 89 ( 2.6) | 330 ( 2.4) | 0.584 | 0.0 | |
Hypertension = Yes (%) | 10021 (58.7) | 2018 (58.9) | 8003 (58.6) | 0.787 | 0.0 | |
Metastatic cancer = Yes (%) | 355 ( 2.1) | 78 ( 2.3) | 277 ( 2.0) | 0.401 | 0.0 | |
Mild liver disease = Yes (%) | 1096 ( 6.4) | 217 ( 6.3) | 879 ( 6.4) | 0.852 | 0.0 | |
Moderate/severe liver disease = Yes (%) | 164 ( 1.0) | 36 ( 1.1) | 128 ( 0.9) | 0.611 | 0.0 | |
Peptic ulcer disease = Yes (%) | 252 ( 1.5) | 46 ( 1.3) | 206 ( 1.5) | 0.520 | 0.0 | |
Peripheral vascular disease = Yes (%) | 2103 (12.3) | 432 (12.6) | 1671 (12.2) | 0.578 | 0.0 | |
Renal disease = Yes (%) | 3541 (20.7) | 708 (20.7) | 2833 (20.7) | 0.927 | 0.0 | |
Rheumatoid disease = Yes (%) | 500 ( 2.9) | 102 ( 3.0) | 398 ( 2.9) | 0.893 | 0.0 | |
CCI (median [IQR]) | 1.000 [0.000, 3.000] | 1.000 [0.000, 3.000] | 1.000 [0.000, 3.000] | 0.124 | nonnorm | 0.0 |
Body Mass Index (BMI) (median [IQR]) | 29.660 [25.510, 35.010] | 29.560 [25.500, 34.623] | 29.690 [25.510, 35.130] | 0.339 | nonnorm | 12.0 |
Diastolic blood pressure (median [IQR]) | 73.000 [65.500, 80.500] | 73.000 [65.000, 80.500] | 73.000 [65.500, 80.500] | 0.803 | nonnorm | 3.5 |
Heart rate (median [IQR]) | 87.500 [77.500, 98.000] | 87.500 [78.000, 98.000] | 87.500 [77.500, 98.000] | 0.906 | nonnorm | 3.4 |
Oxygen saturation (median [IQR]) | 96.000 [94.000, 98.000] | 96.000 [94.000, 98.000] | 96.000 [94.000, 98.000] | 0.979 | nonnorm | 4.3 |
Respiration rate (median [IQR]) | 20.000 [18.000, 22.000] | 20.000 [18.000, 22.000] | 20.000 [18.000, 22.000] | 0.980 | nonnorm | 4.2 |
Systolic blood pressure (median [IQR]) | 126.000 [115.000, 139.000] | 126.000 [115.000, 139.000] | 126.000 [115.000, 139.000] | 0.663 | nonnorm | 3.5 |
Temperature (median [IQR]) | 37.000 [36.700, 37.400] | 37.000 [36.700, 37.400] | 37.000 [36.700, 37.400] | 0.200 | nonnorm | 3.5 |
Alanine aminotransferase (ALT) (median [IQR]) | 28.000 [18.000, 46.000] | 28.000 [18.000, 46.000] | 28.000 [18.000, 46.000] | 0.791 | nonnorm | 20.3 |
Aspartate aminotransferase (AST) (median [IQR]) | 37.000 [25.000, 58.000] | 38.000 [25.000, 59.000] | 37.000 [25.000, 58.000] | 0.383 | nonnorm | 21.3 |
C-reactive protein (CRP) (median [IQR]) | 79.100 [33.800, 141.360] | 79.900 [31.600, 147.083] | 79.100 [34.000, 140.000] | 0.615 | nonnorm | 39.3 |
Creatinine (median [IQR]) | 1.000 [0.800, 1.415] | 1.000 [0.800, 1.430] | 1.000 [0.800, 1.410] | 0.410 | nonnorm | 10.7 |
Ferritin (median [IQR]) | 509.000 [223.000, 1082.500] | 505.000 [220.000, 1085.000] | 510.000 [224.000, 1080.000] | 0.897 | nonnorm | 44.2 |
Fibrin D-Dimer (median [IQR]) | 762.000 [405.000, 1600.000] | 880.000 [460.500, 1810.000] | 750.000 [390.000, 1540.750] | 0.083 | nonnorm | 90.5 |
Lactate dehydrogenase (LDH) (median [IQR]) | 321.000 [238.000, 441.000] | 316.000 [237.000, 443.000] | 321.000 [238.000, 441.000] | 0.925 | nonnorm | 45.7 |
Lymphocyte count (median [IQR]) | 0.990 [0.700, 1.400] | 0.973 [0.700, 1.400] | 1.000 [0.700, 1.400] | 0.342 | nonnorm | 11.4 |
Neutrophil count (median [IQR]) | 4.900 [3.400, 7.160] | 4.940 [3.440, 7.330] | 4.900 [3.370, 7.100] | 0.083 | nonnorm | 11.5 |
Platelet count (PLT) (median [IQR]) | 202.000 [157.000, 261.000] | 203.000 [157.000, 264.000] | 202.000 [157.000, 260.000] | 0.438 | nonnorm | 10.1 |
Procalcitonin (median [IQR]) | 0.130 [0.070, 0.370] | 0.140 [0.070, 0.361] | 0.130 [0.070, 0.370] | 0.684 | nonnorm | 49.5 |
Troponin I (median [IQR]) | 0.010 [0.010, 0.050] | 0.010 [0.010, 0.050] | 0.010 [0.010, 0.050] | 0.243 | nonnorm | 41.2 |
White blood cell count (WBC) (median [IQR]) | 6.700 [4.900, 9.130] | 6.800 [5.000, 9.400] | 6.700 [4.900, 9.100] | 0.035 | nonnorm | 9.9 |
Our main predictions are based on the unpenalized logistic regression.
lrm_fits <- list(age = lrm_fit_age,
c = lrm_fit_c,
d = lrm_fit_d,
dc = lrm_fit_dc,
all= lrm_fit_all)
lrm_prob_names <- paste0("lrm_probs_", names(lrm_fits))
for (i in 1:length(lrm_fits)){
mi_df_test[[lrm_prob_names[i]]] <- predict(lrm_fits[[i]],
newdata = as.data.frame(mi_df_test),
type = "fitted", se.fit = FALSE)
}
Let’s evaluate the models using the metrics returned by rms::val.prob()
.
lrm_val_test <- vector(mode = "list", length = length(lrm_fits))
names(lrm_val_test) <- lrm_names
for (i in 1:length(lrm_val_test)){
lrm_val_test[[i]] <- val.prob(p = mi_df_test[[lrm_prob_names[i]]],
y = mi_df_test$died,
pl = FALSE)
}
summarize_val_probs <- function(object){
map_df(object, function (x) {
as_tibble(matrix(x, nrow = 1)) %>%
set_colnames(names(x))
}, .id = "Model") %>%
kable(digits = 4) %>%
kable_styling() %>%
scroll_box(width = "100%")
}
summarize_val_probs(lrm_val_test)
Model | Dxy | C (ROC) | R2 | D | D:Chi-sq | D:p | U | U:Chi-sq | U:p | Q | Brier | Intercept | Slope | Emax | E90 | Eavg | S:z | S:p |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Age only | 0.5116 | 0.7558 | 0.1728 | 0.1023 | 1754.349 | NA | 0.0016 | 30.1185 | 0e+00 | 0.1007 | 0.1111 | -0.2207 | 0.9038 | 0.0436 | 0.0315 | 0.0147 | -2.4299 | 0.0151 |
Comorbidities only | 0.4372 | 0.7186 | 0.1187 | 0.0691 | 1185.986 | NA | 0.0009 | 17.4343 | 2e-04 | 0.0682 | 0.1151 | -0.1358 | 0.9707 | 0.2426 | 0.0291 | 0.0216 | -2.4338 | 0.0149 |
All demographics | 0.5464 | 0.7732 | 0.2007 | 0.1198 | 2055.148 | NA | 0.0010 | 19.8757 | 0e+00 | 0.1188 | 0.1082 | -0.1611 | 0.9457 | 0.0272 | 0.0219 | 0.0142 | -3.0846 | 0.0020 |
Demographics and comorbidities | 0.5807 | 0.7904 | 0.2267 | 0.1365 | 2340.133 | NA | 0.0010 | 19.5298 | 1e-04 | 0.1354 | 0.1062 | -0.1531 | 0.9523 | 0.1353 | 0.0175 | 0.0119 | -2.2529 | 0.0243 |
All variables | 0.7474 | 0.8737 | 0.3994 | 0.2546 | 4365.199 | NA | 0.0020 | 36.2544 | 0e+00 | 0.2526 | 0.0879 | -0.1997 | 0.9304 | 0.1266 | 0.0261 | 0.0116 | 0.1136 | 0.9096 |
We will focus on the Brier score and C-Index and assess performance on both the training and test sets.
summarize_performance <- function(val_train, val_test){
# Get metrics
train_metrics <- map_df(val_train, function (x) {
x[c("Dxy", "B"), "index.orig"]
}) %>%
mutate(Dxy = (Dxy + 1)/2) %>%
rename(c_index = Dxy)
test_metrics <- map_df(val_test, function(x) x[c("C (ROC)", "Brier")])
# Make table
bind_cols(
tibble(Model = names(val_test)),
train_metrics,
test_metrics
) %>%
kable(col.names = c("Model", "C-Index", "Brier", "C-Index", "Brier"),
digits = 4) %>%
kable_styling() %>%
add_header_above(c(" ", "Train" = 2, "Test" = 2))
}
summarize_performance(val_train = lrm_val_train, val_test = lrm_val_test)
Model | C-Index | Brier | C-Index | Brier |
---|---|---|---|---|
Age only | 0.7746 | 0.1159 | 0.7558 | 0.1111 |
Comorbidities only | 0.7310 | 0.1216 | 0.7186 | 0.1151 |
All demographics | 0.7848 | 0.1143 | 0.7732 | 0.1082 |
Demographics and comorbidities | 0.8018 | 0.1118 | 0.7904 | 0.1062 |
All variables | 0.8825 | 0.0897 | 0.8737 | 0.0879 |
Like we did with the training set, we will plot calibration curves.
make_cal_test_data <- function(data, p , model_names){
data[, c("died", ".imp", p)] %>%
pivot_longer(cols = c(p), names_to = "model",
values_to = "predicted") %>%
rename(actual = died) %>%
mutate(model = model_names[match(model, p)],
model = factor(model, levels = model_names))
}
plot_calibration_test <- function(cal_data){
ggplot(cal_data, aes(x = predicted, y = actual)) +
facet_wrap(~model) +
geom_smooth(se = FALSE, method = "loess", formula = y ~ x, color = "black",
size = 1) +
geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "grey") +
scale_x_continuous(breaks = seq(0, 1, .2), limits = c(0, 1)) +
scale_y_continuous(breaks = seq(0, 1, .2), limits = c(0, 1)) +
xlab("Predicted probability") +
ylab("Actual probability") +
theme(legend.position = "bottom")
}
cal_test_data <- make_cal_test_data(mi_df_test, p = lrm_prob_names,
model_names = lrm_names)
plot_calibration_test(cal_test_data)
It is again useful to visualize the distribution of predicted probabilities since calibration at the tails might be based on a smaller number of patients.
plot_predprobs_hist(cal_test_data)
plot_predprobs_cdf(cal_test_data)
We will run sensitivity analyses on the full model to see if penalization can improve calibration and/or discrimination. The three models that we will consider are (i) logistic regression with a ridge penalty, (ii) logistic regression with a lasso penalty, and (iii) the standard unpenalized logistic regression (i.e., from the results presented above).
The penalized logistic regression models can be fit using glmnet
.
x_train <- mi_list %>% map(function(data)
make_x(data, model_rhs)
)
# Cross validation
lasso_cvfits <- ridge_cvfits <- vector(mode = "list", length = length(x_train))
for (i in 1:length(lasso_cvfits)){
ridge_cvfits[[i]] <- cv.glmnet(x = x_train[[i]], y = y[[i]],
family = "binomial", alpha = 0)
lasso_cvfits[[i]] <- cv.glmnet(x = x_train[[i]], y = y[[i]],
family = "binomial", alpha = 1)
}
Now let’s make predictions for the penalized models. The prediction for a given patient is the average prediction across the imputed datasets.
predict_glmnet_mi <- function(fits, newx){
n_fits <- length(fits)
pred <- matrix(NA, nrow = nrow(newx), ncol = n_fits)
for (j in 1:n_fits){
pred[, j] <- predict(fits[[i]], newx = x_test,
s = "lambda.1se", type = "response")
}
return(apply(pred, 1, mean))
}
x_test <- make_x(mi_df_test, model_rhs)
mi_df_test$ridge_probs_all <- predict_glmnet_mi(ridge_cvfits, x_test)
mi_df_test$lasso_probs_all <- predict_glmnet_mi(lasso_cvfits, x_test)
With predictions in hand, we can evaluate the model.
penalized_val_test <- vector(mode = "list", length = 2)
names(penalized_val_test) <- c("Lasso", "Ridge")
penalized_prob_names <- paste0(c("lasso", "ridge"), "_probs_all")
for (i in 1:length(penalized_val_test)){
penalized_val_test[[i]] <- val.prob(p = mi_df_test[[penalized_prob_names[i]]],
y = mi_df_test$died,
pl = FALSE)
}
We will first assess various evaluation metrics with an emphasis on the Brier score and C-Index.
summarize_val_probs(penalized_val_test)
Model | Dxy | C (ROC) | R2 | D | D:Chi-sq | D:p | U | U:Chi-sq | U:p | Q | Brier | Intercept | Slope | Emax | E90 | Eavg | S:z | S:p |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Lasso | 0.7294 | 0.8647 | 0.3768 | 0.2383 | 4086.133 | NA | 0.0008 | 15.6733 | 4e-04 | 0.2375 | 0.0906 | -0.1357 | 0.9541 | 0.1335 | 0.0125 | 0.0100 | 0.1938 | 0.8464 |
Ridge | 0.7258 | 0.8629 | 0.3517 | 0.2206 | 3782.492 | NA | 0.0134 | 232.2089 | 0e+00 | 0.2072 | 0.0936 | -0.3918 | 0.9855 | 0.1767 | 0.0782 | 0.0367 | -7.5871 | 0.0000 |
Finally, we will compare calibration curves for each of the three models.
make_cal_test_data(mi_df_test,
p = c(penalized_prob_names, "lrm_probs_all"),
model_names = c(names(penalized_val_test), "No penalty")) %>%
plot_calibration_test()
sessionInfo()
## R version 4.0.0 (2020-04-24)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 18.04.5 LTS
##
## Matrix products: default
## BLAS: /usr/lib/x86_64-linux-gnu/openblas/libblas.so.3
## LAPACK: /usr/lib/x86_64-linux-gnu/libopenblasp-r0.2.20.so
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## attached base packages:
## [1] splines stats graphics grDevices datasets utils methods
## [8] base
##
## other attached packages:
## [1] tidyr_1.1.2 tibble_3.0.4 tableone_0.12.0 rms_6.0-0
## [5] SparseM_1.78 Hmisc_4.4-0 Formula_1.2-3 survival_3.2-7
## [9] lattice_0.20-41 rcompanion_2.3.25 purrr_0.3.4 mice_3.13.0
## [13] magrittr_2.0.1 kableExtra_1.2.1 knitr_1.30 gridExtra_2.3
## [17] oem_2.0.10 bigmemory_4.5.36 glmnet_4.1 Matrix_1.3-0
## [21] ggplot2_3.3.3 DT_0.15 dplyr_1.0.2 corrr_0.4.2
##
## loaded via a namespace (and not attached):
## [1] TH.data_1.0-10 bigmemory.sri_0.1.3 colorspace_2.0-0
## [4] class_7.3-18 ellipsis_0.3.1 modeltools_0.2-23
## [7] htmlTable_2.0.0 base64enc_0.1-3 rstudioapi_0.13
## [10] farver_2.0.3 MatrixModels_0.4-1 fansi_0.4.2
## [13] mvtnorm_1.1-1 coin_1.3-1 xml2_1.3.2
## [16] codetools_0.2-18 libcoin_1.0-5 jsonlite_1.7.2
## [19] broom_0.7.3 cluster_2.1.1 png_0.1-7
## [22] compiler_4.0.0 httr_1.4.2 backports_1.2.1
## [25] assertthat_0.2.1 survey_4.0 cli_2.2.0
## [28] acepack_1.4.1 htmltools_0.5.0 quantreg_5.75
## [31] tools_4.0.0 gtable_0.3.0 glue_1.4.2
## [34] Rcpp_1.0.5 vctrs_0.3.6 nlme_3.1-151
## [37] conquer_1.0.2 crosstalk_1.1.1 iterators_1.0.13
## [40] lmtest_0.9-37 xfun_0.19 stringr_1.4.0
## [43] rvest_0.3.6 lifecycle_1.0.0 renv_0.12.3
## [46] polspline_1.1.19 MASS_7.3-53 zoo_1.8-8
## [49] scales_1.1.1 hms_0.5.3 parallel_4.0.0
## [52] sandwich_2.5-1 expm_0.999-4 RColorBrewer_1.1-2
## [55] yaml_2.2.1 Exact_2.0 labelled_2.5.0
## [58] EMT_1.1 rpart_4.1-15 latticeExtra_0.6-29
## [61] stringi_1.5.3 highr_0.8 foreach_1.5.1
## [64] nortest_1.0-4 e1071_1.7-4 checkmate_2.0.0
## [67] boot_1.3-27 shape_1.4.5 rlang_0.4.10
## [70] pkgconfig_2.0.3 matrixStats_0.57.0 evaluate_0.14
## [73] htmlwidgets_1.5.3 labeling_0.4.2 tidyselect_1.1.0
## [76] plyr_1.8.6 R6_2.5.0 DescTools_0.99.37
## [79] generics_0.1.0 multcompView_0.1-8 multcomp_1.4-13
## [82] DBI_1.1.1 haven_2.3.1 mgcv_1.8-33
## [85] pillar_1.4.7 foreign_0.8-81 withr_2.3.0
## [88] nnet_7.3-15 crayon_1.4.1 rmarkdown_2.4
## [91] jpeg_0.1-8.1 grid_4.0.0 data.table_1.13.6
## [94] forcats_0.5.1 digest_0.6.27 webshot_0.5.2
## [97] stats4_4.0.0 munsell_0.5.0 viridisLite_0.3.0
## [100] mitools_2.4