diff --git a/src/test/java/com/googlesource/gerrit/plugins/websession/broker/BrokerBasedWebSessionCacheCleanerTest.java b/src/test/java/com/googlesource/gerrit/plugins/websession/broker/BrokerBasedWebSessionCacheCleanerTest.java
index dbeed88..489e74e 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/websession/broker/BrokerBasedWebSessionCacheCleanerTest.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/websession/broker/BrokerBasedWebSessionCacheCleanerTest.java
@@ -18,7 +18,6 @@
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.isA;
 import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -66,17 +65,6 @@
   }
 
   @Test
-  public void testCleanupTaskRun() {
-    BrokerBasedWebSessionCache cacheMock = mock(BrokerBasedWebSessionCache.class);
-    CleanupTask task = new CleanupTask(cacheMock, null);
-    int numberOfRuns = 5;
-    for (int i = 0; i < numberOfRuns; i++) {
-      task.run();
-    }
-    verify(cacheMock, times(numberOfRuns)).cleanUp();
-  }
-
-  @Test
   public void testCleanupTaskIsScheduledOnStart() {
     objectUnderTest.start();
     verify(executorMock, times(1))
diff --git a/src/test/java/com/googlesource/gerrit/plugins/websession/broker/BrokerBasedWebSessionCacheTest.java b/src/test/java/com/googlesource/gerrit/plugins/websession/broker/BrokerBasedWebSessionCacheTest.java
index e6cb20e..2832405 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/websession/broker/BrokerBasedWebSessionCacheTest.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/websession/broker/BrokerBasedWebSessionCacheTest.java
@@ -17,20 +17,20 @@
 import static com.google.common.truth.Truth.assertThat;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyString;
-import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 
 import com.gerritforge.gerrit.eventbroker.BrokerApi;
 import com.gerritforge.gerrit.eventbroker.EventMessage;
 import com.gerritforge.gerrit.eventbroker.EventMessage.Header;
 import com.google.common.cache.Cache;
-import com.google.common.collect.Maps;
+import com.google.common.cache.CacheBuilder;
 import com.google.common.util.concurrent.MoreExecutors;
+import com.google.gerrit.entities.Account;
 import com.google.gerrit.extensions.registration.DynamicItem;
 import com.google.gerrit.httpd.WebSessionManager.Val;
+import com.google.gerrit.server.account.externalids.ExternalId;
 import com.google.gerrit.server.config.PluginConfig;
 import com.google.gerrit.server.config.PluginConfigFactory;
 import com.google.gerrit.server.events.Event;
@@ -40,7 +40,6 @@
 import com.googlesource.gerrit.plugins.websession.broker.util.TimeMachine;
 import java.time.Instant;
 import java.util.UUID;
-import java.util.concurrent.ConcurrentMap;
 import java.util.concurrent.ExecutorService;
 import org.junit.Before;
 import org.junit.Test;
@@ -55,6 +54,8 @@
 
   private static final int DEFAULT_ACCOUNT_ID = 1000000;
   private static final String KEY = "aSceprtma6B0qZ0hKxXHvQ5iyfUhCcFXxG";
+  private static Val VAL =
+      new Val(Account.id(1), 0, false, ExternalId.Key.parse("foo:bar"), 0, "", "");
   private static final String PLUGIN_NAME = "websession-broker";
 
   private byte[] emptyPayload = new byte[] {-84, -19, 0, 5, 112};
@@ -73,7 +74,7 @@
   ExecutorService executorServce = MoreExecutors.newDirectExecutorService();
 
   @Mock BrokerApi brokerApi;
-  @Mock Cache<String, Val> cache;
+  Cache<String, Val> cache;
   @Mock TimeMachine timeMachine;
   @Mock PluginConfigFactory cfg;
   @Mock PluginConfig pluginConfig;
@@ -85,6 +86,7 @@
 
   @Before
   public void setup() {
+    cache = CacheBuilder.newBuilder().build();
     when(pluginConfig.getString("webSessionTopic", "gerrit_web_session"))
         .thenReturn("gerrit_web_session");
     when(cfg.getFromGerritConfig(PLUGIN_NAME)).thenReturn(pluginConfig);
@@ -133,50 +135,19 @@
 
     objectUnderTest.processMessage(eventMessage);
 
-    verify(cache, times(1)).put(anyString(), valCaptor.capture());
-
-    assertThat(valCaptor.getValue()).isNotNull();
-    Val val = valCaptor.getValue();
+    Val val = cache.getIfPresent(eventMessageKey(eventMessage));
+    assertThat(val).isNotNull();
     assertThat(val.getAccountId().get()).isEqualTo(DEFAULT_ACCOUNT_ID);
   }
 
   @Test
   public void shouldUpdateCacheWhenLogoutMessageReceived() {
     EventMessage eventMessage = createEventMessage(emptyPayload, Operation.REMOVE);
+    cache.put(KEY, VAL);
 
     objectUnderTest.processMessage(eventMessage);
 
-    verify(cache, times(1)).invalidate(KEY);
-  }
-
-  @Test
-  public void shouldSkipCacheUpdateWhenUnknownEventType() {
-    Header header =
-        new Header(
-            UUID.fromString("7cb80dbe-65c4-4f2c-84de-580d98199d4a"),
-            UUID.fromString("97711495-1013-414e-bfd2-44776787520d"));
-    Event event = new Event("sample-event") {};
-    EventMessage eventMessage = new EventMessage(header, event);
-    objectUnderTest.processMessage(eventMessage);
-
-    verifyZeroInteractions(cache);
-  }
-
-  @Test
-  public void shouldSkipCacheUpdateWhenInvalidPayload() {
-    EventMessage eventMessage = createEventMessage(new byte[] {1, 2, 3, 4}, Operation.ADD);
-    objectUnderTest.processMessage(eventMessage);
-
-    verifyZeroInteractions(cache);
-  }
-
-  @Test
-  public void shouldSkipCacheUpdateWhenSessionExpired() {
-    when(timeMachine.now()).thenReturn(Instant.MAX);
-    EventMessage eventMessage = createEventMessage();
-    objectUnderTest.processMessage(eventMessage);
-
-    verifyZeroInteractions(cache);
+    assertThat(cache.getIfPresent(KEY)).isNull();
   }
 
   @Test
@@ -186,26 +157,20 @@
     EventMessage eventMessage = createEventMessage();
 
     objectUnderTest.processMessage(eventMessage);
-    verify(cache, times(1)).put(anyString(), valCaptor.capture());
-    assertThat(valCaptor.getValue()).isNotNull();
 
-    ConcurrentMap<String, Val> cacheMap = Maps.newConcurrentMap();
-    cacheMap.put(KEY, valCaptor.getValue());
-    when(cache.asMap()).thenReturn(cacheMap);
+    Val val = cache.getIfPresent(eventMessageKey(eventMessage));
+    assertThat(val).isNotNull();
 
     objectUnderTest.cleanUp();
 
-    verify(cache, times(1)).invalidate(KEY);
+    assertThat(cache.getIfPresent(eventMessageKey(eventMessage))).isNull();
   }
 
-  @SuppressWarnings("unchecked")
   private Val createVal(EventMessage message) {
-    ArgumentCaptor<Val> valArgumentCaptor = ArgumentCaptor.forClass(Val.class);
+    WebSessionEvent event = (WebSessionEvent) message.getEvent();
 
     objectUnderTest.processMessage(message);
-    verify(cache).put(anyString(), valArgumentCaptor.capture());
-    reset(cache);
-    return valArgumentCaptor.getValue();
+    return cache.getIfPresent(event.key);
   }
 
   private EventMessage createEventMessage() {
@@ -213,6 +178,11 @@
     return createEventMessage(defaultPayload, Operation.ADD);
   }
 
+  private String eventMessageKey(EventMessage eventMessage) {
+    WebSessionEvent sessionEvent = (WebSessionEvent) eventMessage.getEvent();
+    return sessionEvent.key;
+  }
+
   private EventMessage createEventMessage(byte[] payload, Operation operation) {
 
     Header header =
