Skip to content

Commit

Permalink
Merge pull request #183 from stanford-oval/dev-google-search
Browse files Browse the repository at this point in the history
[New RM] Add `GoogleSearch`
  • Loading branch information
shaoyijia authored Sep 25, 2024
2 parents 7bfa67b + f1ea4f7 commit 3328c97
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ runner = STORMWikiRunner(engine_args, lm_configs, rm)

Currently, our package support:
- `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components
- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, and `TavilySearchRM` as retrieval module components
- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch` as retrieval module components

:star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!**

Expand Down
2 changes: 1 addition & 1 deletion knowledge_storm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
STORMWikiRunner,
)

__version__ = "0.2.7"
__version__ = "0.2.8"
124 changes: 124 additions & 0 deletions knowledge_storm/rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,127 @@ def forward(
print(f"Error occurs when searching query {query}: {e}")

return collected_results


class GoogleSearch(dspy.Retrieve):
def __init__(
self,
google_search_api_key=None,
google_cse_id=None,
k=3,
is_valid_source: Callable = None,
min_char_count: int = 150,
snippet_chunk_size: int = 1000,
webpage_helper_max_threads=10,
):
"""
Params:
google_search_api_key: Google API key. Check out https://developers.google.com/custom-search/v1/overview
"API key" section
google_cse_id: Custom search engine ID. Check out https://developers.google.com/custom-search/v1/overview
"Search engine ID" section
k: Number of top results to retrieve.
is_valid_source: Optional function to filter valid sources.
min_char_count: Minimum character count for the article to be considered valid.
snippet_chunk_size: Maximum character count for each snippet.
webpage_helper_max_threads: Maximum number of threads to use for webpage helper.
"""
super().__init__(k=k)
try:
from googleapiclient.discovery import build
except ImportError as err:
raise ImportError(
"GoogleSearch requires `pip install google-api-python-client`."
) from err
if not google_search_api_key and not os.environ.get("GOOGLE_SEARCH_API_KEY"):
raise RuntimeError(
"You must supply google_search_api_key or set the GOOGLE_SEARCH_API_KEY environment variable"
)
if not google_cse_id and not os.environ.get("GOOGLE_CSE_ID"):
raise RuntimeError(
"You must supply google_cse_id or set the GOOGLE_CSE_ID environment variable"
)

self.google_search_api_key = (
google_search_api_key or os.environ["GOOGLE_SEARCH_API_KEY"]
)
self.google_cse_id = google_cse_id or os.environ["GOOGLE_CSE_ID"]

if is_valid_source:
self.is_valid_source = is_valid_source
else:
self.is_valid_source = lambda x: True

self.service = build(
"customsearch", "v1", developerKey=self.google_search_api_key
)
self.webpage_helper = WebPageHelper(
min_char_count=min_char_count,
snippet_chunk_size=snippet_chunk_size,
max_thread_num=webpage_helper_max_threads,
)
self.usage = 0

def get_usage_and_reset(self):
usage = self.usage
self.usage = 0
return {"GoogleSearch": usage}

def forward(
self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []
):
"""Search using Google Custom Search API for self.k top results for query or queries.
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
exclude_urls (List[str]): A list of URLs to exclude from the search results.
Returns:
A list of dicts, each dict has keys: 'title', 'url', 'snippet', 'description'.
"""
queries = (
[query_or_queries]
if isinstance(query_or_queries, str)
else query_or_queries
)
self.usage += len(queries)

url_to_results = {}

for query in queries:
try:
response = (
self.service.cse()
.list(
q=query,
cx=self.google_cse_id,
num=self.k,
)
.execute()
)

for item in response.get("items", []):
if (
self.is_valid_source(item["link"])
and item["link"] not in exclude_urls
):
url_to_results[item["link"]] = {
"title": item["title"],
"url": item["link"],
# "snippet": item.get("snippet", ""), # Google search snippet is very short.
"description": item.get("snippet", ""),
}

except Exception as e:
logging.error(f"Error occurred while searching query {query}: {e}")

valid_url_to_snippets = self.webpage_helper.urls_to_snippets(
list(url_to_results.keys())
)
collected_results = []
for url in valid_url_to_snippets:
r = url_to_results[url]
r["snippets"] = valid_url_to_snippets[url]["snippets"]
collected_results.append(r)

return collected_results
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

setup(
name="knowledge-storm",
version="0.2.7",
version="0.2.8",
author="Yijia Shao, Yucheng Jiang",
author_email="[email protected], [email protected]",
description="STORM: A language model-powered knowledge curation engine.",
Expand Down

0 comments on commit 3328c97

Please sign in to comment.