// Copyright (C) 2009 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.gerrit.httpd.auth.openid;

import com.google.common.base.Objects;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.gerrit.common.Nullable;
import com.google.gerrit.common.PageLinks;
import com.google.gerrit.common.auth.openid.OpenIdUrls;
import com.google.gerrit.extensions.restapi.Url;
import com.google.gerrit.httpd.HtmlDomUtil;
import com.google.gerrit.httpd.template.SiteHeaderFooter;
import com.google.gerrit.reviewdb.client.AuthType;
import com.google.gerrit.server.config.AuthConfig;
import com.google.gerrit.server.config.CanonicalWebUrl;
import com.google.gerrit.server.config.GerritServerConfig;
import com.google.inject.Inject;
import com.google.inject.Provider;
import com.google.inject.Singleton;

import org.eclipse.jgit.lib.Config;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Element;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

import javax.servlet.ServletOutputStream;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

/** Handles OpenID based login flow. */
@SuppressWarnings("serial")
@Singleton
class LoginForm extends HttpServlet {
  private static final Logger log = LoggerFactory.getLogger(LoginForm.class);
  private static final ImmutableMap<String, String> ALL_PROVIDERS = ImmutableMap.of(
      "google", OpenIdUrls.URL_GOOGLE,
      "yahoo", OpenIdUrls.URL_YAHOO);

  private final ImmutableSet<String> suggestProviders;
  private final Provider<String> urlProvider;
  private final OpenIdServiceImpl impl;
  private final int maxRedirectUrlLength;
  private final String ssoUrl;
  private final SiteHeaderFooter header;

  @Inject
  LoginForm(
      @CanonicalWebUrl @Nullable Provider<String> urlProvider,
      @GerritServerConfig Config config,
      AuthConfig authConfig,
      OpenIdServiceImpl impl,
      SiteHeaderFooter header) {
    this.urlProvider = urlProvider;
    this.impl = impl;
    this.header = header;
    this.maxRedirectUrlLength = config.getInt(
        "openid", "maxRedirectUrlLength",
        10);

    if (urlProvider == null || Strings.isNullOrEmpty(urlProvider.get())) {
      log.error("gerrit.canonicalWebUrl must be set in gerrit.config");
    }

    if (authConfig.getAuthType() == AuthType.OPENID_SSO) {
      suggestProviders = ImmutableSet.of();
      ssoUrl = authConfig.getOpenIdSsoUrl();
    } else {
      Set<String> providers = Sets.newHashSet();
      for (Map.Entry<String, String> e : ALL_PROVIDERS.entrySet()) {
        if (impl.isAllowedOpenID(e.getValue())) {
          providers.add(e.getKey());
        }
      }
      suggestProviders = ImmutableSet.copyOf(providers);
      ssoUrl = null;
    }
  }

  @Override
  protected void doGet(HttpServletRequest req, HttpServletResponse res)
      throws IOException {
    if (ssoUrl != null) {
      String token = getToken(req);
      SignInMode mode;
      if (PageLinks.REGISTER.equals(token)) {
        mode = SignInMode.REGISTER;
        token = PageLinks.MINE;
      } else {
        mode = SignInMode.SIGN_IN;
      }
      discover(req, res, false, ssoUrl, false, token, mode);
    } else {
      String id = Strings.nullToEmpty(req.getParameter("id")).trim();
      if (!id.isEmpty()) {
        doPost(req, res);
      } else {
        boolean link = req.getParameter("link") != null;
        sendForm(req, res, link, null);
      }
    }
  }

  @Override
  protected void doPost(HttpServletRequest req, HttpServletResponse res)
      throws IOException {
    boolean link = req.getParameter("link") != null;
    String id = Strings.nullToEmpty(req.getParameter("id")).trim();
    if (id.isEmpty()) {
      sendForm(req, res, link, null);
      return;
    }
    if (!id.startsWith("http://") && !id.startsWith("https://")) {
      id = "http://" + id;
    }
    if ((ssoUrl != null && !ssoUrl.equals(id)) || !impl.isAllowedOpenID(id)) {
      sendForm(req, res, link, "OpenID provider not permitted by site policy.");
      return;
    }

    boolean remember = "1".equals(req.getParameter("rememberme"));
    String token = getToken(req);
    SignInMode mode;
    if (link) {
      mode = SignInMode.LINK_IDENTIY;
    } else if (PageLinks.REGISTER.equals(token)) {
      mode = SignInMode.REGISTER;
      token = PageLinks.MINE;
    } else {
      mode = SignInMode.SIGN_IN;
    }

    discover(req, res, link, id, remember, token, mode);
  }

