package it.inaf.ia2.transfer.auth;

import it.inaf.ia2.aa.jwt.InvalidTokenException;
import it.inaf.ia2.aa.jwt.TokenParser;
import java.io.IOException;
import java.security.Principal;
import java.util.Map;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TokenFilter implements Filter {

    private static final Logger LOG = LoggerFactory.getLogger(TokenFilter.class);

    private final TokenParser tokenParser;

    public TokenFilter(TokenParser tokenParser) {
        this.tokenParser = tokenParser;
    }

    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {

        HttpServletRequest request = (HttpServletRequest) req;
        HttpServletResponse response = (HttpServletResponse) res;

        String token = getToken(request);

        TokenPrincipal tokenPrincipal;
        try {
            tokenPrincipal = getTokenPrincipal(token);
        } catch (InvalidTokenException ex) {
            response.sendError(401, "Unauthorized: " + ex.getMessage());
            return;
        }

        RequestWithPrincipal requestWrapper = new RequestWithPrincipal(request, tokenPrincipal);

        chain.doFilter(requestWrapper, response);
    }

    private String getToken(HttpServletRequest request) {
        String token = getTokenFromHeader(request);
        if (token == null) {
            // get token from query string
            token = request.getParameter("token");
        }
        return token;
    }

    private String getTokenFromHeader(HttpServletRequest request) {

        LOG.trace("Loading token from header");

        String authHeader = request.getHeader("Authorization");

        if (authHeader != null) {
            if (authHeader.startsWith("Bearer")) {
                return authHeader.substring("Bearer".length() + 1).trim();
            } else {
                LOG.warn("Detected non-Bearer Authorization header");
            }
        }

        return null;
    }

    private TokenPrincipal getTokenPrincipal(String token) {
        if (token == null) {
            // anonymous
            return new TokenPrincipal();
        }

        Map<String, Object> claims = tokenParser.getClaims(token);
        String userId = (String) claims.get("sub");
        if (userId == null) {
            throw new InvalidTokenException("Missing sub claim");
        }

        return new TokenPrincipal(userId, token);
    }

    private static class RequestWithPrincipal extends HttpServletRequestWrapper {

        private final TokenPrincipal tokenPrincipal;

        public RequestWithPrincipal(HttpServletRequest request, TokenPrincipal tokenPrincipal) {
            super(request);
            this.tokenPrincipal = tokenPrincipal;
        }

        @Override
        public Principal getUserPrincipal() {
            return tokenPrincipal;
        }
    }
}
