Function_logistic.R

###################################################
### Refit the model with the selected variables ###
###################################################

### Inputs:
        # x: design matrix
        # y: response vector
        # beta: original coeffecients
### Output:
        # updated coefficients
        
update.beta <- function(x, y, beta) {
        pos <- which(abs(beta[-1]) > 1e-10)
        beta.refit <- rep(0, length(beta))
        if (length(pos) >= 1) {
                refit <- glm(y ~ x[, pos], family = binomial)
                beta.refit[c(1, pos + 1)] <- refit$coef
                return(beta.refit)
        }
        else {
        refit <- glm(y ~ 1, family = binomial)
        beta.refit[1] <- refit$coef
        return(beta.refit)
    }
}

################################
### Calculate log likelihood ###
################################

### Inputs:
        # as in update.beta
### Output:
        # log likelihood
        
log.likelihood <- function(x, y, beta) {
    eta <- cbind(rep(1, nrow(x)), x) %*% beta
    return(sum(y * eta - log(1 + exp(eta)))/nrow(x))
}

#########################
### logistic function ###
#########################

logit.inv <- function(x) {
    return(exp(x)/(1 + exp(x)))
}

#####################################################
### Select the optimal lambda by cross validation ###
#####################################################

### This is an internal function.

### Inputs:
        # lambda.seq: the sequence of lambda to search
        # x: design matrix
        # y: response vector
        # eta: the value of eta, the parameter tuning the prior information [(6) in the paper]
        # tau: the value of tau, the parameter tuning the L1 penalty weights [(35) in the paper]
        # beta.prior: either the prior estimator or the initial estimator
        # index: a vector of indicators of which predictors are penalized (yes/no: 1/0)
        # cv.group: the index of cross validation groups
        # is.refit: is the estimator refitted (yes/no: 1/0)?
### Output:
        # the optimal lambda with cross validation

select.lambda <- function(lambda.seq, x, y, eta, tau, beta.prior, index, cv.group, is.refit)
{
        n.group <- length(unique(cv.group))
        y.prior <- logit.inv(cbind(1, x) %*% beta.prior)
        y.tilde <- (y + eta * y.prior)/(1 + eta)
        Winv <- 1 + tau * abs(beta.prior[-1])
        x.Winv <- x * (matrix(1, nrow = nrow(x)) %*% Winv)
        x.Winv.1 <- cbind(1, x.Winv)
        log.liklhd <- matrix(0, length(lambda.seq), n.group)

        fits <- grplasso(x.Winv.1, y.tilde, index = index, lambda = lambda.seq, standardize = FALSE, control = grpl.control(trace = 0))
        betas <- fits$coef
        
        for(j in 1 : n.group)
        {
                sel <- (cv.group != j)
                x.Winv.train <- x.Winv[sel, ]
                x.Winv.1.train <- x.Winv.1[sel, ]
                y.tilde.train <- y.tilde[sel]
                x.tune <- x[!sel, ]
                y.tune <- y[!sel]
                fits <- grplasso(x.Winv.1.train, y.tilde.train, index = index, lambda = lambda.seq, standardize = FALSE, control = grpl.control(trace = 0))

                for(i in 1 : length(lambda.seq))
                {
                        beta <- fits$coef[, i]
                        beta.refit <- rep(0, length(beta))
                        if(is.refit) {
                                temp <- update.beta(x.Winv.train, y.tilde.train, beta)
                        }
                        else {
                                temp <- beta
                        }
                        beta.refit[1] <- temp[1]
                        beta.refit[-1] <- Winv * temp[-1]
                        log.liklhd[i, j] <- log.likelihood(x.tune, y.tune, beta.refit)
                }
        }
        
        # delete an all -Inf column: added to avoid code mistakes
        # print(log.liklhd)
        log.liklhd <- matrix(log.liklhd[, apply(log.liklhd, 2, max) > -Inf], nrow = nrow(log.liklhd))
        if (nrow(log.liklhd) > 1 && ncol(log.liklhd) == 0) {
                pos.1se <- 2
                mean.log.liklhd <- rep(-Inf, nrow(log.liklhd))
        }
        # in case there is only one lambda: added to avoid code mistakes
        else if(nrow(log.liklhd) == 1 && ncol(log.liklhd) == 0) {
                pos.1se <- 1
                mean.log.liklhd <- -Inf
        }
        else {
                mean.log.liklhd <- apply(log.liklhd, 1, mean)
                # in case all mean.log.liklhd are -Inf: added to avoid code mistakes
                if(max(mean.log.liklhd) == -Inf) {
                        pos.1se <- 2
                }
                else {
                        pos.max <- which.max(mean.log.liklhd)
                        # cat("pos.max = ", pos.max, "\n")
                        if (ncol(log.liklhd) > 1) {
                                se.log.liklhd <- apply(log.liklhd, 1, sd)/sqrt(ncol(log.liklhd))
                                # one standard error rule
                                temp <- which(mean.log.liklhd >= mean.log.liklhd[pos.max] - se.log.liklhd[pos.max])
                                pos.1se <- temp[1]
                                # cat("pos.1se = ", pos.1se, "\n")
                        }
                        else if (ncol(log.liklhd) == 1) {
                                pos.1se <- pos.max
                        }
                }
        }
        # cat("pos.1se is:", pos.1se, "\n")
        beta.1se <- betas[, pos.1se]

        beta.refit.1se <- rep(0, length(beta.1se))
        if(is.refit) {
                temp <- update.beta(x.Winv, y.tilde, beta.1se)
        }
        else {
                temp <- beta.1se
        }

        beta.refit.1se[1] <- temp[1]
        beta.refit.1se[-1] <- Winv * temp[-1]
    return(list(position = pos.1se, lambda = lambda.seq[pos.1se], num.nzero = sum(abs(beta.refit.1se[-1]) > 1e-10),
                beta.refit.1se = beta.refit.1se, log.liklhd.1se = mean.log.liklhd[pos.1se]))
}

