Add support for GPT-4o models
Support has been added for responses from ChatGPT 4o models, which may
involve multiple tool calls.
Change-Id: I45b8988572898f2bad901e0902cc8b4e4f109a37
Signed-off-by: Patrizio <patrizio.gelosi@amarulasolutions.com>
diff --git a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/PatchSetReviewer.java b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/PatchSetReviewer.java
index fcbafd9..cea8b4c 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/PatchSetReviewer.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/PatchSetReviewer.java
@@ -21,8 +21,6 @@
import java.util.*;
-import static com.googlesource.gerrit.plugins.chatgpt.utils.GsonUtils.getGson;
-
@Slf4j
public class PatchSetReviewer {
private static final String SPLIT_REVIEW_MSG = "Too many changes. Please consider splitting into patches smaller " +
@@ -66,7 +64,7 @@
}
ChangeSetDataHandler.update(config, change, gerritClient, changeSetData);
- String reviewReply = getReviewReply(change, patchSet);
+ ChatGptResponseContent reviewReply = getReviewReply(change, patchSet);
log.debug("ChatGPT response: {}", reviewReply);
retrieveReviewBatches(reviewReply, change);
@@ -101,9 +99,8 @@
}
}
- private void retrieveReviewBatches(String reviewReply, GerritChange change) {
- ChatGptResponseContent reviewJson = getGson().fromJson(reviewReply, ChatGptResponseContent.class);
- for (ChatGptReplyItem replyItem : reviewJson.getReplies()) {
+ private void retrieveReviewBatches(ChatGptResponseContent reviewReply, GerritChange change) {
+ for (ChatGptReplyItem replyItem : reviewReply.getReplies()) {
String reply = replyItem.getReply();
Integer score = replyItem.getScore();
boolean isNotNegative = isNotNegativeReply(score);
@@ -130,11 +127,11 @@
}
}
- private String getReviewReply(GerritChange change, String patchSet) throws Exception {
+ private ChatGptResponseContent getReviewReply(GerritChange change, String patchSet) throws Exception {
List<String> patchLines = Arrays.asList(patchSet.split("\n"));
if (patchLines.size() > config.getMaxReviewLines()) {
log.warn("Patch set too large. Skipping review. changeId: {}", change.getFullChangeId());
- return String.format(SPLIT_REVIEW_MSG, config.getMaxReviewLines());
+ return new ChatGptResponseContent(String.format(SPLIT_REVIEW_MSG, config.getMaxReviewLines()));
}
return chatGptClient.ask(config, changeSetData, change, patchSet);
diff --git a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/common/client/api/chatgpt/ChatGptClient.java b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/common/client/api/chatgpt/ChatGptClient.java
index 13b96da..f3b8f1d 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/common/client/api/chatgpt/ChatGptClient.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/common/client/api/chatgpt/ChatGptClient.java
@@ -19,7 +19,7 @@
@Getter
protected String requestBody;
- protected String extractContent(Configuration config, String body) throws Exception {
+ protected ChatGptResponseContent extractContent(Configuration config, String body) throws Exception {
if (config.getGptStreamOutput() && !isCommentEvent) {
StringBuilder finalContent = new StringBuilder();
try (BufferedReader reader = new BufferedReader(new StringReader(body))) {
@@ -28,7 +28,7 @@
extractContentFromLine(line).ifPresent(finalContent::append);
}
}
- return finalContent.toString();
+ return convertResponseContentFromJson(finalContent.toString());
}
else {
ChatGptResponseUnstreamed chatGptResponseUnstreamed =
@@ -37,9 +37,7 @@
}
}
- protected boolean validateResponse(String contentExtracted, String changeId, int attemptInd) {
- ChatGptResponseContent chatGptResponseContent =
- getGson().fromJson(contentExtracted, ChatGptResponseContent.class);
+ protected boolean validateResponse(ChatGptResponseContent chatGptResponseContent, String changeId, int attemptInd) {
String returnedChangeId = chatGptResponseContent.getChangeId();
// A response is considered valid if either no changeId is returned or the changeId returned matches the one
// provided in the request
@@ -51,8 +49,12 @@
return isValidated;
}
- protected String getResponseContent(List<ChatGptToolCall> toolCalls) {
- return toolCalls.get(0).getFunction().getArguments();
+ protected ChatGptResponseContent getResponseContent(List<ChatGptToolCall> toolCalls) {
+ if (toolCalls.size() > 1) {
+ return mergeToolCalls(toolCalls);
+ } else {
+ return getArgumentAsResponse(toolCalls, 0);
+ }
}
protected Optional<String> extractContentFromLine(String line) {
@@ -67,8 +69,30 @@
if (delta == null || delta.getToolCalls() == null) {
return Optional.empty();
}
- String content = getResponseContent(delta.getToolCalls());
+ String content = getArgumentAsString(delta.getToolCalls(), 0);
return Optional.ofNullable(content);
}
+ private ChatGptResponseContent convertResponseContentFromJson(String content) {
+ return getGson().fromJson(content, ChatGptResponseContent.class);
+ }
+
+ private String getArgumentAsString(List<ChatGptToolCall> toolCalls, int ind) {
+ return toolCalls.get(ind).getFunction().getArguments();
+ }
+
+ private ChatGptResponseContent getArgumentAsResponse(List<ChatGptToolCall> toolCalls, int ind) {
+ return convertResponseContentFromJson(getArgumentAsString(toolCalls, ind));
+ }
+
+ private ChatGptResponseContent mergeToolCalls(List<ChatGptToolCall> toolCalls) {
+ ChatGptResponseContent responseContent = getArgumentAsResponse(toolCalls, 0);
+ for (int ind = 1; ind < toolCalls.size(); ind++) {
+ responseContent.getReplies().addAll(
+ getArgumentAsResponse(toolCalls, ind).getReplies()
+ );
+ }
+ return responseContent;
+ }
+
}
diff --git a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/common/model/api/chatgpt/ChatGptResponseContent.java b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/common/model/api/chatgpt/ChatGptResponseContent.java
index 892d50f..2662155 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/common/model/api/chatgpt/ChatGptResponseContent.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/common/model/api/chatgpt/ChatGptResponseContent.java
@@ -1,11 +1,16 @@
package com.googlesource.gerrit.plugins.chatgpt.mode.common.model.api.chatgpt;
import lombok.Data;
+import lombok.NonNull;
+import lombok.RequiredArgsConstructor;
import java.util.List;
@Data
+@RequiredArgsConstructor
public class ChatGptResponseContent {
private List<ChatGptReplyItem> replies;
private String changeId;
+ @NonNull
+ private String errorMessage;
}
diff --git a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/interfaces/client/api/chatgpt/IChatGptClient.java b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/interfaces/client/api/chatgpt/IChatGptClient.java
index 9a3f6fd..50f8517 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/interfaces/client/api/chatgpt/IChatGptClient.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/interfaces/client/api/chatgpt/IChatGptClient.java
@@ -2,9 +2,11 @@
import com.googlesource.gerrit.plugins.chatgpt.config.Configuration;
import com.googlesource.gerrit.plugins.chatgpt.mode.common.client.api.gerrit.GerritChange;
+import com.googlesource.gerrit.plugins.chatgpt.mode.common.model.api.chatgpt.ChatGptResponseContent;
import com.googlesource.gerrit.plugins.chatgpt.mode.common.model.data.ChangeSetData;
public interface IChatGptClient {
- String ask(Configuration config, ChangeSetData changeSetData, GerritChange change, String patchSet) throws Exception;
+ ChatGptResponseContent ask(Configuration config, ChangeSetData changeSetData, GerritChange change, String patchSet)
+ throws Exception;
String getRequestBody();
}
diff --git a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/stateful/client/api/chatgpt/ChatGptClientStateful.java b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/stateful/client/api/chatgpt/ChatGptClientStateful.java
index c401cd0..fe6401e 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/stateful/client/api/chatgpt/ChatGptClientStateful.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/stateful/client/api/chatgpt/ChatGptClientStateful.java
@@ -7,6 +7,7 @@
import com.googlesource.gerrit.plugins.chatgpt.data.PluginDataHandler;
import com.googlesource.gerrit.plugins.chatgpt.mode.common.client.api.chatgpt.ChatGptClient;
import com.googlesource.gerrit.plugins.chatgpt.mode.common.client.api.gerrit.GerritChange;
+import com.googlesource.gerrit.plugins.chatgpt.mode.common.model.api.chatgpt.ChatGptResponseContent;
import com.googlesource.gerrit.plugins.chatgpt.mode.common.model.data.ChangeSetData;
import com.googlesource.gerrit.plugins.chatgpt.mode.interfaces.client.api.chatgpt.IChatGptClient;
import lombok.extern.slf4j.Slf4j;
@@ -23,7 +24,7 @@
this.pluginDataHandler = pluginDataHandler;
}
- public String ask(Configuration config, ChangeSetData changeSetData, GerritChange change, String patchSet) {
+ public ChatGptResponseContent ask(Configuration config, ChangeSetData changeSetData, GerritChange change, String patchSet) {
isCommentEvent = change.getIsCommentEvent();
String changeId = change.getFullChangeId();
log.info("Processing STATEFUL ChatGPT Request with changeId: {}, Patch Set: {}", changeId, patchSet);
diff --git a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/stateless/client/api/chatgpt/ChatGptClientStateless.java b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/stateless/client/api/chatgpt/ChatGptClientStateless.java
index e1b34ac..de37d13 100644
--- a/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/stateless/client/api/chatgpt/ChatGptClientStateless.java
+++ b/src/main/java/com/googlesource/gerrit/plugins/chatgpt/mode/stateless/client/api/chatgpt/ChatGptClientStateless.java
@@ -32,8 +32,8 @@
private final HttpClientWithRetry httpClientWithRetry = new HttpClientWithRetry();
- public String ask(Configuration config, ChangeSetData changeSetData, GerritChange change, String patchSet)
- throws Exception {
+ public ChatGptResponseContent ask(Configuration config, ChangeSetData changeSetData, GerritChange change,
+ String patchSet) throws Exception {
isCommentEvent = change.getIsCommentEvent();
String changeId = change.getFullChangeId();
log.info("Processing STATELESS ChatGPT Request with changeId: {}, Patch Set: {}", changeId, patchSet);
@@ -49,7 +49,7 @@
throw new IOException("ChatGPT response body is null");
}
- String contentExtracted = extractContent(config, body);
+ ChatGptResponseContent contentExtracted = extractContent(config, body);
if (validateResponse(contentExtracted, changeId, attemptInd)) {
return contentExtracted;
}
diff --git a/src/test/java/com/googlesource/gerrit/plugins/chatgpt/ChatGptReviewStatefulTest.java b/src/test/java/com/googlesource/gerrit/plugins/chatgpt/ChatGptReviewStatefulTest.java
index 79e15ea..d452cec 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/chatgpt/ChatGptReviewStatefulTest.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/chatgpt/ChatGptReviewStatefulTest.java
@@ -44,7 +44,8 @@
private static final String CHAT_GPT_RUN_ID = "run_TEST_RUN_ID";
private String formattedPatchContent;
- private String reviewMessage;
+ private String reviewMessageCode;
+ private String reviewMessageCommitMessage;
private ChatGptPromptStateful chatGptPromptStateful;
private JsonObject threadMessage;
@@ -147,7 +148,8 @@
protected void initComparisonContent() {
super.initComparisonContent();
- reviewMessage = getReviewMessage();
+ reviewMessageCode = getReviewMessage(0);
+ reviewMessageCommitMessage = getReviewMessage(1);
}
protected ArgumentCaptor<ReviewInput> testRequestSent() throws RestApiException {
@@ -156,12 +158,12 @@
return reviewInputCaptor;
}
- private String getReviewMessage() {
+ private String getReviewMessage(int tollCallId) {
ChatGptListResponse reviewResponse = getGson().fromJson(readTestFile(
"__files/chatGptRunStepsResponse.json"
), ChatGptListResponse.class);
- String reviewJsonResponse = reviewResponse.getData().get(0).getStepDetails().getToolCalls().get(0).getFunction()
- .getArguments();
+ String reviewJsonResponse = reviewResponse.getData().get(0).getStepDetails().getToolCalls().get(tollCallId)
+ .getFunction().getArguments();
return getGson().fromJson(reviewJsonResponse, ChatGptResponseContent.class).getReplies().get(0).getReply();
}
@@ -175,7 +177,11 @@
String userPrompt = threadMessage.get("content").getAsString();
Assert.assertEquals(reviewUserPrompt, userPrompt);
Assert.assertEquals(
- reviewMessage,
+ reviewMessageCode,
+ captor.getAllValues().get(0).comments.get("test_file_1.py").get(0).message
+ );
+ Assert.assertEquals(
+ reviewMessageCommitMessage,
captor.getAllValues().get(0).comments.get(GERRIT_PATCH_SET_FILENAME).get(0).message
);
}
diff --git a/src/test/java/com/googlesource/gerrit/plugins/chatgpt/integration/CodeReviewPluginIT.java b/src/test/java/com/googlesource/gerrit/plugins/chatgpt/integration/CodeReviewPluginIT.java
index d6f873b..2c1535f 100644
--- a/src/test/java/com/googlesource/gerrit/plugins/chatgpt/integration/CodeReviewPluginIT.java
+++ b/src/test/java/com/googlesource/gerrit/plugins/chatgpt/integration/CodeReviewPluginIT.java
@@ -5,6 +5,7 @@
import com.googlesource.gerrit.plugins.chatgpt.mode.common.client.api.gerrit.GerritChange;
import com.googlesource.gerrit.plugins.chatgpt.mode.common.client.api.gerrit.GerritClient;
import com.googlesource.gerrit.plugins.chatgpt.mode.common.client.api.gerrit.GerritClientReview;
+import com.googlesource.gerrit.plugins.chatgpt.mode.common.model.api.chatgpt.ChatGptResponseContent;
import com.googlesource.gerrit.plugins.chatgpt.mode.common.model.data.ChangeSetData;
import com.googlesource.gerrit.plugins.chatgpt.mode.common.model.review.ReviewBatch;
import com.googlesource.gerrit.plugins.chatgpt.mode.interfaces.client.api.chatgpt.IChatGptClient;
@@ -49,7 +50,7 @@
when(config.getGptModel()).thenReturn(Configuration.DEFAULT_GPT_MODEL);
when(chatGptPromptStateless.getGptSystemPrompt()).thenReturn(ChatGptPromptStateless.DEFAULT_GPT_SYSTEM_PROMPT);
- String answer = chatGptClient.ask(config, changeSetData, new GerritChange(""), "hello");
+ ChatGptResponseContent answer = chatGptClient.ask(config, changeSetData, new GerritChange(""), "hello");
log.info("answer: {}", answer);
assertNotNull(answer);
}
diff --git a/src/test/resources/__files/chatGptRunStepsResponse.json b/src/test/resources/__files/chatGptRunStepsResponse.json
index 30fc5e8..b9b6044 100644
--- a/src/test/resources/__files/chatGptRunStepsResponse.json
+++ b/src/test/resources/__files/chatGptRunStepsResponse.json
@@ -17,7 +17,15 @@
"type": "function",
"function": {
"name": "format_replies",
- "arguments": "{\n \"replies\": [\n {\n \"reply\": \"The change in the `rsplit` function call from `rsplit('.', 1)` to `rsplit('.', 2)` might lead to a `ValueError` if the `module_name` does not contain any dots. This change assumes that there is always at least one dot in the `module_name`. Ensure that the module naming convention enforces this or add error handling for the case where `module_name` does not contain a dot.\",\n \n\"score\n\": -1,\n \n\"relevance\n\": 0.9,\n \n\"repeated\n\": false,\n \n\"conflicting\n\": false,\n \n\"filename\n\": \n\"test_file_1.py\n\",\n \n\"lineNumber\n\": 18,\n \n\"codeSnippet\n\": \n\"module_name, class_name = module_name.rsplit('.', 2)\n\"\n }\n ],\n \n\"changeId\n\": \n\"myProject~myBranchName~myChangeId\n\"\n }"
+ "arguments": "{\n \"replies\": [\n {\n \"reply\": \"The change in the `rsplit` function call from `rsplit('.', 1)` to `rsplit('.', 2)` might lead to a `ValueError` if the `module_name` does not contain any dots. This change assumes that there is always at least one dot in the `module_name`. Ensure that the module naming convention enforces this or add error handling for the case where `module_name` does not contain a dot.\",\n \"score\": -1,\n \"relevance\": 0.9,\n \"repeated\": false,\n \"conflicting\": false,\n \"filename\": \"test_file_1.py\",\n \"lineNumber\": 18,\n \"codeSnippet\": \"module_name, class_name = module_name.rsplit('.', 2)\"\n }\n ],\n \"changeId\": \"myProject~myBranchName~myChangeId\"\n }"
+ }
+ },
+ {
+ "id": "call_M1ioOSrg4sDZ3Q8Tps6dcAJk",
+ "type": "function",
+ "function": {
+ "name": "format_replies",
+ "arguments": "{\n \"replies\": [\n {\n \"reply\": \"The commit message 'Minor Fixes' is too vague and does not provide adequate information about the changes made in the PatchSet. A more descriptive message would help understand the nature and objective of the changes. Please consider revising it to include details of what is being fixed and why.\",\n \"score\": -1,\n \"relevance\": 0.8,\n \"repeated\": false,\n \"conflicting\": false\n }\n ],\n \"changeId\": \"myProject~myBranchName~myChangeId\"\n }"
}
}
]
diff --git a/src/test/resources/__files/stateful/gerritFormattedPatch.txt b/src/test/resources/__files/stateful/gerritFormattedPatch.txt
index b7875ed..5a0f07e 100644
--- a/src/test/resources/__files/stateful/gerritFormattedPatch.txt
+++ b/src/test/resources/__files/stateful/gerritFormattedPatch.txt
@@ -1,3 +1,6 @@
+Subject: Minor fixes
+
+Change-Id: myChangeId
---
diff --git a/test_file_1.py b/test_file_1.py