Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Task 5】Research and Implementation of Multi-Label Classification Methods for GitHub Repositories #79

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added Task5/README.assets/image-20240729224307842.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
63 changes: 63 additions & 0 deletions Task5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# README

## 概述

任务5旨在为Github仓库的分类问题构建一个多标签分类模型,能够根据Github上的元数据以及代码质量对于仓库进行打标、分类,从而更好地完成开源社区的各种任务。对于这个问题,我们组认为Github上的元数据(标签)以及代码都可以被总结为文本数据,而大语言模型天然在各类语言任务上具备zero-shot的能力,所以如何利用大语言模型构建完整工作流,实现自动化地对于仓库进行分类,是我们选择的切入点。在技术栈方面,我们选择使用langchain+ZhipuAI的方式构建了自动化的仓库分类模型,项目具备自动爬取Github指定链接仓库的元数据以及代码的能力,且无需进行任何训练或微调,即可通过langchain构建的工作流详细地分析并依据预设的几个标准分类当前仓库。



本项目包括以下几个主要组件:

- `get_info.py`:包含从GitHub API获取仓库信息的函数。
- `main.py`:主程序入口,用于调用其他模块并执行项目的主要逻辑。
- `code_loader.py`:用于从GitHub仓库克隆代码到本地的函数。
- `config.py`:包含项目配置,如GitHub访问令牌等。
- `requirements.txt`:列出了项目所需的Python依赖。
- `llm_classification.py`:包含使用语言模型对代码进行分类和质量分析的逻辑。

## 安装

要在本地运行此项目,请确保您的系统上安装了Python >= 3.8。然后,通过以下步骤安装所需的依赖:

```shell
pip install -r requirements.txt
```

## 使用方法

1. 配置`config.py`文件,填入您的GitHub访问令牌和LLM的API(本项目采用ZhipuAI)。
2. 运行`main.py`文件以开始项目。



## 分类依据

本项目使用`llm_classification.py`模块中的`analyze_repo_with_agent`函数对于仓库元数据进行分析,以如下标准对于仓库进行分类:
- 应用领域
- 开发阶段
- 社区活跃度
- 技术栈

本项目使用`llm_classification.py`模块中的`analyze_code_quality`函数来分析代码质量。它将随机选择一定数量的代码文件,并使用预定义的标准对它们进行评分,最终我们根据langchain的工作流汇总这些评分,得到项目的综合评分。评分标准包括:

- 可读性
- 可维护性
- 一致性
- 简洁性
- 健壮性
- 模块化



### Case Study

![image-20240729224307842](README.assets/image-20240729224307842.png)

### 小组成员

| Task 5 组长 | 李垫 | 调研,技术选型,项目基本框架设计 |
| :---------: | ------ | -------------------------------------- |
| 组员 | 张豈明 | Langchain开发,项目基本逻辑设计 |
| 组员 | 叶韩辉 | 设计自动扒取Github代码以及元数据的函数 |
| 组员 | 吴行健 | Langchain开发,主要业务代码设计 |

20 changes: 20 additions & 0 deletions Task5/code_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import subprocess
import os

def code_loader(repo_url):
# 设置GitHub仓库的URL和本地文件夹路径
local_folder = 'cache'

# 确保目标文件夹存在
if not os.path.exists(local_folder):
os.makedirs(local_folder)

# 切换到目标文件夹
os.chdir(local_folder)

# 使用subprocess执行git clone命令
try:
subprocess.run(['git', 'clone', '-b', 'master', repo_url], check=True)
print(f"代码已成功下载到 {local_folder}")
except subprocess.CalledProcessError as e:
print(f"下载失败: {e}")
5 changes: 5 additions & 0 deletions Task5/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# config.py

GITHUB_TOKEN = ""

ZHIPUAI_API_KEY = ""
81 changes: 81 additions & 0 deletions Task5/get_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import requests


# 获取GitHub仓库路径
def get_repo_path_from_url(url):
parts = url.split('/')
if len(parts) >= 5:
return f"{parts[3]}/{parts[4]}"
else:
return None


# 获取GitHub仓库基本信息
def get_repo_data(repo_path, headers):
api_url = f"https://api.github.com/repos/{repo_path}"
return requests.get(api_url, headers=headers).json()


# 获取GitHub仓库语言信息
def get_languages_data(repo_path, headers):
languages_url = f"https://api.github.com/repos/{repo_path}/languages"
return requests.get(languages_url, headers=headers).json()