#########################################################
### pLASSO in logistic regression by cross validation ###
#########################################################

### This is the main function for linear regression with cross validation.
### It can find each unweighted penalization estimators: LASSO, p, pLASSO.
### It can find each weighted penalization estimators: LASSO-A, p-A, pLASSO-A.
### For details of the above terms, refer to the paper.
### Note: to increase computational efficiency, the algorithm performs two search steps: 
        # Step 1: search from lambda.max to lambda.min with length length.1
        # Step 2: if step 1 gets lambda > lambda.min, do a refinement search between two neighbor lambda with length.2
        #       : if step 1 gets lambda = lambda.min, do a refinement search between lambda and lambda/2 with length.2
        # The search will stop whenever #nonzero of the estimator > n.threshold

### Inputs:
        # x: design matrix
        # y: response vector
        # eta.seq: the sequence of eta, the parameter tuning the prior information [(6) in the paper]
                # for LASSO/p/pLASSO, use eta sequence
                # for LASSO-A/p-A/pLASSO-A, use 0
        # tau.seq: the sequence of tau, the parameter tuning the L1 penalty weights [(35) in the paper]
                # for LASSO/p/pLASSO, use 0
                # for LASSO-A/p-A/pLASSO-A, use tau sequence
        # beta.prior: either the prior estimator or the initial estimator
                # for LASSO/p: use a vector with all 0
                # for pLASSO: use prior estimator
                # for LASSO-A: use LASSO estimator
                # for p-A: use prior estimator
                # for pLASSO-A: use pLASSO estimator
        # index: the indicators of which predictors are penalized
        # cv.group: the index of cross validation groups
        # lambda.min: the minimum lambda value of search
        # length.1: length of search for lambda in step 1
        # length.2: length of search for lambda in step 2
        # n.threshold: stops the search whenever #nonzero > n.threshold
        # is.refit: is the estimator refitted (yes/no: 1/0)?
### Output:
        # the optimal estimator with cross validation

