// Copyright (C) 2015 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.gerrit.httpd.auth.oauth;

import com.google.common.base.Objects;
import com.google.common.base.Strings;
import com.google.common.collect.Iterables;
import com.google.gerrit.common.Nullable;
import com.google.gerrit.extensions.auth.oauth.OAuthServiceProvider;
import com.google.gerrit.extensions.registration.DynamicMap;
import com.google.gerrit.httpd.HtmlDomUtil;
import com.google.gerrit.httpd.LoginUrlToken;
import com.google.gerrit.httpd.template.SiteHeaderFooter;
import com.google.gerrit.server.CurrentUser;
import com.google.gerrit.server.config.CanonicalWebUrl;
import com.google.inject.Inject;
import com.google.inject.Provider;
import com.google.inject.Singleton;

import org.w3c.dom.Document;
import org.w3c.dom.Element;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;

import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpSession;

@Singleton
/* OAuth web filter uses active OAuth session to perform OAuth requests */
class OAuthWebFilter implements Filter {
  static final String GERRIT_LOGIN = "/login";

  private final Provider<String> urlProvider;
  private final Provider<CurrentUser> currentUserProvider;
  private final Provider<OAuthSession> oauthSessionProvider;
  private final DynamicMap<OAuthServiceProvider> oauthServiceProviders;
  private final SiteHeaderFooter header;
  private OAuthServiceProvider ssoProvider;

  @Inject
  OAuthWebFilter(@CanonicalWebUrl @Nullable Provider<String> urlProvider,
      Provider<CurrentUser> currentUserProvider,
      DynamicMap<OAuthServiceProvider> oauthServiceProviders,
      Provider<OAuthSession> oauthSessionProvider,
      SiteHeaderFooter header) {
    this.urlProvider = urlProvider;
    this.currentUserProvider = currentUserProvider;
    this.oauthServiceProviders = oauthServiceProviders;
    this.oauthSessionProvider = oauthSessionProvider;
    this.header = header;
  }

  @Override
  public void init(FilterConfig filterConfig) throws ServletException {
    pickSSOServiceProvider();
  }

  @Override
  public void destroy() {
  }

  @Override
  public void doFilter(ServletRequest request, ServletResponse response,
      FilterChain chain) throws IOException, ServletException {
    HttpServletRequest httpRequest = (HttpServletRequest) request;
    HttpSession httpSession = ((HttpServletRequest) request).getSession(false);
    if (currentUserProvider.get().isIdentifiedUser()) {
      if (httpSession != null) {
        httpSession.invalidate();
      }
      chain.doFilter(request, response);
      return;
    }

    HttpServletResponse httpResponse = (HttpServletResponse) response;

    String provider = httpRequest.getParameter("provider");
    OAuthSession oauthSession = oauthSessionProvider.get();
    OAuthServiceProvider service = ssoProvider == null
        ? oauthSession.getServiceProvider()
        : ssoProvider;

    if ((isGerritLogin(httpRequest)
        || oauthSession.isOAuthFinal(httpRequest))
        && !oauthSession.isLoggedIn()) {
      if (service == null && Strings.isNullOrEmpty(provider)) {
        selectProvider(httpRequest, httpResponse, null);
        return;
      } else {
        if (service == null) {
          service = findService(provider);
        }
        oauthSession.setServiceProvider(service);
        oauthSession.login(httpRequest, httpResponse, service);
      }
    } else {
      chain.doFilter(httpRequest, response);
    }
  }

  private OAuthServiceProvider findService(String providerId)
      throws ServletException {
    Set<String> plugins = oauthServiceProviders.plugins();
    for (String pluginName : plugins) {
      Map<String, Provider<OAuthServiceProvider>> m =
          oauthServiceProviders.byPlugin(pluginName);
        for (Map.Entry<String, Provider<OAuthServiceProvider>> e
            : m.entrySet()) {
          if (providerId.equals(
              String.format("%s_%s", pluginName, e.getKey()))) {
            return e.getValue().get();
          }
        }
    }
    throw new ServletException("No provider found for: " + providerId);
  }

  private void selectProvider(HttpServletRequest req, HttpServletResponse res,
      @Nullable String errorMessage)
      throws IOException {
    String self = req.getRequestURI();
    String cancel = Objects.firstNonNull(
        urlProvider != null ? urlProvider.get() : "/", "/");
    cancel += LoginUrlToken.getToken(req);

    Document doc = header.parse(OAuthWebFilter.class, "LoginForm.html");
    HtmlDomUtil.find(doc, "hostName").setTextContent(req.getServerName());
    HtmlDomUtil.find(doc, "login_form").setAttribute("action", self);
    HtmlDomUtil.find(doc, "cancel_link").setAttribute("href", cancel);

    Element emsg = HtmlDomUtil.find(doc, "error_message");
    if (Strings.isNullOrEmpty(errorMessage)) {
      emsg.getParentNode().removeChild(emsg);
    } else {
      emsg.setTextContent(errorMessage);
    }

    Element providers = HtmlDomUtil.find(doc, "providers");

    Set<String> plugins = oauthServiceProviders.plugins();
    for (String pluginName : plugins) {
      Map<String, Provider<OAuthServiceProvider>> m =
          oauthServiceProviders.byPlugin(pluginName);
        for (Map.Entry<String, Provider<OAuthServiceProvider>> e
            : m.entrySet()) {
          addProvider(providers, pluginName, e.getKey(),
              e.getValue().get().getName());
        }
    }

    sendHtml(res, doc);
  }

  private static void addProvider(Element form, String pluginName,
      String id, String serviceName) {
    Element div = form.getOwnerDocument().createElement("div");
    div.setAttribute("id", id);
    Element hyperlink = form.getOwnerDocument().createElement("a");
    hyperlink.setAttribute("href", String.format("?provider=%s_%s",
        pluginName, id));
    hyperlink.setTextContent(serviceName +
        " (" + pluginName + " plugin)");
    div.appendChild(hyperlink);
    form.appendChild(div);
  }

  private static void sendHtml(HttpServletResponse res, Document doc)
      throws IOException {
    byte[] bin = HtmlDomUtil.toUTF8(doc);
    res.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
    res.setContentType("text/html");
    res.setCharacterEncoding(StandardCharsets.UTF_8.name());
    res.setContentLength(bin.length);
    try (ServletOutputStream out = res.getOutputStream()) {
      out.write(bin);
    }
  }

  private void pickSSOServiceProvider()
      throws ServletException {
    SortedSet<String> plugins = oauthServiceProviders.plugins();
    if (plugins.isEmpty()) {
      throw new ServletException(
          "OAuth service provider wasn't installed");
    }
    if (plugins.size() == 1) {
      SortedMap<String, Provider<OAuthServiceProvider>> services =
          oauthServiceProviders.byPlugin(Iterables.getOnlyElement(plugins));
      if (services.size() == 1) {
        ssoProvider = Iterables.getOnlyElement(services.values()).get();
      }
    }
  }

  private static boolean isGerritLogin(HttpServletRequest request) {
    return request.getRequestURI().indexOf(GERRIT_LOGIN) >= 0;
  }
}
