blob: 6202b2e2bfad45b1e0f2dd9c32de5a27c54a1299 [file] [log] [blame]
// Copyright (C) 2024 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.googlesource.gerrit.plugins.aicodereview;
import static com.googlesource.gerrit.plugins.aicodereview.config.Configuration.KEY_AI_CHAT_ENDPOINT;
import static com.googlesource.gerrit.plugins.aicodereview.config.Configuration.KEY_AI_TYPE;
import static com.googlesource.gerrit.plugins.aicodereview.config.Configuration.KEY_STREAM_OUTPUT;
import static com.googlesource.gerrit.plugins.aicodereview.utils.TextUtils.joinWithNewLine;
import static java.net.HttpURLConnection.HTTP_OK;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.github.tomakehurst.wiremock.client.WireMock;
import com.google.common.net.HttpHeaders;
import com.google.gerrit.extensions.api.changes.FileApi;
import com.google.gerrit.extensions.api.changes.ReviewInput;
import com.google.gerrit.extensions.common.DiffInfo;
import com.google.gerrit.extensions.common.FileInfo;
import com.google.gerrit.extensions.restapi.RestApiException;
import com.google.gerrit.json.OutputFormat;
import com.google.gson.Gson;
import com.google.gson.JsonArray;
import com.googlesource.gerrit.plugins.aicodereview.listener.EventHandlerTask;
import com.googlesource.gerrit.plugins.aicodereview.listener.EventHandlerTask.SupportedEvents;
import com.googlesource.gerrit.plugins.aicodereview.mode.stateless.client.api.UriResourceLocatorStateless;
import com.googlesource.gerrit.plugins.aicodereview.mode.stateless.client.prompt.AIChatPromptStateless;
import com.googlesource.gerrit.plugins.aicodereview.settings.Settings;
import java.net.URI;
import java.util.Arrays;
import java.util.Map;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.reflect.TypeLiteral;
import org.apache.http.entity.ContentType;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
@Slf4j
@RunWith(MockitoJUnitRunner.class)
public class AIChatReviewStatelessTest extends AIChatReviewTestBase {
private ReviewInput expectedResponseStreamed;
private String expectedSystemPromptReview;
private String promptTagReview;
private String diffContent;
private ReviewInput gerritPatchSetReview;
private JsonArray prompts;
private AIChatPromptStateless AIChatPromptStateless;
protected void initConfig() {
super.initGlobalAndProjectConfig();
when(globalConfig.getBoolean(Mockito.eq(KEY_STREAM_OUTPUT), Mockito.anyBoolean()))
.thenReturn(GPT_STREAM_OUTPUT);
when(globalConfig.getBoolean(Mockito.eq("aiReviewCommitMessages"), Mockito.anyBoolean()))
.thenReturn(true);
super.initConfig();
// Load the prompts
AIChatPromptStateless = new AIChatPromptStateless(config);
}
protected void setupMockRequests() throws RestApiException {
super.setupMockRequests();
// Mock the behavior of the gerritPatchSetFiles request
Map<String, FileInfo> files =
readTestFileToType(
"__files/stateless/gerritPatchSetFiles.json",
new TypeLiteral<Map<String, FileInfo>>() {}.getType());
when(revisionApiMock.files(0)).thenReturn(files);
// Mock the behavior of the gerritPatchSet diff requests
FileApi commitMsgFileMock = mock(FileApi.class);
when(revisionApiMock.file("/COMMIT_MSG")).thenReturn(commitMsgFileMock);
DiffInfo commitMsgFileDiff =
readTestFileToClass("__files/stateless/gerritPatchSetDiffCommitMsg.json", DiffInfo.class);
when(commitMsgFileMock.diff(0)).thenReturn(commitMsgFileDiff);
FileApi testFileMock = mock(FileApi.class);
when(revisionApiMock.file("test_file.py")).thenReturn(testFileMock);
DiffInfo testFileDiff =
readTestFileToClass("__files/stateless/gerritPatchSetDiffTestFile.json", DiffInfo.class);
when(testFileMock.diff(0)).thenReturn(testFileDiff);
// Mock the behavior of the askGpt request
WireMock.stubFor(
WireMock.post(
WireMock.urlEqualTo(
URI.create(
config.getAIDomain() + UriResourceLocatorStateless.chatCompletionsUri())
.getPath()))
.willReturn(
WireMock.aResponse()
.withStatus(HTTP_OK)
.withHeader(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.toString())
.withBodyFile("aiChatResponseStreamed.txt")));
}
protected void initComparisonContent() {
super.initComparisonContent();
diffContent = readTestFile("reducePatchSet/patchSetDiffOutput.json");
gerritPatchSetReview =
readTestFileToClass("__files/stateless/gerritPatchSetReview.json", ReviewInput.class);
expectedResponseStreamed =
readTestFileToClass(
"__files/stateless/aiChatExpectedResponseStreamed.json", ReviewInput.class);
promptTagReview = readTestFile("__files/stateless/aiChatPromptTagReview.json");
promptTagComments = readTestFile("__files/stateless/aiChatPromptTagRequests.json");
expectedSystemPromptReview = AIChatPromptStateless.getDefaultGptReviewSystemPrompt();
}
protected ArgumentCaptor<ReviewInput> testRequestSent() throws RestApiException {
ArgumentCaptor<ReviewInput> reviewInputCaptor = super.testRequestSent();
prompts = gptRequestBody.get("messages").getAsJsonArray();
return reviewInputCaptor;
}
private String getReviewUserPrompt() {
return joinWithNewLine(
Arrays.asList(
AIChatPromptStateless.DEFAULT_AI_CHAT_REVIEW_PROMPT,
AIChatPromptStateless.DEFAULT_AI_CHAT_REVIEW_PROMPT_REVIEW
+ " "
+ AIChatPromptStateless.DEFAULT_AI_CHAT_PROMPT_FORCE_JSON_FORMAT
+ " "
+ AIChatPromptStateless.getPatchSetReviewPrompt(),
AIChatPromptStateless.getReviewPromptCommitMessages(),
AIChatPromptStateless.DEFAULT_AI_CHAT_REVIEW_PROMPT_DIFF,
diffContent,
AIChatPromptStateless.DEFAULT_AI_CHAT_REVIEW_PROMPT_MESSAGE_HISTORY,
promptTagReview));
}
@Test
public void patchSetCreatedOrUpdatedStreamed() throws Exception {
String reviewUserPrompt = getReviewUserPrompt();
AIChatPromptStateless.setCommentEvent(false);
handleEventBasedOnType(SupportedEvents.PATCH_SET_CREATED);
ArgumentCaptor<ReviewInput> captor = testRequestSent();
String systemPrompt = prompts.get(0).getAsJsonObject().get("content").getAsString();
Assert.assertEquals(expectedSystemPromptReview, systemPrompt);
String userPrompt = prompts.get(1).getAsJsonObject().get("content").getAsString();
Assert.assertEquals(reviewUserPrompt, userPrompt);
Gson gson = OutputFormat.JSON_COMPACT.newGson();
Assert.assertEquals(
gson.toJson(expectedResponseStreamed), gson.toJson(captor.getAllValues().get(0)));
}
@Test
public void patchSetCreatedOrUpdatedUnstreamed() throws Exception {
when(globalConfig.getBoolean(Mockito.eq("aiStreamOutput"), Mockito.anyBoolean()))
.thenReturn(false);
when(globalConfig.getBoolean(Mockito.eq("enabledVoting"), Mockito.anyBoolean()))
.thenReturn(true);
String reviewUserPrompt = getReviewUserPrompt();
AIChatPromptStateless.setCommentEvent(false);
WireMock.stubFor(
WireMock.post(
WireMock.urlEqualTo(
URI.create(
config.getAIDomain() + UriResourceLocatorStateless.chatCompletionsUri())
.getPath()))
.willReturn(
WireMock.aResponse()
.withStatus(HTTP_OK)
.withHeader(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.toString())
.withBodyFile("aiChatResponseReview.json")));
handleEventBasedOnType(SupportedEvents.PATCH_SET_CREATED);
ArgumentCaptor<ReviewInput> captor = testRequestSent();
String userPrompt = prompts.get(1).getAsJsonObject().get("content").getAsString();
Assert.assertEquals(reviewUserPrompt, userPrompt);
Gson gson = OutputFormat.JSON_COMPACT.newGson();
Assert.assertEquals(
gson.toJson(gerritPatchSetReview), gson.toJson(captor.getAllValues().get(0)));
}
@Test
public void patchSetDisableUserGroup() {
when(globalConfig.getString(Mockito.eq("disabledGroups"), Mockito.anyString()))
.thenReturn(GERRIT_USER_GROUP);
Assert.assertEquals(
EventHandlerTask.Result.NOT_SUPPORTED,
handleEventBasedOnType(SupportedEvents.PATCH_SET_CREATED));
}
@Test
public void gptMentionedInComment() throws RestApiException {
when(config.getGerritUserName()).thenReturn(GERRIT_GPT_USERNAME);
AIChatPromptStateless.setCommentEvent(true);
WireMock.stubFor(
WireMock.post(
WireMock.urlEqualTo(
URI.create(
config.getAIDomain() + UriResourceLocatorStateless.chatCompletionsUri())
.getPath()))
.willReturn(
WireMock.aResponse()
.withStatus(HTTP_OK)
.withHeader(HttpHeaders.CONTENT_TYPE, ContentType.APPLICATION_JSON.toString())
.withBodyFile("aiChatResponseRequestStateless.json")));
handleEventBasedOnType(SupportedEvents.COMMENT_ADDED);
int commentPropertiesSize =
gerritClient.getClientData(getGerritChange()).getCommentProperties().size();
String commentUserPrompt =
joinWithNewLine(
Arrays.asList(
AIChatPromptStateless.DEFAULT_AI_CHAT_REQUEST_PROMPT_DIFF,
diffContent,
AIChatPromptStateless.DEFAULT_AI_CHAT_REQUEST_PROMPT_REQUESTS,
readTestFile("__files/stateless/aiChatExpectedRequestMessage.json"),
AIChatPromptStateless.getCommentRequestPrompt(commentPropertiesSize)));
testRequestSent();
String userPrompt = prompts.get(1).getAsJsonObject().get("content").getAsString();
Assert.assertEquals(commentUserPrompt, userPrompt);
}
@Test
public void testAITypeValidOptions() {
when(globalConfig.getString(Mockito.eq("aiType"), Mockito.anyString())).thenReturn("CHATGPT");
// check default for aiType is chatGPT.
Assert.assertEquals(config.getAIType(), Settings.AIType.CHATGPT);
when(globalConfig.getString(Mockito.eq("aiType"), Mockito.anyString())).thenReturn("OLLAMA");
Assert.assertEquals(config.getAIType(), Settings.AIType.OLLAMA);
}
@Test
public void testAITypeControlsEndpoint() {
when(globalConfig.getString(Mockito.eq("aiType"), Mockito.anyString())).thenReturn("CHATGPT");
// check default for aiType is chatGPT.
Assert.assertEquals(config.getChatEndpoint(), "");
Assert.assertEquals(
UriResourceLocatorStateless.chatCompletionsUri(),
UriResourceLocatorStateless.getChatResourceUri(config));
// swap it to ollama, check we still get the chatCompletionsUri, as its the openai
// compat endpoint we use.
when(globalConfig.getString(Mockito.eq("aiType"), Mockito.anyString())).thenReturn("OLLAMA");
Assert.assertEquals(
UriResourceLocatorStateless.chatCompletionsUri(),
UriResourceLocatorStateless.getChatResourceUri(config));
// finally change to GENERIC, and check that we can specify any endpoint
when(globalConfig.getString(Mockito.eq(KEY_AI_TYPE), Mockito.anyString()))
.thenReturn("GENERIC");
final String expectedValueForEndpoint = "/someendpoint/someapi/chat";
when(globalConfig.getString(Mockito.eq(KEY_AI_CHAT_ENDPOINT), Mockito.anyString()))
.thenReturn(expectedValueForEndpoint);
Assert.assertEquals(
expectedValueForEndpoint, UriResourceLocatorStateless.getChatResourceUri(config));
}
@Test
public void testAITypeControlsAuthHeader() {
when(globalConfig.getString(Mockito.eq("aiType"), Mockito.anyString())).thenReturn("CHATGPT");
// check default for aiType is chatGPT.
Assert.assertEquals("Authorization", config.getAuthorizationHeaderInfo().getName());
Assert.assertEquals(
"Bearer " + config.getAIToken(), config.getAuthorizationHeaderInfo().getValue());
// swap it to ollama, check we still get the chatCompletionsUri, as its the openai
// compat endpoint we use.
when(globalConfig.getString(Mockito.eq("aiType"), Mockito.anyString())).thenReturn("OLLAMA");
Assert.assertNull(
"No expected value for auth header for ollama", config.getAuthorizationHeaderInfo());
}
}