find.beta <- function(x, y, eta.seq, tau.seq, beta.prior, index, cv.group, lambda.min, length.1, length.2, n.threshold, is.refit)
{
        log.liklhd.eta.tau <- matrix(0, length(eta.seq), length(tau.seq))
        betas.eta.tau <- array(0, dim = c(ncol(x) + 1, length(eta.seq), length(tau.seq)))
        y.prior <- logit.inv(cbind(1, x) %*% beta.prior)
        
        for(k in 1 : length(eta.seq))
        {
                eta <- eta.seq[k]
                y.tilde <- (y + eta * y.prior)/(1 + eta)
                
                for(l in 1 : length(tau.seq))
                {       
                        tau <- tau.seq[l]
                        Winv <- 1 + tau * abs(beta.prior[-1])
                        x.Winv <- x * (matrix(1, nrow = nrow(x)) %*% Winv)
                        x.Winv.1 <- cbind(1, x.Winv)
                        lambda.max <- lambdamax(x = x.Winv.1, y = y.tilde, index = index, standardize = FALSE)
                        # cat("lambda.max is:", lambda.max, "\n")
                        
                        # when a NaN lambda.max happens, assign -Inf to log.liklhd: added to avoid code mistakes
                        if(is.na(lambda.max)) {
                                log.liklhd.eta.tau[k, l] <- -Inf
                                betas.eta.tau[, k, l] <- rep(NA, ncol(x) + 1)
                        }
                        
                        # when a very small lambda.max appears, search from lambda.min instead: added to accelarate the code
                        else if(lambda.max <= lambda.min) {
                                lambda <- lambda.min
                                continue <- TRUE
                                while(continue) {
                                        lambda.seq <- seq(lambda, lambda/2, length.out = length.2)
                                        select <- select.lambda(lambda.seq, x, y, eta, tau, beta.prior, index, cv.group, is.refit)
                                        position <- select$position
                                        lambda <- select$lambda
                                        num.nzero <- select$num.nzero
                                        if(position != length.2 || num.nzero > n.threshold) continue <- FALSE
                                }
                                log.liklhd.eta.tau[k, l] <- select$log.liklhd.1se
                                betas.eta.tau[, k, l] <- select$beta.refit.1se
                        }
                        
                        else {
                                lambda.seq <- seq(lambda.max, lambda.min, length.out = length.1)
                                select <- select.lambda(lambda.seq, x, y, eta, tau, beta.prior, index, cv.group, is.refit)
                                position <- select$position
                                lambda <- select$lambda
                                num.nzero <- select$num.nzero
                        
                                # print(position)
                                # print(num.nzero)
                                
                                refined <- FALSE
                                if (position == length.1 && num.nzero < n.threshold) {
                                        continue <- TRUE
                                        refined <- TRUE
                                        # cat("continue to search!", "\n")
                                }
                                else {
                                        continue <- FALSE
                                }
                
                                while(continue) {
                                        lambda.seq <- seq(lambda, lambda/2, length.out = length.2)
                                        select <- select.lambda(lambda.seq, x, y, eta, tau, beta.prior, index, cv.group, is.refit)
                                        position <- select$position
                                        lambda <- select$lambda
                                        num.nzero <- select$num.nzero
                                        if(position != length.2 || num.nzero > n.threshold) continue <- FALSE
                                }
                                # cat("selected lambda = ", lambda.seq[position], "\n")
                                
                                if(position != 1 && position != length.1 && refined == FALSE) {
                                        # print(position)
                                        # cat("refine start!", "\n")
                                        lambda.seq <- seq(lambda.seq[position - 1], lambda.seq[position + 1], length.out = length.2)
                                        select <- select.lambda(lambda.seq, x, y, eta, tau, beta.prior, index, cv.group, is.refit)
                                        position <- select$position
                                        lambda <- select$lambda
                                        # cat("selected lambda = ", lambda.seq[position], "\n")
                                }
                                log.liklhd.eta.tau[k, l] <- select$log.liklhd.1se
                                betas.eta.tau[, k, l] <- select$beta.refit.1se
                        }
                }
        }
        
        pos.max.eta.tau <- which(log.liklhd.eta.tau == max(log.liklhd.eta.tau), arr.ind = TRUE)
        if(nrow(pos.max.eta.tau) > 1) {
                pos.max.eta.tau <- pos.max.eta.tau[nrow(pos.max.eta.tau), ]
        }
        
        # eta <- eta.seq[pos.max.eta.tau[1]]
        # tau <- tau.seq[pos.max.eta.tau[2]]
        
        # if(length(eta.seq) > 1 && length(tau.seq) == 1) {
                # cat("plasso: selected eta is ", eta, "\n")
        # }
        # if(length(eta.seq) == 1 && length(tau.seq) > 1) {
                # cat("alasso: selected tau is ", tau, "\n")
        # }
        # if(length(eta.seq) > 1 && length(tau.seq) > 1) {
                # cat("plasso + alasso: selected eta is ", eta, " and selected tau is ", tau, "\n")
        # }
        
        return(beta = betas.eta.tau[, pos.max.eta.tau[1], pos.max.eta.tau[2]])
}