Skip to content

Commit

Permalink
fix gpt bigcode model laoding with fp16 weights precision
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 6, 2025
1 parent 753f84d commit de0777e
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 0 deletions.
21 changes: 21 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CodeGenOnnxConfig,
FalconOnnxConfig,
GemmaOnnxConfig,
GPTBigCodeOnnxConfig,
GPTJOnnxConfig,
GPTNeoOnnxConfig,
GPTNeoXOnnxConfig,
Expand Down Expand Up @@ -68,6 +69,7 @@
FalconModelPatcher,
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
GptBigCodeModelPatcher,
GptJModelPatcher,
GptNeoModelPatcher,
GptNeoxJapaneseModelPatcher,
Expand Down Expand Up @@ -2554,3 +2556,22 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
)
class GLMOpenVINOConfig(LlamaOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "4.46.0"


@register_in_tasks_manager(
"gpt-bigcode",
*[
"feature-extraction",
"feature-extraction-with-past",
"text-generation",
"text-generation-with-past",
"text-classification",
"token-classification",
],
library_name="transformers",
)
class GPTBigCodeOpenVINOConfig(GPTBigCodeOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GptBigCodeModelPatcher(self, model, model_kwargs=model_kwargs)
90 changes: 90 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3601,3 +3601,93 @@ def __exit__(self, exc_type, exc_value, traceback):
for block in self._model.blocks:
block.forward = block._orig_forward
block.attn.forward = block.attn._orig_forward


# copied from https://github.com/huggingface/transformers/blob/v4.46.3/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L401
def gpt_bigcode_attn(self, query, key, value, attention_mask=None, head_mask=None):
if head_mask is not None:
# The super dispatch is done in the forward.
raise ValueError("PyTorch SDPA does not support head_mask. Please open an issue in Transformers repository.")

scale = None
if not self.scale_attn_weights:
scale = 1

# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]
key.shape[-2]

if self.multi_query:
query_length = query_shape[1]

# SDPA requires the dimension [..., sequence_length, head_dim].
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)

# Without these unsqueeze, SDPA complains as the query and key/value have a different number of dimensions.
key = key.unsqueeze(1)
value = value.unsqueeze(1)

# Although these expand are not numerically useful, PyTorch can not dispatch to memory-efficient backend
# and flash attention backend (No available kernel. Aborting execution.) from the shapes
# query = [batch_size, num_heads, query_length, head_dim]
# key = [batch_size, 1, past_length, head_dim]
# value = [batch_size, 1, past_length, head_dim]
#
# torch==2.1.2 is bugged with non-contiguous inputs with custom attn_mask (https://github.com/pytorch/pytorch/issues/112577), hence the check.
if is_torch_version(">=", "2.2.0"):
key = key.expand(-1, self.num_heads, -1, -1)
value = value.expand(-1, self.num_heads, -1, -1)
else:
query_length = query_shape[-1]

# See the comment above.
if query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The query_length > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not
# create a causal mask in case query_length == 1.
is_causal = True if self.is_causal and attention_mask is None and query_length > 1 else False
# different from original, due to loading model weights in original format transformer.wte dtype may be different from query dtype
if attention_mask is not None:
attention_mask = attention_mask.to(query.dtype)
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_pdrop if self.training else 0.0,
is_causal=is_causal,
scale=scale,
)

if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
sdpa_result = sdpa_result.transpose(1, 2)

# Reshape is kind of expensive here, as it does a memory copy,
# but I did not manage to make away without it (logits do not match when using view)
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
sdpa_result = sdpa_result.reshape(query_shape)

return sdpa_result, None


class GptBigCodeModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
for layer in self._model.transformer.h:
layer.attn._orig_attn = layer.attn._attn
layer.attn._attn = types.MethodType(gpt_bigcode_attn, layer.attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
for layer in self._model.transformer.h:
layer.attn._attn = layer.attn._orig_attn

0 comments on commit de0777e

Please sign in to comment.