[ Add ] support for log probs chat completion creation

This commit is contained in:
Anas Fikhi
2024-02-22 02:42:07 +01:00
parent e3973a7439
commit dfce769f27
7 changed files with 159 additions and 1 deletions

View File

@ -1,4 +1,4 @@
{ {
"cSpell.words": ["Epoches", "openai"], "cSpell.words": ["Epoches", "openai", "Probs"],
"editor.acceptSuggestionOnEnter": "off" "editor.acceptSuggestionOnEnter": "off"
} }

View File

@ -0,0 +1,45 @@
import 'package:dart_openai/dart_openai.dart';
import 'env/env.dart';
void main() async {
// Set the OpenAI API key from the .env file.
OpenAI.apiKey = Env.apiKey;
final systemMessage = OpenAIChatCompletionChoiceMessageModel(
content: [
OpenAIChatCompletionChoiceMessageContentItemModel.text(
"return any message you are given as JSON.",
),
],
role: OpenAIChatMessageRole.assistant,
);
final userMessage = OpenAIChatCompletionChoiceMessageModel(
content: [
OpenAIChatCompletionChoiceMessageContentItemModel.text(
"Hello, I am a chatbot created by OpenAI. How are you today?",
),
],
role: OpenAIChatMessageRole.user,
name: "anas",
);
final requestMessages = [
systemMessage,
userMessage,
];
OpenAIChatCompletionModel chatCompletion = await OpenAI.instance.chat.create(
model: "gpt-3.5-turbo-1106",
responseFormat: {"type": "json_object"},
seed: 6,
messages: requestMessages,
temperature: 0.2,
maxTokens: 500,
logprobs: true,
topLogprobs: 2,
);
print(chatCompletion.choices.first.logprobs?.content.first.bytes); //
}

View File

@ -1,3 +1,4 @@
import 'sub_models/log_probs/log_probs.dart';
import 'sub_models/message.dart'; import 'sub_models/message.dart';
/// {@template openai_chat_completion_choice} /// {@template openai_chat_completion_choice}
@ -15,6 +16,9 @@ final class OpenAIChatCompletionChoiceModel {
/// The [finishReason] of the choice. /// The [finishReason] of the choice.
final String? finishReason; final String? finishReason;
/// The log probability of the choice.
final OpenAIChatCompletionChoiceLogProbsModel? logprobs;
/// Weither the choice have a finish reason. /// Weither the choice have a finish reason.
bool get haveFinishReason => finishReason != null; bool get haveFinishReason => finishReason != null;
@ -28,6 +32,7 @@ final class OpenAIChatCompletionChoiceModel {
required this.index, required this.index,
required this.message, required this.message,
required this.finishReason, required this.finishReason,
required this.logprobs,
}); });
/// This is used to convert a [Map<String, dynamic>] object to a [OpenAIChatCompletionChoiceModel] object. /// This is used to convert a [Map<String, dynamic>] object to a [OpenAIChatCompletionChoiceModel] object.
@ -39,6 +44,9 @@ final class OpenAIChatCompletionChoiceModel {
: int.tryParse(json['index'].toString()) ?? json['index'], : int.tryParse(json['index'].toString()) ?? json['index'],
message: OpenAIChatCompletionChoiceMessageModel.fromMap(json['message']), message: OpenAIChatCompletionChoiceMessageModel.fromMap(json['message']),
finishReason: json['finish_reason'], finishReason: json['finish_reason'],
logprobs: json['logprobs'] != null
? OpenAIChatCompletionChoiceLogProbsModel.fromMap(json['logprobs'])
: null,
); );
} }
@ -48,6 +56,7 @@ final class OpenAIChatCompletionChoiceModel {
"index": index, "index": index,
"message": message.toMap(), "message": message.toMap(),
"finish_reason": finishReason, "finish_reason": finishReason,
"logprobs": logprobs?.toMap(),
}; };
} }

View File

