blob: 6202cfc65d5d7f7f714be44ae65fe0036bd71611 [file] [log] [blame]
// Copyright (C) 2009 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.gerrit.httpd.auth.openid;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.MoreObjects;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.gerrit.common.Nullable;
import com.google.gerrit.common.PageLinks;
import com.google.gerrit.common.auth.openid.OpenIdUrls;
import com.google.gerrit.extensions.auth.oauth.OAuthServiceProvider;
import com.google.gerrit.extensions.client.AuthType;
import com.google.gerrit.extensions.registration.DynamicMap;
import com.google.gerrit.extensions.restapi.Url;
import com.google.gerrit.httpd.HtmlDomUtil;
import com.google.gerrit.httpd.LoginUrlToken;
import com.google.gerrit.httpd.template.SiteHeaderFooter;
import com.google.gerrit.server.CurrentUser;
import com.google.gerrit.server.config.AuthConfig;
import com.google.gerrit.server.config.CanonicalWebUrl;
import com.google.gerrit.server.config.GerritServerConfig;
import com.google.inject.Inject;
import com.google.inject.Provider;
import com.google.inject.Singleton;
import java.io.IOException;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import javax.servlet.ServletOutputStream;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jgit.lib.Config;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
/** Handles OpenID based login flow. */
@SuppressWarnings("serial")
@Singleton
class LoginForm extends HttpServlet {
private static final Logger log = LoggerFactory.getLogger(LoginForm.class);
private static final ImmutableMap<String, String> ALL_PROVIDERS =
ImmutableMap.of(
"launchpad", OpenIdUrls.URL_LAUNCHPAD,
"yahoo", OpenIdUrls.URL_YAHOO);
private final ImmutableSet<String> suggestProviders;
private final Provider<String> urlProvider;
private final Provider<OAuthSessionOverOpenID> oauthSessionProvider;
private final OpenIdServiceImpl impl;
private final int maxRedirectUrlLength;
private final String ssoUrl;
private final SiteHeaderFooter header;
private final Provider<CurrentUser> currentUserProvider;
private final DynamicMap<OAuthServiceProvider> oauthServiceProviders;
@Inject
LoginForm(
@CanonicalWebUrl @Nullable Provider<String> urlProvider,
@GerritServerConfig Config config,
AuthConfig authConfig,
OpenIdServiceImpl impl,
SiteHeaderFooter header,
Provider<OAuthSessionOverOpenID> oauthSessionProvider,
Provider<CurrentUser> currentUserProvider,
DynamicMap<OAuthServiceProvider> oauthServiceProviders) {
this.urlProvider = urlProvider;
this.impl = impl;
this.header = header;
this.maxRedirectUrlLength = config.getInt("openid", "maxRedirectUrlLength", 10);
this.oauthSessionProvider = oauthSessionProvider;
this.currentUserProvider = currentUserProvider;
this.oauthServiceProviders = oauthServiceProviders;
if (urlProvider == null || Strings.isNullOrEmpty(urlProvider.get())) {
log.error("gerrit.canonicalWebUrl must be set in gerrit.config");
}
if (authConfig.getAuthType() == AuthType.OPENID_SSO) {
suggestProviders = ImmutableSet.of();
ssoUrl = authConfig.getOpenIdSsoUrl();
} else {
Set<String> providers = new HashSet<>();
for (Map.Entry<String, String> e : ALL_PROVIDERS.entrySet()) {
if (impl.isAllowedOpenID(e.getValue())) {
providers.add(e.getKey());
}
}
suggestProviders = ImmutableSet.copyOf(providers);
ssoUrl = null;
}
}
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse res) throws IOException {
if (ssoUrl != null) {
String token = LoginUrlToken.getToken(req);
SignInMode mode;
if (PageLinks.REGISTER.equals(token)) {
mode = SignInMode.REGISTER;
token = PageLinks.MINE;
} else {
mode = SignInMode.SIGN_IN;
}
discover(req, res, false, ssoUrl, false, token, mode);
} else {
String id = Strings.nullToEmpty(req.getParameter("id")).trim();
if (!id.isEmpty()) {
doPost(req, res);
} else {
boolean link = req.getParameter("link") != null;
sendForm(req, res, link, null);
}
}
}
@Override
protected void doPost(HttpServletRequest req, HttpServletResponse res) throws IOException {
boolean link = req.getParameter("link") != null;
String id = Strings.nullToEmpty(req.getParameter("id")).trim();
if (id.isEmpty()) {
sendForm(req, res, link, null);
return;
}
if (!id.startsWith("http://") && !id.startsWith("https://")) {
id = "http://" + id;
}
if ((ssoUrl != null && !ssoUrl.equals(id)) || !impl.isAllowedOpenID(id)) {
sendForm(req, res, link, "OpenID provider not permitted by site policy.");
return;
}
boolean remember = "1".equals(req.getParameter("rememberme"));
String token = LoginUrlToken.getToken(req);
SignInMode mode;
if (link) {
mode = SignInMode.LINK_IDENTIY;
} else if (PageLinks.REGISTER.equals(token)) {
mode = SignInMode.REGISTER;
token = PageLinks.MINE;
} else {
mode = SignInMode.SIGN_IN;
}
log.debug("mode \"{}\"", mode);
OAuthServiceProvider oauthProvider = lookupOAuthServiceProvider(id);
if (oauthProvider == null) {
log.debug("OpenId provider \"{}\"", id);
discover(req, res, link, id, remember, token, mode);
} else {
log.debug("OAuth provider \"{}\"", id);
OAuthSessionOverOpenID oauthSession = oauthSessionProvider.get();
if (!currentUserProvider.get().isIdentifiedUser() && oauthSession.isLoggedIn()) {
oauthSession.logout();
}
if ((isGerritLogin(req) || oauthSession.isOAuthFinal(req))) {
oauthSession.setServiceProvider(oauthProvider);
oauthSession.setLinkMode(link);
oauthSession.login(req, res, oauthProvider);
}
}
}
private void discover(
HttpServletRequest req,
HttpServletResponse res,
boolean link,
String id,
boolean remember,
String token,
SignInMode mode)
throws IOException {
if (ssoUrl != null) {
remember = false;
}
DiscoveryResult r = impl.discover(req, id, mode, remember, token);
switch (r.status) {
case VALID:
redirect(r, res);
break;
case NO_PROVIDER:
sendForm(req, res, link, "Provider is not supported, or was incorrectly entered.");
break;
case ERROR:
sendForm(req, res, link, "Unable to connect with OpenID provider.");
break;
}
}
private void redirect(DiscoveryResult r, HttpServletResponse res) throws IOException {
StringBuilder url = new StringBuilder();
url.append(r.providerUrl);
if (r.providerArgs != null && !r.providerArgs.isEmpty()) {
boolean first = true;
for (Map.Entry<String, String> arg : r.providerArgs.entrySet()) {
if (first) {
url.append('?');
first = false;
} else {
url.append('&');
}
url.append(Url.encode(arg.getKey())).append('=').append(Url.encode(arg.getValue()));
}
}
if (url.length() <= maxRedirectUrlLength) {
res.sendRedirect(url.toString());
return;
}
Document doc = HtmlDomUtil.parseFile(LoginForm.class, "RedirectForm.html");
Element form = HtmlDomUtil.find(doc, "redirect_form");
form.setAttribute("action", r.providerUrl);
if (r.providerArgs != null && !r.providerArgs.isEmpty()) {
for (Map.Entry<String, String> arg : r.providerArgs.entrySet()) {
Element in = doc.createElement("input");
in.setAttribute("type", "hidden");
in.setAttribute("name", arg.getKey());
in.setAttribute("value", arg.getValue());
form.appendChild(in);
}
}
sendHtml(res, doc);
}
private void sendForm(
HttpServletRequest req, HttpServletResponse res, boolean link, @Nullable String errorMessage)
throws IOException {
String self = req.getRequestURI();
String cancel = MoreObjects.firstNonNull(urlProvider != null ? urlProvider.get() : "/", "/");
cancel += LoginUrlToken.getToken(req);
Document doc = header.parse(LoginForm.class, "LoginForm.html");
HtmlDomUtil.find(doc, "hostName").setTextContent(req.getServerName());
HtmlDomUtil.find(doc, "login_form").setAttribute("action", self);
HtmlDomUtil.find(doc, "cancel_link").setAttribute("href", cancel);
if (!link || ssoUrl != null) {
Element input = HtmlDomUtil.find(doc, "f_link");
input.getParentNode().removeChild(input);
}
String last = getLastId(req);
if (last != null) {
HtmlDomUtil.find(doc, "f_openid").setAttribute("value", last);
}
Element emsg = HtmlDomUtil.find(doc, "error_message");
if (Strings.isNullOrEmpty(errorMessage)) {
emsg.getParentNode().removeChild(emsg);
} else {
emsg.setTextContent(errorMessage);
}
for (String name : ALL_PROVIDERS.keySet()) {
Element div = HtmlDomUtil.find(doc, "provider_" + name);
if (div == null) {
continue;
}
if (!suggestProviders.contains(name)) {
div.getParentNode().removeChild(div);
continue;
}
Element a = HtmlDomUtil.find(div, "id_" + name);
if (a == null) {
div.getParentNode().removeChild(div);
continue;
}
StringBuilder u = new StringBuilder();
u.append(self).append(a.getAttribute("href"));
if (link) {
u.append("&link");
}
a.setAttribute("href", u.toString());
}
// OAuth: Add plugin based providers
Element providers = HtmlDomUtil.find(doc, "providers");
Set<String> plugins = oauthServiceProviders.plugins();
for (String pluginName : plugins) {
Map<String, Provider<OAuthServiceProvider>> m = oauthServiceProviders.byPlugin(pluginName);
for (Map.Entry<String, Provider<OAuthServiceProvider>> e : m.entrySet()) {
addProvider(providers, link, pluginName, e.getKey(), e.getValue().get().getName());
}
}
sendHtml(res, doc);
}
private void sendHtml(HttpServletResponse res, Document doc) throws IOException {
byte[] bin = HtmlDomUtil.toUTF8(doc);
res.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
res.setContentType("text/html");
res.setCharacterEncoding(UTF_8.name());
res.setContentLength(bin.length);
try (ServletOutputStream out = res.getOutputStream()) {
out.write(bin);
}
}
private static void addProvider(
Element form, boolean link, String pluginName, String id, String serviceName) {
Element div = form.getOwnerDocument().createElement("div");
div.setAttribute("id", id);
Element hyperlink = form.getOwnerDocument().createElement("a");
StringBuilder u = new StringBuilder(String.format("?id=%s_%s", pluginName, id));
if (link) {
u.append("&link");
}
hyperlink.setAttribute("href", u.toString());
hyperlink.setTextContent(serviceName + " (" + pluginName + " plugin)");
div.appendChild(hyperlink);
form.appendChild(div);
}
private OAuthServiceProvider lookupOAuthServiceProvider(String providerId) {
if (providerId.startsWith("http://")) {
providerId = providerId.substring("http://".length());
}
Set<String> plugins = oauthServiceProviders.plugins();
for (String pluginName : plugins) {
Map<String, Provider<OAuthServiceProvider>> m = oauthServiceProviders.byPlugin(pluginName);
for (Map.Entry<String, Provider<OAuthServiceProvider>> e : m.entrySet()) {
if (providerId.equals(String.format("%s_%s", pluginName, e.getKey()))) {
return e.getValue().get();
}
}
}
return null;
}
private static String getLastId(HttpServletRequest req) {
Cookie[] cookies = req.getCookies();
if (cookies != null) {
for (Cookie c : cookies) {
if (OpenIdUrls.LASTID_COOKIE.equals(c.getName())) {
return c.getValue();
}
}
}
return null;
}
private static boolean isGerritLogin(HttpServletRequest request) {
return request.getRequestURI().indexOf(OAuthSessionOverOpenID.GERRIT_LOGIN) >= 0;
}
}