  private void discover(HttpServletRequest req, HttpServletResponse res,
      boolean link, String id, boolean remember, String token, SignInMode mode)
      throws IOException {
    if (ssoUrl != null) {
      remember = false;
    }

    DiscoveryResult r = impl.discover(req, id, mode, remember, token);
    switch (r.status) {
      case VALID:
        redirect(r, res);
        break;

      case NO_PROVIDER:
        sendForm(req, res, link,
            "Provider is not supported, or was incorrectly entered.");
        break;

      case ERROR:
        sendForm(req, res, link, "Unable to connect with OpenID provider.");
        break;
    }
  }

  private void redirect(DiscoveryResult r, HttpServletResponse res)
      throws IOException {
    StringBuilder url = new StringBuilder();
    url.append(r.providerUrl);
    if (r.providerArgs != null && !r.providerArgs.isEmpty()) {
      boolean first = true;
      for(Map.Entry<String, String> arg : r.providerArgs.entrySet()) {
        if (first) {
          url.append('?');
          first = false;
        } else {
          url.append('&');
        }
        url.append(Url.encode(arg.getKey()))
           .append('=')
           .append(Url.encode(arg.getValue()));
      }
    }
    if (url.length() <= maxRedirectUrlLength) {
      res.sendRedirect(url.toString());
      return;
    }

    Document doc = HtmlDomUtil.parseFile(LoginForm.class, "RedirectForm.html");
    Element form = HtmlDomUtil.find(doc, "redirect_form");
    form.setAttribute("action", r.providerUrl);
    if (r.providerArgs != null && !r.providerArgs.isEmpty()) {
      for (Map.Entry<String, String> arg : r.providerArgs.entrySet()) {
        Element in = doc.createElement("input");
        in.setAttribute("type", "hidden");
        in.setAttribute("name", arg.getKey());
        in.setAttribute("value", arg.getValue());
        form.appendChild(in);
      }
    }
    sendHtml(res, doc);
  }

  private static String getToken(HttpServletRequest req) {
    String token = req.getPathInfo();
    if (token == null || token.isEmpty()) {
      token = PageLinks.MINE;
    } else if (!token.startsWith("/")) {
      token = "/" + token;
    }
    return token;
  }

  private void sendForm(HttpServletRequest req, HttpServletResponse res,
      boolean link, @Nullable String errorMessage) throws IOException {
    String self = req.getRequestURI();
    String cancel = Objects.firstNonNull(urlProvider != null ? urlProvider.get() : "/", "/");
    String token = getToken(req);
    if (!token.equals("/")) {
      cancel += "#" + token;
    }

    Document doc = header.parse(LoginForm.class, "LoginForm.html");
    HtmlDomUtil.find(doc, "hostName").setTextContent(req.getServerName());
    HtmlDomUtil.find(doc, "login_form").setAttribute("action", self);
    HtmlDomUtil.find(doc, "cancel_link").setAttribute("href", cancel);

    if (!link || ssoUrl != null) {
      Element input = HtmlDomUtil.find(doc, "f_link");
      input.getParentNode().removeChild(input);
    }

    String last = getLastId(req);
    if (last != null) {
      HtmlDomUtil.find(doc, "f_openid").setAttribute("value", last);
    }

    Element emsg = HtmlDomUtil.find(doc, "error_message");
    if (Strings.isNullOrEmpty(errorMessage)) {
      emsg.getParentNode().removeChild(emsg);
    } else {
      emsg.setTextContent(errorMessage);
    }

    for (String name : ALL_PROVIDERS.keySet()) {
      Element div = HtmlDomUtil.find(doc, "provider_" + name);
      if (div == null) {
        continue;
      }
      if (!suggestProviders.contains(name)) {
        div.getParentNode().removeChild(div);
        continue;
      }
      Element a = HtmlDomUtil.find(div, "id_" + name);
      if (a == null) {
        div.getParentNode().removeChild(div);
        continue;
      }
      StringBuilder u = new StringBuilder();
      u.append(self).append(a.getAttribute("href"));
      if (link) {
        u.append("&link");
      }
      a.setAttribute("href", u.toString());
    }
    sendHtml(res, doc);
  }

  private void sendHtml(HttpServletResponse res, Document doc)
      throws IOException {
    byte[] bin = HtmlDomUtil.toUTF8(doc);
    res.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
    res.setContentType("text/html");
    res.setCharacterEncoding("UTF-8");
    res.setContentLength(bin.length);
    ServletOutputStream out = res.getOutputStream();
    try {
      out.write(bin);
    } finally {
      out.close();
    }
  }

  private static String getLastId(HttpServletRequest req) {
    Cookie[] cookies = req.getCookies();
    if (cookies != null) {
      for (Cookie c : cookies) {
        if (OpenIdUrls.LASTID_COOKIE.equals(c.getName())) {
          return c.getValue();
        }
      }
    }
    return null;
  }
}