@ -0,0 +1,31 @@
// ignore_for_file: public_member_api_docs, sort_constructors_first
import 'sub_models/content.dart';
class OpenAIChatCompletionChoiceLogProbsModel {
OpenAIChatCompletionChoiceLogProbsModel({
required this.content,
});
final List<OpenAIChatCompletionChoiceLogProbsContentModel> content;
factory OpenAIChatCompletionChoiceLogProbsModel.fromMap(
Map<String, dynamic> json,
) {
return OpenAIChatCompletionChoiceLogProbsModel(
content: json["content"] != null
? List<OpenAIChatCompletionChoiceLogProbsContentModel>.from(
json["content"].map(
(x) =>
OpenAIChatCompletionChoiceLogProbsContentModel.fromMap(x),
),
)
: [],
);
}
Map<String, dynamic> toMap() {
return {
"content": content.map((x) => x.toMap()).toList(),
};
}
}

View File

@ -0,0 +1,41 @@
import 'top_prob.dart';
class OpenAIChatCompletionChoiceLogProbsContentModel {
final String? token;
final double? logprob;
final List<int>? bytes;
final List<OpenAIChatCompletionChoiceTopLogProbsContentModel>? topLogprobs;
OpenAIChatCompletionChoiceLogProbsContentModel({
this.token,
this.logprob,
this.bytes,
this.topLogprobs,
});
factory OpenAIChatCompletionChoiceLogProbsContentModel.fromMap(
Map<String, dynamic> map,
) {
return OpenAIChatCompletionChoiceLogProbsContentModel(
token: map['token'],
logprob: map['logprob'],
bytes: List<int>.from(map['bytes']),
topLogprobs: List<OpenAIChatCompletionChoiceTopLogProbsContentModel>.from(
map['top_logprobs']?.map(
(x) => OpenAIChatCompletionChoiceTopLogProbsContentModel.fromMap(x),
),
),
);
}
Map<String, dynamic> toMap() {
return {
'token': token,
'logprob': logprob,
'bytes': bytes,
};
}
}

View File

@ -0,0 +1,28 @@
import 'content.dart';
class OpenAIChatCompletionChoiceTopLogProbsContentModel
extends OpenAIChatCompletionChoiceLogProbsContentModel {
OpenAIChatCompletionChoiceTopLogProbsContentModel({
super.token,
super.logprob,
super.bytes,
});
factory OpenAIChatCompletionChoiceTopLogProbsContentModel.fromMap(
Map<String, dynamic> map,
) {
return OpenAIChatCompletionChoiceTopLogProbsContentModel(
token: map['token'],
logprob: map['logprob'],
bytes: List<int>.from(map['bytes']),
);
}
Map<String, dynamic> toMap() {
return {
'token': token,
'logprob': logprob,
'bytes': bytes,
};
}
}

View File

@ -82,6 +82,8 @@ interface class OpenAIChat implements OpenAIChatBase {
String? user, String? user,
Map<String, String>? responseFormat, Map<String, String>? responseFormat,
int? seed, int? seed,
bool? logprobs,
int? topLogprobs,
http.Client? client, http.Client? client,
}) async { }) async {
return await OpenAINetworkingClient.post( return await OpenAINetworkingClient.post(
@ -103,6 +105,8 @@ interface class OpenAIChat implements OpenAIChatBase {
if (user != null) "user": user, if (user != null) "user": user,
if (seed != null) "seed": seed, if (seed != null) "seed": seed,
if (responseFormat != null) "response_format": responseFormat, if (responseFormat != null) "response_format": responseFormat,
if (logprobs != null) "logprobs": logprobs,
if (topLogprobs != null) "top_logprobs": topLogprobs,
}, },
onSuccess: (Map<String, dynamic> response) { onSuccess: (Map<String, dynamic> response) {
return OpenAIChatCompletionModel.fromMap(response); return OpenAIChatCompletionModel.fromMap(response);