-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathyandex_gpt.py
348 lines (311 loc) · 13.2 KB
/
yandex_gpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import asyncio
from typing import (
Dict,
Any,
TypedDict,
List,
Union
)
import aiohttp
import requests
from .config_manager import YandexGPTConfigManagerBase
class YandexGPTMessage(TypedDict):
role: str
text: str
class YandexGPTBase:
"""
This class is used to interact with the Yandex GPT API, providing asynchronous and synchronous methods to send
requests and poll for their completion. Currently, only asynchronous methods are implemented fully.
Methods
-------
send_async_completion_request(headers: Dict[str, str], payload: Dict[str, Any], completion_url: str) -> str
Sends an asynchronous request to the Yandex GPT completion API.
poll_async_completion(operation_id: str, headers: Dict[str, str], timeout: int, poll_url: str) -> Dict[str, Any]
Polls the status of an asynchronous completion operation until it completes or times out.
send_sync_completion_request(headers: Dict[str, str], payload: Dict[str, Any], completion_url: str) -> Dict[str, Any]
Sends a synchronous request to the Yandex GPT completion API.
"""
@staticmethod
async def send_async_completion_request(
headers: Dict[str, str],
payload: Dict[str, Any],
completion_url: str = "https://llm.api.cloud.yandex.net/foundationModels/v1/completionAsync"
) -> str:
"""
Sends an asynchronous request to the Yandex GPT completion API.
Parameters
----------
headers : Dict[str, str]
Dictionary containing the authorization token (IAM), content type, and x-folder-id (YandexCloud catalog ID).
payload : Dict[str, Any]
Dictionary with the model URI, completion options, and messages.
completion_url : str
URL of the completion API.
Returns
-------
str
ID of the completion operation to poll.
"""
# Making the request
async with aiohttp.ClientSession() as session:
async with session.post(completion_url, headers=headers, json=payload) as resp:
# If the request was successful, return the ID of the completion operation
# Otherwise, raise an exception
if resp.status == 200:
data = await resp.json()
return data['id']
else:
raise Exception(f"Failed to send async request, status code: {resp.status}")
@staticmethod
async def poll_async_completion(
operation_id: str,
headers: Dict[str, str],
timeout: int = 5,
poll_url: str = 'https://llm.api.cloud.yandex.net/operations/'
) -> Dict[str, Any]:
"""
Polls the status of an asynchronous completion operation until it completes or times out.
Parameters
----------
operation_id : str
ID of the completion operation to poll.
headers : Dict[str, str]
Dictionary containing the authorization token (IAM).
timeout : int
Time in seconds after which the operation is considered timed out.
poll_url : str
Poll URL.
Returns
-------
Dict[str, Any]
Completion result.
"""
# Polling the completion operation for the specified amount of time
async with aiohttp.ClientSession() as session:
end_time = asyncio.get_event_loop().time() + timeout
while True:
# Check if the operation has timed out and if so, raise an exception
if asyncio.get_event_loop().time() > end_time:
raise TimeoutError(f"Operation timed out after {timeout} seconds")
# Polling the operation
async with session.get(f"{poll_url}{operation_id}", headers=headers) as resp:
# If the request was successful, return the completion result
# Otherwise, raise an exception
if resp.status == 200:
data = await resp.json()
if data.get('done', False):
return data
else:
raise Exception(f"Failed to poll operation status, status code: {resp.status}")
await asyncio.sleep(1)
@staticmethod
def send_sync_completion_request(
headers: Dict[str, str],
payload: Dict[str, Any],
completion_url: str = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion"
) -> Dict[str, Any]:
"""
Sends a synchronous request to the Yandex GPT completion API.
Parameters
----------
headers : Dict[str, str]
Dictionary containing the authorization token (IAM), content type, and x-folder-id (YandexCloud catalog ID).
payload : Dict[str, Any]
Dictionary with the model URI, completion options, and messages.
completion_url : str
URL of the completion API.
Returns
-------
Dict[str, Any]
Completion result.
"""
# Making the request
response = requests.post(completion_url, headers=headers, json=payload)
# If the request was successful, return the completion result
# Otherwise, raise an exception
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Failed to send sync request, status code: {response.status_code}")
class YandexGPT(YandexGPTBase):
"""
Extends the YandexGPTBase class to interact with the Yandex GPT API using a simplified configuration manager.
This class allows for easier configuration of API requests and includes both synchronous and asynchronous methods.
Methods
-------
get_async_completion(messages, temperature, max_tokens, stream, completion_url, timeout) -> str
Asynchronously sends a completion request to the Yandex GPT API and returns the completion result.
get_sync_completion(messages, temperature, max_tokens, stream, completion_url) -> str
Synchronously sends a completion request to the Yandex GPT API and returns the completion result.
"""
def __init__(
self,
config_manager: Union[YandexGPTConfigManagerBase, Dict[str, Any]]
) -> None:
"""
Initializes the YandexGPT class with a configuration manager.
Parameters
----------
config_manager : Union[YandexGPTConfigManagerBase, Dict[str, Any]]
Config manager or a dictionary containing:
1) completion_request_model_type_uri_field
("gpt://{self.config_manager.catalog_id}/{self.config_manager.model_type}/latest")
2) completion_request_catalog_id_field (self.config_manager.catalog_id)
3) completion_request_authorization_field ("Bearer {iam_token}" or "Api-Key {api_key}")
"""
self.config_manager = config_manager
def _create_completion_request_headers(self) -> Dict[str, str]:
"""
Creates headers for sending a completion request to the API.
Returns
-------
Dict[str, str]
Dictionary with authorization credentials, content type, and x-folder-id (YandexCloud catalog ID).
"""
return {
"Content-Type": "application/json",
"Authorization": self.config_manager.completion_request_authorization_field,
"x-folder-id": self.config_manager.completion_request_catalog_id_field
}
def _create_completion_request_payload(
self,
messages: Union[List[YandexGPTMessage], List[Dict[str, str]]],
temperature: float = 0.6,
max_tokens: int = 1000,
stream: bool = False
) -> Dict[str, Any]:
"""
Creates the payload for sending a completion request.
Parameters
----------
messages : Union[List[YandexGPTMessage], List[Dict[str, str]]]
List of messages with roles and texts.
temperature : float
Controls the randomness of the completion, from 0 (deterministic) to 1 (random).
max_tokens : int
Maximum number of tokens to generate.
stream : bool
Stream option for the API, currently not supported in this implementation.
Returns
-------
Dict[str, Any]
Dictionary containing the model URI, completion options, and messages.
"""
return {
"modelUri": self.config_manager.completion_request_model_type_uri_field,
"completionOptions": {
"stream": stream,
"temperature": temperature,
"maxTokens": max_tokens
},
"messages": messages
}
async def get_async_completion(
self,
messages: Union[List[YandexGPTMessage], List[Dict[str, str]]],
temperature: float = 0.6,
max_tokens: int = 1000,
stream: bool = False,
completion_url: str = "https://llm.api.cloud.yandex.net/foundationModels/v1/completionAsync",
timeout: int = 5
) -> str:
"""
Sends an asynchronous completion request to the Yandex GPT API and polls for the result.
Parameters
----------
messages : Union[List[YandexGPTMessage], List[Dict[str, str]]]
List of messages with roles and texts.
temperature : float
Randomness of the completion, from 0 (deterministic) to 1 (most random).
max_tokens : int
Maximum number of tokens to generate.
stream : bool
Indicates whether streaming is enabled; currently not supported in this implementation.
completion_url : str
URL to the Yandex GPT asynchronous completion API.
timeout : int
Time in seconds after which the operation is considered timed out.
Returns
-------
str
The text of the completion result.
Raises
------
Exception
If the completion operation fails or times out.
"""
# Making the request and obtaining the ID of the completion operation
headers: Dict[str, str] = self._create_completion_request_headers()
payload: Dict[str, Any] = self._create_completion_request_payload(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=stream
)
completion_request_id: str = await self.send_async_completion_request(
headers=headers,
payload=payload,
completion_url=completion_url
)
# Polling the completion operation
completion_response: Dict[str, Any] = await self.poll_async_completion(
operation_id=completion_request_id,
headers=headers,
timeout=timeout
)
# If the request was successful, return the completion result
# Otherwise, raise an exception
if completion_response.get('error', None):
raise Exception(f"Failed to get completion: {completion_response['error']}")
else:
return completion_response['response']['alternatives'][0]['message']['text']
def get_sync_completion(
self,
messages: Union[List[YandexGPTMessage], List[Dict[str, str]]],
temperature: float = 0.6,
max_tokens: int = 1000,
stream: bool = False,
completion_url: str = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion",
):
"""
Sends a synchronous completion request to the Yandex GPT API and returns the result.
Parameters
----------
messages : Union[List[YandexGPTMessage], List[Dict[str, str]]]
List of messages with roles and texts.
temperature : float
Randomness of the completion, from 0 (deterministic) to 1 (most random).
max_tokens : int
Maximum number of tokens to generate.
stream : bool
Indicates whether streaming is enabled; currently not supported in this implementation.
completion_url : str
URL to the Yandex GPT synchronous completion API.
Returns
-------
str
The text of the completion result.
Raises
------
Exception
If the completion request fails.
"""
# Making the request
headers: Dict[str, str] = self._create_completion_request_headers()
payload: Dict[str, Any] = self._create_completion_request_payload(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
stream=stream
)
completion_response: Dict[str, Any] = self.send_sync_completion_request(
headers=headers,
payload=payload,
completion_url=completion_url
)
# If the request was successful, return the completion result
# Otherwise, raise an exception
if completion_response.get('error', None):
raise Exception(f"Failed to get completion: {completion_response['error']}")
else:
return completion_response['result']['alternatives'][0]['message']['text']