# 获取GitHub仓库贡献者信息
def get_contributors_data(repo_path, headers):
contributors_url = f"https://api.github.com/repos/{repo_path}/contributors"
return requests.get(contributors_url, headers=headers).json()


# 获取GitHub仓库授权许可信息
def get_license_data(repo_path, headers):
license_url = f"https://api.github.com/repos/{repo_path}/license"
return requests.get(license_url, headers=headers).json()


# 获取GitHub仓库README内容
def get_readme_content(repo_path, headers):
readme_url = f"https://api.github.com/repos/{repo_path}/readme"
readme_data = requests.get(readme_url, headers=headers).json()
download_url = readme_data.get('download_url')
if download_url:
return requests.get(download_url).text
else:
return "Not available"


# 获取GitHub仓库的代码文件列表
def get_repo_tree(repo_path, headers):
tree_url = f"https://api.github.com/repos/{repo_path}/git/trees/main?recursive=1"
response = requests.get(tree_url, headers=headers)
if response.status_code == 404:
tree_url = f"https://api.github.com/repos/{repo_path}/git/trees/master?recursive=1"
response = requests.get(tree_url, headers=headers)
return response.json()


# 获取单个文件内容
def get_file_content(file_url, headers):
response = requests.get(file_url, headers=headers)
return response.text


# 获取GitHub特定仓库主要代码
def print_repo_code(repo_path, headers):
repo_tree = get_repo_tree(repo_path, headers)

if 'tree' not in repo_tree:
print("Error: No tree data found in the repository response.")
return

# 筛选出主要代码文件(如 .py, .js 等)
code_files = [file for file in repo_tree['tree'] if
file['path'].endswith(('.py', '.js', '.java', '.cpp', '.c', '.rb', '.go'))]

for file in code_files:
file_url = file['url']
file_content = get_file_content(file_url, headers)
print(f"File: {file['path']}\n")
print(file_content)
print("\n" + "=" * 80 + "\n")
178 changes: 178 additions & 0 deletions Task5/llm_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import os
import random
import time
from pathlib import Path
import numpy as np
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_community.chat_models import ChatZhipuAI
from langchain.globals import set_debug
import config
from pydantic.v1 import validator, Field
from pydantic.v1 import BaseModel

set_debug(True)


class CodeQuality(BaseModel):
readability: int = Field(default=70, description="请为这(些)代码的可读性打分(0-100分):")
maintainability: int = Field(default=70, description="请为这(些)代码的可维护性打分(0-100分):")
consistencency: int = Field(default=70, description="请为这(些)代码的一致性打分(0-100分):")
complexity: int = Field(default=70, description="请为这(些)代码的简洁性打分(0-100分):")
robustness: int = Field(default=70, description="请为这(些)代码的健壮性打分(0-100分):")
modualrity: int = Field(default=70, description="请为这(些)代码的模块化打分(0-100分):")

@validator('readability', 'maintainability', 'consistencency', 'complexity', 'robustness', 'modualrity')
@classmethod
def validate_score(cls, field):
if not 0 <= field <= 100:
raise ValueError("Score must be between 0 and 100")
return field


class RepoAnalysis(BaseModel):
application_domain: str = Field(default="机器学习", description="请根据如上对于本仓库的描述,给出该仓库的应用领域:")
development_stage: str = Field(default="初创", description="请根据如上对于本仓库的描述,给出该仓库的开发阶段:")
community_activity: str = Field(default="活跃", description="请根据如上对于本仓库的描述,给出该仓库的社区活跃度:")
tech_stack: str = Field(default="Python", description="请根据如上对于本仓库的描述,给出该仓库的技术栈")


prompt_set = {
"repo_analysis_prompt": """请根据以下GitHub仓库信息进行分析并用中文回答:
- 仓库名称: {reponame}
- 描述: {repodescription}
- 语言: {languages}
- 贡献者数量: {contributors_count}
- Star数: {star}
- Fork数: {fork}
- README内容: {readme_content}

{format_instructions}
""",

"code_quality_prompt": """{file_content}
请根据以下标准分析代码的质量并用中文回答:
- 可读性:代码是否易于理解,命名是否清晰,格式是否规范。
- 可维护性:代码是否容易修改和扩展,是否遵循了设计模式和原则。
- 一致性:代码风格是否统一,是否遵循了团队或项目的编码标准。
- 简洁性:代码是否简洁,避免冗余和复杂性。
- 健壮性:代码是否能够优雅地处理错误和异常情况。
- 模块化:代码是否按照功能划分模块,模块之间是否低耦合。

{format_instructions}
""",
}


