| // Copyright (C) 2015 The Android Open Source Project |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package com.google.gerrit.httpd.auth.oauth; |
| |
| import com.google.common.base.MoreObjects; |
| import com.google.common.base.Strings; |
| import com.google.common.collect.Iterables; |
| import com.google.gerrit.common.Nullable; |
| import com.google.gerrit.extensions.auth.oauth.OAuthServiceProvider; |
| import com.google.gerrit.extensions.registration.DynamicMap; |
| import com.google.gerrit.httpd.HtmlDomUtil; |
| import com.google.gerrit.httpd.LoginUrlToken; |
| import com.google.gerrit.httpd.template.SiteHeaderFooter; |
| import com.google.gerrit.server.config.CanonicalWebUrl; |
| import com.google.inject.Inject; |
| import com.google.inject.Provider; |
| import com.google.inject.Singleton; |
| |
| import org.w3c.dom.Document; |
| import org.w3c.dom.Element; |
| |
| import java.io.IOException; |
| import java.nio.charset.StandardCharsets; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.SortedMap; |
| import java.util.SortedSet; |
| |
| import javax.servlet.Filter; |
| import javax.servlet.FilterChain; |
| import javax.servlet.FilterConfig; |
| import javax.servlet.ServletException; |
| import javax.servlet.ServletOutputStream; |
| import javax.servlet.ServletRequest; |
| import javax.servlet.ServletResponse; |
| import javax.servlet.http.HttpServletRequest; |
| import javax.servlet.http.HttpServletResponse; |
| |
| @Singleton |
| /* OAuth web filter uses active OAuth session to perform OAuth requests */ |
| class OAuthWebFilter implements Filter { |
| static final String GERRIT_LOGIN = "/login"; |
| |
| private final Provider<String> urlProvider; |
| private final Provider<OAuthSession> oauthSessionProvider; |
| private final DynamicMap<OAuthServiceProvider> oauthServiceProviders; |
| private final SiteHeaderFooter header; |
| private OAuthServiceProvider ssoProvider; |
| |
| @Inject |
| OAuthWebFilter(@CanonicalWebUrl @Nullable Provider<String> urlProvider, |
| DynamicMap<OAuthServiceProvider> oauthServiceProviders, |
| Provider<OAuthSession> oauthSessionProvider, |
| SiteHeaderFooter header) { |
| this.urlProvider = urlProvider; |
| this.oauthServiceProviders = oauthServiceProviders; |
| this.oauthSessionProvider = oauthSessionProvider; |
| this.header = header; |
| } |
| |
| @Override |
| public void init(FilterConfig filterConfig) throws ServletException { |
| pickSSOServiceProvider(); |
| } |
| |
| @Override |
| public void destroy() { |
| } |
| |
| @Override |
| public void doFilter(ServletRequest request, ServletResponse response, |
| FilterChain chain) throws IOException, ServletException { |
| HttpServletRequest httpRequest = (HttpServletRequest) request; |
| HttpServletResponse httpResponse = (HttpServletResponse) response; |
| |
| OAuthSession oauthSession = oauthSessionProvider.get(); |
| if (request.getParameter("link") != null) { |
| oauthSession.setLinkMode(true); |
| oauthSession.setServiceProvider(null); |
| } |
| |
| String provider = httpRequest.getParameter("provider"); |
| OAuthServiceProvider service = ssoProvider == null |
| ? oauthSession.getServiceProvider() |
| : ssoProvider; |
| |
| if (isGerritLogin(httpRequest) || oauthSession.isOAuthFinal(httpRequest)) { |
| if (service == null && Strings.isNullOrEmpty(provider)) { |
| selectProvider(httpRequest, httpResponse, null); |
| return; |
| } else { |
| if (service == null) { |
| service = findService(provider); |
| } |
| oauthSession.setServiceProvider(service); |
| oauthSession.login(httpRequest, httpResponse, service); |
| } |
| } else { |
| chain.doFilter(httpRequest, response); |
| } |
| } |
| |
| private OAuthServiceProvider findService(String providerId) |
| throws ServletException { |
| Set<String> plugins = oauthServiceProviders.plugins(); |
| for (String pluginName : plugins) { |
| Map<String, Provider<OAuthServiceProvider>> m = |
| oauthServiceProviders.byPlugin(pluginName); |
| for (Map.Entry<String, Provider<OAuthServiceProvider>> e |
| : m.entrySet()) { |
| if (providerId.equals( |
| String.format("%s_%s", pluginName, e.getKey()))) { |
| return e.getValue().get(); |
| } |
| } |
| } |
| throw new ServletException("No provider found for: " + providerId); |
| } |
| |
| private void selectProvider(HttpServletRequest req, HttpServletResponse res, |
| @Nullable String errorMessage) |
| throws IOException { |
| String self = req.getRequestURI(); |
| String cancel = MoreObjects.firstNonNull( |
| urlProvider != null ? urlProvider.get() : "/", "/"); |
| cancel += LoginUrlToken.getToken(req); |
| |
| Document doc = header.parse(OAuthWebFilter.class, "LoginForm.html"); |
| HtmlDomUtil.find(doc, "hostName").setTextContent(req.getServerName()); |
| HtmlDomUtil.find(doc, "login_form").setAttribute("action", self); |
| HtmlDomUtil.find(doc, "cancel_link").setAttribute("href", cancel); |
| |
| Element emsg = HtmlDomUtil.find(doc, "error_message"); |
| if (Strings.isNullOrEmpty(errorMessage)) { |
| emsg.getParentNode().removeChild(emsg); |
| } else { |
| emsg.setTextContent(errorMessage); |
| } |
| |
| Element providers = HtmlDomUtil.find(doc, "providers"); |
| |
| Set<String> plugins = oauthServiceProviders.plugins(); |
| for (String pluginName : plugins) { |
| Map<String, Provider<OAuthServiceProvider>> m = |
| oauthServiceProviders.byPlugin(pluginName); |
| for (Map.Entry<String, Provider<OAuthServiceProvider>> e |
| : m.entrySet()) { |
| addProvider(providers, pluginName, e.getKey(), |
| e.getValue().get().getName()); |
| } |
| } |
| |
| sendHtml(res, doc); |
| } |
| |
| private static void addProvider(Element form, String pluginName, |
| String id, String serviceName) { |
| Element div = form.getOwnerDocument().createElement("div"); |
| div.setAttribute("id", id); |
| Element hyperlink = form.getOwnerDocument().createElement("a"); |
| hyperlink.setAttribute("href", String.format("?provider=%s_%s", |
| pluginName, id)); |
| hyperlink.setTextContent(serviceName + |
| " (" + pluginName + " plugin)"); |
| div.appendChild(hyperlink); |
| form.appendChild(div); |
| } |
| |
| private static void sendHtml(HttpServletResponse res, Document doc) |
| throws IOException { |
| byte[] bin = HtmlDomUtil.toUTF8(doc); |
| res.setStatus(HttpServletResponse.SC_UNAUTHORIZED); |
| res.setContentType("text/html"); |
| res.setCharacterEncoding(StandardCharsets.UTF_8.name()); |
| res.setContentLength(bin.length); |
| try (ServletOutputStream out = res.getOutputStream()) { |
| out.write(bin); |
| } |
| } |
| |
| private void pickSSOServiceProvider() |
| throws ServletException { |
| SortedSet<String> plugins = oauthServiceProviders.plugins(); |
| if (plugins.isEmpty()) { |
| throw new ServletException( |
| "OAuth service provider wasn't installed"); |
| } |
| if (plugins.size() == 1) { |
| SortedMap<String, Provider<OAuthServiceProvider>> services = |
| oauthServiceProviders.byPlugin(Iterables.getOnlyElement(plugins)); |
| if (services.size() == 1) { |
| ssoProvider = Iterables.getOnlyElement(services.values()).get(); |
| } |
| } |
| } |
| |
| private static boolean isGerritLogin(HttpServletRequest request) { |
| return request.getRequestURI().indexOf(GERRIT_LOGIN) >= 0; |
| } |
| } |