blob: 2642a54361bda9f4e7b29746eb7597f67cf97ff3 [file] [log] [blame]
// Copyright (C) 2015 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.gerrit.httpd.auth.oauth;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.base.MoreObjects;
import com.google.common.base.Strings;
import com.google.common.collect.Iterables;
import com.google.gerrit.common.Nullable;
import com.google.gerrit.extensions.auth.oauth.OAuthServiceProvider;
import com.google.gerrit.extensions.registration.DynamicMap;
import com.google.gerrit.httpd.HtmlDomUtil;
import com.google.gerrit.httpd.LoginUrlToken;
import com.google.gerrit.httpd.template.SiteHeaderFooter;
import com.google.gerrit.server.config.CanonicalWebUrl;
import com.google.inject.Inject;
import com.google.inject.Provider;
import com.google.inject.Singleton;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
import java.util.SortedMap;
import java.util.SortedSet;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
@Singleton
/* OAuth web filter uses active OAuth session to perform OAuth requests */
class OAuthWebFilter implements Filter {
static final String GERRIT_LOGIN = "/login";
private final Provider<String> urlProvider;
private final Provider<OAuthSession> oauthSessionProvider;
private final DynamicMap<OAuthServiceProvider> oauthServiceProviders;
private final SiteHeaderFooter header;
private OAuthServiceProvider ssoProvider;
@Inject
OAuthWebFilter(
@CanonicalWebUrl @Nullable Provider<String> urlProvider,
DynamicMap<OAuthServiceProvider> oauthServiceProviders,
Provider<OAuthSession> oauthSessionProvider,
SiteHeaderFooter header) {
this.urlProvider = urlProvider;
this.oauthServiceProviders = oauthServiceProviders;
this.oauthSessionProvider = oauthSessionProvider;
this.header = header;
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
pickSSOServiceProvider();
}
@Override
public void destroy() {}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest httpRequest = (HttpServletRequest) request;
HttpServletResponse httpResponse = (HttpServletResponse) response;
OAuthSession oauthSession = oauthSessionProvider.get();
if (request.getParameter("link") != null) {
oauthSession.setLinkMode(true);
oauthSession.setServiceProvider(null);
}
String provider = httpRequest.getParameter("provider");
OAuthServiceProvider service =
ssoProvider == null ? oauthSession.getServiceProvider() : ssoProvider;
if (isGerritLogin(httpRequest) || oauthSession.isOAuthFinal(httpRequest)) {
if (service == null && Strings.isNullOrEmpty(provider)) {
selectProvider(httpRequest, httpResponse, null);
return;
}
if (service == null) {
service = findService(provider);
}
oauthSession.setServiceProvider(service);
oauthSession.login(httpRequest, httpResponse, service);
} else {
chain.doFilter(httpRequest, response);
}
}
private OAuthServiceProvider findService(String providerId) throws ServletException {
Set<String> plugins = oauthServiceProviders.plugins();
for (String pluginName : plugins) {
Map<String, Provider<OAuthServiceProvider>> m = oauthServiceProviders.byPlugin(pluginName);
for (Map.Entry<String, Provider<OAuthServiceProvider>> e : m.entrySet()) {
if (providerId.equals(String.format("%s_%s", pluginName, e.getKey()))) {
return e.getValue().get();
}
}
}
throw new ServletException("No provider found for: " + providerId);
}
private void selectProvider(
HttpServletRequest req, HttpServletResponse res, @Nullable String errorMessage)
throws IOException {
String self = req.getRequestURI();
String cancel = MoreObjects.firstNonNull(urlProvider != null ? urlProvider.get() : "/", "/");
cancel += LoginUrlToken.getToken(req);
Document doc = header.parse(OAuthWebFilter.class, "LoginForm.html");
HtmlDomUtil.find(doc, "hostName").setTextContent(req.getServerName());
HtmlDomUtil.find(doc, "login_form").setAttribute("action", self);
HtmlDomUtil.find(doc, "cancel_link").setAttribute("href", cancel);
Element emsg = HtmlDomUtil.find(doc, "error_message");
if (Strings.isNullOrEmpty(errorMessage)) {
emsg.getParentNode().removeChild(emsg);
} else {
emsg.setTextContent(errorMessage);
}
Element providers = HtmlDomUtil.find(doc, "providers");
Set<String> plugins = oauthServiceProviders.plugins();
for (String pluginName : plugins) {
Map<String, Provider<OAuthServiceProvider>> m = oauthServiceProviders.byPlugin(pluginName);
for (Map.Entry<String, Provider<OAuthServiceProvider>> e : m.entrySet()) {
addProvider(providers, pluginName, e.getKey(), e.getValue().get().getName());
}
}
sendHtml(res, doc);
}
private static void addProvider(Element form, String pluginName, String id, String serviceName) {
Element div = form.getOwnerDocument().createElement("div");
div.setAttribute("id", id);
Element hyperlink = form.getOwnerDocument().createElement("a");
hyperlink.setAttribute("href", String.format("?provider=%s_%s", pluginName, id));
hyperlink.setTextContent(serviceName + " (" + pluginName + " plugin)");
div.appendChild(hyperlink);
form.appendChild(div);
}
private static void sendHtml(HttpServletResponse res, Document doc) throws IOException {
byte[] bin = HtmlDomUtil.toUTF8(doc);
res.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
res.setContentType("text/html");
res.setCharacterEncoding(UTF_8.name());
res.setContentLength(bin.length);
try (ServletOutputStream out = res.getOutputStream()) {
out.write(bin);
}
}
private void pickSSOServiceProvider() throws ServletException {
SortedSet<String> plugins = oauthServiceProviders.plugins();
if (plugins.isEmpty()) {
throw new ServletException("OAuth service provider wasn't installed");
}
if (plugins.size() == 1) {
SortedMap<String, Provider<OAuthServiceProvider>> services =
oauthServiceProviders.byPlugin(Iterables.getOnlyElement(plugins));
if (services.size() == 1) {
ssoProvider = Iterables.getOnlyElement(services.values()).get();
}
}
}
private static boolean isGerritLogin(HttpServletRequest request) {
return request.getRequestURI().contains(GERRIT_LOGIN);
}
}