def analyze_repo_with_agent(llm, repo_data, languages_data, contributors_count, readme_content):
repo_analysis_parser = PydanticOutputParser(pydantic_object=RepoAnalysis)
repo_analysis_prompt = PromptTemplate.from_template(template=prompt_set["repo_analysis_prompt"], partial_variables={
"format_instructions": repo_analysis_parser.get_format_instructions()})
repo_analysis_chain = repo_analysis_prompt | llm | repo_analysis_parser

repo_analysis_report = repo_analysis_chain.invoke(
{'reponame': repo_data.get('name'), 'repodescription': repo_data.get('description'),
'languages': languages_data,
'contributors_count': contributors_count, 'star': repo_data.get('stargazers_count'),
'fork': repo_data.get('forks_count'), 'readme_content': readme_content})

report = f'应用领域:{repo_analysis_report.application_domain},开发阶段:{repo_analysis_report.development_stage},社区活跃度:{repo_analysis_report.community_activity},技术栈:{repo_analysis_report.tech_stack}'
return report


def analyze_code_quality(repo_name):
path = f'cache/{repo_name}'
files = get_all_files(path)
code_files = [file for file in files if
file.endswith(('.py', '.js', '.java', '.cpp', '.c', '.rb', '.go'))]
length = len(code_files)

max_length = min(length, 5)

random.shuffle(code_files)
code_files = code_files[:max_length]

llm = ChatZhipuAI(model='glm-4', api_key=config.ZHIPUAI_API_KEY, temperature=0.5)
code_quality_parser = PydanticOutputParser(pydantic_object=CodeQuality)
code_quality_prompt = PromptTemplate.from_template(template=prompt_set["code_quality_prompt"], partial_variables={
'format_instructions': code_quality_parser.get_format_instructions()})

code_quality_chain = code_quality_prompt | llm | code_quality_parser

file_content = [read_file_to_string(code_file) for code_file in code_files]

evaluations = []
cnt = 0
for content in file_content:
while (1):
try:
evaluation = code_quality_chain.invoke({'file_content': content})
break
except:
cnt += 1
if cnt > 5:
print("代码应该出错了,请检查API配置和VPN设置,强制退出中......")
exit(-1)
print("出错了,正在重试.......")
time.sleep(5) # wait for 5 seconds before retrying

cnt = 0
evaluations.append(evaluation)

readability, maintainability, consistencency, complexity, robustness, modualrity = [], [], [], [], [], []

for evaluation in evaluations:
readability.append(evaluation.readability)

maintainability.append(evaluation.maintainability)

consistencency.append(evaluation.consistencency)

complexity.append(evaluation.complexity)

robustness.append(evaluation.robustness)

modualrity.append(evaluation.modualrity)

code_quality = f'可读性: {np.mean(readability)}\n可维护性: {np.mean(maintainability)}\n一致性: {np.mean(consistencency)}\n简洁性: {np.mean(complexity)}\n鲁棒性: {np.mean(robustness)}\n模块化: {np.mean(modualrity)}'
return code_quality


def analyze_with_agent(repo_name, repo_data, languages_data, contributors_count, readme_content):
llm = ChatZhipuAI(model='glm-4', api_key=config.ZHIPUAI_API_KEY)
cnt = 0
while True:
try:
repo_analysis_report = analyze_repo_with_agent(llm, repo_data, languages_data, contributors_count,
readme_content)
break
except:
cnt += 1
if cnt > 5:
print("代码应该出错了,请检查API配置和VPN设置,强制退出中......")
exit(-1)
print("仓库分析热点出错了,正在重试.......")
time.sleep(5) # wait for 5 seconds before retrying

code_quality_score = analyze_code_quality(repo_name)

return repo_analysis_report, code_quality_score


def get_all_files(directory):
files_list = []
for root, dirs, files in os.walk(directory):
for file in files:
files_list.append(os.path.join(root, file))
return files_list


def read_file_to_string(file_path):
try:
path = Path(file_path)
content = path.read_text(encoding='utf-8') # 使用read_text读取文本内容
return content
except FileNotFoundError:
print(f"The file {file_path} was not found.")
except IOError as e:
print(f"An I/O error occurred: {e}")
Loading