Skip to content

Commit

Permalink
feat: use gpu if available
Browse files Browse the repository at this point in the history
  • Loading branch information
kod-kristoff committed Jun 11, 2024
1 parent a93e93a commit 40b9ff6
Showing 1 changed file with 22 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from sparv import api as sparv_api # type: ignore [import-untyped]
from transformers import ( # type: ignore [import-untyped]
BertForMaskedLM,
BertTokenizer,
FillMaskPipeline,
)

logger = sparv_api.get_logger(__name__)

SCORE_FORMATS = {
1: ("{:.1f}", lambda s: s.endswith(".0")),
2: ("{:.2f}", lambda s: s.endswith(".00")),
Expand Down Expand Up @@ -36,9 +40,25 @@ def __init__(

@classmethod
def _default_model(cls) -> BertForMaskedLM:
return BertForMaskedLM.from_pretrained(
MODELS["kb-bert"].model_name, revision=MODELS["kb-bert"].model_revision
if torch.cuda.is_available():
logger.info("Using GPU (cuda)")
dtype = torch.float16
else:
logger.warning("Using CPU, is cuda available?")
dtype = torch.float32
model = BertForMaskedLM.from_pretrained(
MODELS["kb-bert"].model_name,
revision=MODELS["kb-bert"].model_revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
return model

@classmethod
def _default_tokenizer(cls) -> BertTokenizer:
Expand Down

0 comments on commit 40b9ff6

Please sign in to comment.