-
Notifications
You must be signed in to change notification settings - Fork 867
/
Copy pathDownload_model.py
69 lines (59 loc) · 1.64 KB
/
Download_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import os
from huggingface_hub import HfApi, snapshot_download
def dir_path(path_str):
if os.path.isdir(path_str):
return path_str
elif input(f"{path_str} does not exist, create directory? [y/n]").lower() == "y":
os.makedirs(path_str)
return path_str
else:
raise NotADirectoryError(path_str)
class HFModelNotFoundError(Exception):
def __init__(self, model_str):
super().__init__(f"HuggingFace model not found: '{model_str}'")
def hf_model(model_str):
api = HfApi()
models = [m.modelId for m in api.list_models()]
if model_str in models:
return model_str
else:
raise HFModelNotFoundError(model_str)
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
"-o",
type=dir_path,
default="model",
help="Output directory for downloaded model files",
)
parser.add_argument(
"--model_name", "-m", type=hf_model, required=True, help="HuggingFace model name"
)
parser.add_argument(
"--use_auth_token",
"-t",
type=bool,
default=False,
help="Use HF authentication token",
)
parser.add_argument("--revision", "-r", type=str, default="main", help="Revision")
args = parser.parse_args()
# Only download pytorch checkpoint files
allow_patterns = [
"*.json",
"*.pt",
"*.bin",
"*.txt",
"*.model",
"*.pth",
"*.safetensors",
]
snapshot_path = snapshot_download(
repo_id=args.model_name,
revision=args.revision,
allow_patterns=allow_patterns,
cache_dir=args.model_path,
use_auth_token=args.use_auth_token,
)
print(f"Files for '{args.model_name}' is downloaded to '{snapshot_path}'")