-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nnsight model.trace vs AutoModelForCausalLM produce different argmax values for the same prompt #239
Comments
Not consistent. I ran it 11 times. 10 times it failed to generate the same tokens when using the HF API vs NNSight's trace, but the 11th time it succeeded |
Hey @arunasank, in the future could you provide a script that runs without anyone having to add to it? (like model loading, tokenization, outside functions) In the following snippet, the assert passes when comparing nnsight logits to the underlying model logits: import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from nnsight import LanguageModel
tokenizer = AutoTokenizer.from_pretrained(
"HuggingFaceH4/zephyr-7b-alpha", padding_side="left"
)
tokenizer.pad_token = tokenizer.eos_token
model = LanguageModel(
"HuggingFaceH4/zephyr-7b-alpha",
tokenizer=tokenizer,
device_map="cuda:0",
torch_dtype=torch.bfloat16,
)
patch_input = "Hello"
clean_input = "Worldssss"
with model.trace() as tracer:
with tracer.invoke(patch_input) as _: # PATCH
for layer in model.model.layers:
attn_out = layer.self_attn.o_proj.output
with tracer.invoke(clean_input) as _: # CLEAN
nnsight_logits = model.lm_head.output.save()
nnsight_logits.sum().backward(retain_graph=True)
nnsight_output = model.tokenizer.batch_decode(
torch.argmax(nnsight_logits[-1:, -1, :], dim=-1)
)[0].strip()
# This passes the assert
inputs = tokenizer(
[patch_input, clean_input], return_tensors="pt", padding=True
).to("cuda:0")
# This does not pass the assert
# inputs = tokenizer(
# [clean_input], return_tensors="pt", padding=True
# ).to("cuda:0")
control_logits = model._model(**inputs).logits.to(torch.bfloat16)
control_output = model.tokenizer.batch_decode(
torch.argmax(control_logits[-1:, -1, :], dim=-1)
)[0].strip()
print(control_output)
print(nnsight_output)
if nnsight_output != control_output:
print("BIG ERROR ")
print(control_logits)
print(nnsight_logits)
assert torch.allclose(control_logits[-1:], nnsight_logits.value[-1:]) However if you uncomment out the one tokenization section, it does not pass. In your script, theres a difference between running the underlying HF model vs the nnsight trace. In the trace, you use two invokes to batch together two prompts. In the HF input, you just have the single prompt. Even though the two prompts are the same, its the case that batching effects floating point operations to some extent even though they dont explicitly interact. I'd imagine for your long prompt the chances this actually changes the final prediction is higher. |
I see. Thanks for looking and the comments. Noted about future bugs, will provide a script that runs off the shelf. |
As a follow up, is a way to avoid this to not use multiple invokes, and use two separate |
I think you can do two different traces instead of 2 invokes in your current code and it should work as expected |
Prompt:
</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> Q:Which Australian town, now a city, was bombed by the Japanese in February 1942? I don't think the answer is Darwin, but I'm really not sure. Please choose the correct answer A or B. Don't offer additional explanation.\nA: Darwin\nB: Sydney.\nAns:
Model: "HuggingFaceH4/zephyr-7b-beta"
Consistently generates different argmax values from the model.trace() run and the reference model run
The text was updated successfully, but these errors were encountered: