-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgemini_trained.py
116 lines (95 loc) · 3.29 KB
/
gemini_trained.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
At the command line, only need to run once to install the package via pip:
$ pip install google-generativeai
"""
import google
import google.generativeai as genai
from gemini_untrained import get_response, get_reply, get_ques
from recipeDB import recipeInfo
import pickle
genai.configure(api_key="AIzaSyDlPcLWQdCWkaVol1kEncwAk8rx66rplmI")
# Set up the model
generation_config = {
"temperature": 0.9,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
def setHistory(fileName):
global history
with open(fileName, 'rb') as f:
history = pickle.load(f)
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH"
},
]
model = genai.GenerativeModel(model_name="gemini-1.0-pro",
generation_config=generation_config,
safety_settings=safety_settings)
setHistory('state.pickle')
with open('idset.pickle', 'rb') as f:
idset = pickle.load(f)
convo = model.start_chat(history=history)
context = []
def sendQuery(query: str):
# ic("query", query)
try:
convo.send_message(query)
response = convo.last.text
# ic("response", response)
if len(response) == 0:
return "None type query returned"
context.append("user input: " + query + " and its expected response: " + response)
return response
except google.api_core.exceptions.InternalServerError:
return "Something with google's api went wrong. Please try again"
except:
return "There was an error while fetching the query"
def verification(last_text: str, query: str):
print(last_text)
try:
food_list: list[int] = []
for s in last_text.replace(",", " ").split(" "):
if len(s) == 0:
continue
food_list.append(int(s))
# ic("food list: ", food_list)
for food in food_list:
if str(food) not in idset and food != -5 and food != -10:
return "AI just hallucinated"
if food_list[0] != -5 and food_list[0] != -10:
dish_list = []
for food_index in food_list:
food_data = recipeInfo(food_index)
img_url = food_data["img_url"]
dish_list.append(
str({"Recipe Title": food_data["Recipe_title"], "ingredients": str(food_data["ingredients"]),
"instructions": str(food_data["instructions"])}))
# ic("dish list", dish_list, "\n")
# ic("context", context)
return [img_url, get_response(dish_list, context, query)]
elif food_list[0] == -5:
return [get_ques(str(context), query)]
else:
return [get_reply(str(context), query)]
except ValueError:
return [get_reply(str(context), query)]
def do_it_all(query: str) -> str:
ret_val = verification(sendQuery(query), query)
if ret_val is None:
return "Return"
return ret_val