blob: e1ba2c0b1a8b46e1c2127fecba038c7f905cf934 [file] [log] [blame]
// Copyright (C) 2020 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.googlesource.gerrit.plugins.saml;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import com.google.gerrit.entities.Account;
import com.google.gerrit.entities.AccountGroup;
import com.google.gerrit.server.GerritPersonIdent;
import com.google.gerrit.server.IdentifiedUser;
import com.google.gerrit.server.ServerInitiated;
import com.google.gerrit.server.account.*;
import com.google.gerrit.server.group.InternalGroup;
import com.google.gerrit.server.group.db.GroupsUpdate;
import com.google.gerrit.server.group.db.InternalGroupCreation;
import com.google.gerrit.server.group.db.InternalGroupUpdate;
import com.google.gerrit.server.notedb.Sequences;
import com.google.inject.Inject;
import com.google.inject.Provider;
import com.google.inject.Singleton;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.*;
import java.util.stream.Collectors;
import org.eclipse.jgit.lib.PersonIdent;
import org.pac4j.saml.profile.SAML2Profile;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@Singleton
/**
* This class maps the membership attributes in the SAML document onto Internal groups prefixed with
* the saml group prefix.
*/
public class SamlMembership {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private static final String GROUP_PREFIX = "saml/";
private final String memberAttr;
private final PersonIdent serverIdent;
private final AccountManager accountManager;
private final GroupCache groupCache;
private final IdentifiedUser.GenericFactory userFactory;
private final Provider<GroupsUpdate> groupsUpdateProvider;
private final Sequences sequences;
@Inject
SamlMembership(
SamlConfig samlConfig,
@GerritPersonIdent PersonIdent serverIdent,
AccountManager accountManager,
GroupCache groupCache,
IdentifiedUser.GenericFactory userFactory,
@ServerInitiated Provider<GroupsUpdate> groupsUpdateProvider,
Sequences sequences) {
this.memberAttr = samlConfig.getMemberOfAttr();
this.serverIdent = serverIdent;
this.accountManager = accountManager;
this.groupCache = groupCache;
this.userFactory = userFactory;
this.groupsUpdateProvider = groupsUpdateProvider;
this.sequences = sequences;
}
/**
* Synchronises the groups of a user with those in LDAP.
*
* @param user gerrit user
* @param profile SAML profile
*/
public void sync(AuthenticatedUser user, SAML2Profile profile) throws IOException {
Set<AccountGroup.UUID> samlMembership =
Optional.ofNullable((List<?>) profile.getAttribute(memberAttr, List.class))
.orElse(Collections.emptyList()).stream()
.map(m -> getOrCreateGroup(m.toString()))
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toSet());
IdentifiedUser identifiedUser = userFactory.create(getOrCreateAccountId(user));
Set<AccountGroup.UUID> userMembership =
identifiedUser.getEffectiveGroups().getKnownGroups().stream()
.filter(
uuid ->
groupCache
.get(uuid)
.filter(g -> g.getName().startsWith(GROUP_PREFIX))
.isPresent())
.collect(Collectors.toSet());
log.debug(
"User {} is member of {} in saml and {} in gerrit",
user.getUsername(),
samlMembership,
userMembership);
Set<Account.Id> accountIdSet = ImmutableSet.of(identifiedUser.getAccountId());
samlMembership.stream()
.filter(g -> !userMembership.contains(g))
.forEach(g -> this.updateMembers(g, members -> Sets.union(members, accountIdSet)));
userMembership.stream()
.filter(g -> !samlMembership.contains(g))
.forEach(
g ->
this.updateMembers(
g,
members ->
Sets.difference(members, ImmutableSet.of(identifiedUser.getAccountId()))));
}
/**
* test if membership syncing is enabled.
*
* @return true when it is enabled.
*/
public boolean isEnabled() {
return !Strings.isNullOrEmpty(memberAttr);
}
private void updateMembers(
AccountGroup.UUID group, InternalGroupUpdate.MemberModification memberModification) {
InternalGroupUpdate update =
InternalGroupUpdate.builder().setMemberModification(memberModification).build();
try {
groupsUpdateProvider.get().updateGroup(group, update);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private Optional<AccountGroup.UUID> getOrCreateGroup(String samlGroup) {
return samlGroupToName(samlGroup)
.map(name -> groupCache.get(name).orElseGet(() -> createGroup(name, samlGroup)))
.map(InternalGroup::getGroupUUID);
}
private InternalGroup createGroup(AccountGroup.NameKey name, String samlGroup) {
try {
AccountGroup.Id groupId = AccountGroup.id(sequences.nextGroupId());
AccountGroup.UUID uuid = GroupUuid.make(name.get(), serverIdent);
InternalGroupCreation groupCreation =
InternalGroupCreation.builder()
.setGroupUUID(uuid)
.setNameKey(name)
.setId(groupId)
.build();
InternalGroupUpdate.Builder groupUpdateBuilder =
InternalGroupUpdate.builder()
.setVisibleToAll(false)
.setDescription(samlGroup + " (imported by the SAML plugin)");
return groupsUpdateProvider.get().createGroup(groupCreation, groupUpdateBuilder.build());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private Optional<AccountGroup.NameKey> samlGroupToName(String samlGroup) {
return Optional.of(samlGroup)
.filter(s -> !s.isEmpty())
.map(GROUP_PREFIX::concat)
.map(AccountGroup::nameKey);
}
private Account.Id getOrCreateAccountId(AuthenticatedUser user) throws IOException {
AuthRequest authRequest = AuthRequest.forUser(user.getUsername());
authRequest.setUserName(user.getUsername());
authRequest.setEmailAddress(user.getEmail());
authRequest.setDisplayName(user.getDisplayName());
try {
return accountManager.authenticate(authRequest).getAccountId();
} catch (AccountException e) {
throw new RuntimeException(e);
}
}
}