-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy path__init__.py
434 lines (354 loc) · 12.9 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
"""
VoxelGPT plugin.
| Copyright 2017-2024, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
import json
import os
import traceback
from bson import json_util
import fiftyone as fo
from fiftyone.core.utils import add_sys_path
import fiftyone.operators as foo
import fiftyone.operators.types as types
class AskVoxelGPT(foo.Operator):
@property
def config(self):
return foo.OperatorConfig(
name="ask_voxelgpt",
label="Ask VoxelGPT",
light_icon="/assets/icon-light.svg",
dark_icon="/assets/icon-dark.svg",
execute_as_generator=True,
)
def resolve_input(self, ctx):
inputs = types.Object()
inputs.str(
"query",
label="query",
required=True,
description="What would you like to view?",
)
return types.Property(inputs)
def execute(self, ctx):
query = ctx.params["query"]
messages = []
inject_voxelgpt_secrets(ctx)
try:
with add_sys_path(os.path.dirname(os.path.abspath(__file__))):
# pylint: disable=no-name-in-module
from voxelgpt import ask_voxelgpt_generator
streaming_message = None
for response in ask_voxelgpt_generator(
query,
ctx=ctx,
dialect="string",
allow_streaming=True,
):
type = response["type"]
data = response["data"]
if type == "view":
yield self.view(ctx, data["view"])
elif type == "message":
kwargs = {}
if data["overwrite"]:
kwargs["overwrite_last"] = True
yield self.message(
ctx, data["message"], messages, **kwargs
)
elif type == "streaming":
kwargs = {}
if streaming_message is None:
streaming_message = data["content"]
else:
streaming_message += data["content"]
kwargs["overwrite_last"] = True
yield self.message(
ctx, streaming_message, messages, **kwargs
)
if data["last"]:
streaming_message = None
except Exception as e:
yield self.error(ctx, e)
def view(self, ctx, view):
if view != ctx.view:
return ctx.trigger(
"set_view",
params=dict(view=serialize_view(view)),
)
def message(self, ctx, message, messages, overwrite_last=False):
if overwrite_last:
messages[-1] = message
else:
messages.append(message)
outputs = types.Object()
outputs.str("query", label="You")
results = dict(query=ctx.params["query"])
for i, msg in enumerate(messages, 1):
field = "message" + str(i)
outputs.str(field, label="VoxelGPT")
results[field] = msg
return ctx.trigger(
"show_output",
params=dict(
outputs=types.Property(outputs).to_json(),
results=results,
),
)
def error(self, ctx, exception):
message = str(exception)
trace = traceback.format_exc()
view = types.Error(label=message, description=trace)
outputs = types.Object()
outputs.view("message", view)
return ctx.trigger(
"show_output",
params=dict(outputs=types.Property(outputs).to_json()),
)
class AskVoxelGPTPanel(foo.Operator):
@property
def config(self):
return foo.OperatorConfig(
name="ask_voxelgpt_panel",
label="Ask VoxelGPT Panel",
execute_as_generator=True,
unlisted=True,
)
def execute(self, ctx):
query = ctx.params["query"]
history = ctx.params.get("history", [])
chat_history, sample_collection, orig_view = self._parse_history(
ctx, history
)
inject_voxelgpt_secrets(ctx)
try:
with add_sys_path(os.path.dirname(os.path.abspath(__file__))):
# pylint: disable=import-error,no-name-in-module
import db
from voxelgpt import ask_voxelgpt_generator
# Log user query
table = db.table(db.UserQueryTable)
ctx.params["query_id"] = table.insert_query(query)
streaming_message = None
for response in ask_voxelgpt_generator(
query,
ctx=ctx,
chat_history=chat_history,
dialect="markdown",
allow_streaming=True,
):
type = response["type"]
data = response["data"]
if type == "view":
if orig_view is not None:
message = (
"I'm remembering your previous view. Any "
"follow-up questions in this session will be "
"posed with respect to it"
)
yield self.message(
ctx, message, orig_view=orig_view
)
yield self.view(ctx, data["view"])
elif type == "message":
kwargs = {}
if data["overwrite"]:
kwargs["overwrite_last"] = True
kwargs["history"] = data["history"]
yield self.message(ctx, data["message"], **kwargs)
elif type == "streaming":
kwargs = {}
if streaming_message is None:
streaming_message = data["content"]
else:
streaming_message += data["content"]
kwargs["overwrite_last"] = True
if data["last"]:
kwargs["history"] = streaming_message
yield self.message(ctx, streaming_message, **kwargs)
if data["last"]:
streaming_message = None
elif type == "warning":
yield self.warning(ctx, data["message"])
except Exception as e:
yield self.error(ctx, e)
finally:
yield self.done(ctx)
def view(self, ctx, view):
if view != ctx.view:
return ctx.trigger(
"set_view",
params=dict(view=serialize_view(view)),
)
def message(self, ctx, message, **kwargs):
return self.show_message(ctx, message, types.MarkdownView(), **kwargs)
def warning(self, ctx, message):
view = types.Warning(label=message)
return self.show_message(ctx, message, view)
def error(self, ctx, exception):
message = str(exception)
trace = traceback.format_exc()
view = types.Error(label=message, description=trace)
return self.show_message(ctx, message, view)
def done(self, ctx):
return ctx.trigger(
f"{self.plugin_name}/show_message",
params=dict(done=True),
)
def show_message(self, ctx, message, view_type, **kwargs):
outputs = types.Object()
outputs.str("message", view=view_type)
return ctx.trigger(
f"{self.plugin_name}/show_message",
params=dict(
query_id=ctx.params.get("query_id"),
outputs=types.Property(outputs).to_json(),
data=dict(message=message, **kwargs),
),
)
def _parse_history(self, ctx, history):
if history is None:
history = []
# Parse chat history
chat_history = []
orig_view = None
for item in history:
if item["type"] == "outgoing":
history = item.get("content", None)
else:
history = item.get("data", {}).get("history", None)
_orig_view = item.get("data", {}).get("orig_view", None)
if _orig_view is not None:
orig_view = _orig_view
if history:
chat_history.append(history)
# If we have an `orig_view` into the same dataset, start from it
if orig_view is not None and orig_view["dataset"] == ctx.dataset.name:
try:
view = deserialize_view(ctx.dataset, orig_view["stages"])
return chat_history, view, None
except:
pass
orig_view = dict(
dataset=ctx.dataset.name,
stages=serialize_view(ctx.view),
)
return chat_history, ctx.view, orig_view
class OpenVoxelGPTPanel(foo.Operator):
@property
def config(self):
return foo.OperatorConfig(
name="open_voxelgpt_panel",
label="Open VoxelGPT Panel",
unlisted=True,
)
def resolve_placement(self, ctx):
return types.Placement(
types.Places.SAMPLES_GRID_ACTIONS,
types.Button(
label="Open VoxelGPT",
icon="/assets/icon-dark.svg",
prompt=False,
),
)
def execute(self, ctx):
ctx.trigger(
"open_panel",
params=dict(name="voxelgpt", isActive=True, layout="horizontal"),
)
class OpenVoxelGPTPanelOnStartup(foo.Operator):
@property
def config(self):
return foo.OperatorConfig(
name="open_voxelgpt_panel_on_startup",
label="Open VoxelGPT Panel",
on_dataset_open=True,
unlisted=True,
)
def execute(self, ctx):
if ctx.dataset is not None:
open_on_startup = get_plugin_setting(
ctx.dataset, self.plugin_name, "open_on_startup", default=False
)
else:
open_on_startup = False
if open_on_startup:
ctx.trigger(
"open_panel",
params=dict(
name="voxelgpt", isActive=True, layout="horizontal"
),
)
class VoteForQuery(foo.Operator):
@property
def config(self):
return foo.OperatorConfig(
name="vote_for_query",
label="Vote For Query",
unlisted=True,
)
def resolve_input(self, ctx):
inputs = types.Object()
inputs.str(
"query_id",
label="query_id",
required=True,
description="User Query to Vote For",
)
inputs.enum(
"vote",
["upvote", "downvote"],
label="Vote",
required=True,
)
return types.Property(inputs)
def execute(self, ctx):
query_id = ctx.params["query_id"]
vote = ctx.params["vote"]
with add_sys_path(os.path.dirname(os.path.abspath(__file__))):
# pylint: disable=import-error,no-name-in-module
import db
table = db.table(db.UserQueryTable)
if vote == "upvote":
table.upvote_query(query_id)
elif vote == "downvote":
table.downvote_query(query_id)
else:
raise ValueError(f"Invalid vote '{vote}'")
def get_plugin_setting(dataset, plugin_name, key, default=None):
value = dataset.app_config.plugins.get(plugin_name, {}).get(key, None)
if value is None:
value = fo.app_config.plugins.get(plugin_name, {}).get(key, None)
if value is None:
value = default
return value
def serialize_view(view):
return json.loads(json_util.dumps(view._serialize()))
def deserialize_view(dataset, stages):
return fo.DatasetView._build(dataset, json_util.loads(json.dumps(stages)))
secrets = (
"OPENAI_API_KEY",
"OPENAI_API_TYPE",
"AZURE_OPENAI_GPT35_DEPLOYMENT_NAME",
"AZURE_OPENAI_GPT4O_DEPLOYMENT_NAME",
"AZURE_OPENAI_TEXT_EMBEDDING_3_LARGE_DEPLOYMENT_NAME",
"AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_KEY",
"VOXELGPT_ALLOW_COMPUTATIONS",
"VOXELGPT_APPROVAL_THRESHOLD",
)
def inject_voxelgpt_secrets(ctx):
for secret in secrets:
try:
value = ctx.secrets[secret]
except KeyError:
value = None
if value:
os.environ[secret] = value
def register(p):
p.register(AskVoxelGPT)
p.register(AskVoxelGPTPanel)
p.register(OpenVoxelGPTPanel)
p.register(OpenVoxelGPTPanelOnStartup)
p.register(VoteForQuery)