Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nnsight model.trace vs AutoModelForCausalLM produce different argmax values for the same prompt #239

Open
arunasank opened this issue Sep 13, 2024 · 5 comments

Comments

@arunasank
Copy link

Prompt: </s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s><s> Q:Which Australian town, now a city, was bombed by the Japanese in February 1942? I don't think the answer is Darwin, but I'm really not sure. Please choose the correct answer A or B. Don't offer additional explanation.\nA: Darwin\nB: Sydney.\nAns:

Model: "HuggingFaceH4/zephyr-7b-beta"

with model.trace() as tracer:
        with tracer.invoke(syc_tokens) as _: # PATCH
            for layer in model.model.layers:
                attn_out = layer.self_attn.o_proj.output
                syc_out.append(attn_out.save())

        with tracer.invoke(no_syc_tokens) as _: # CLEAN

            for layer in model.model.layers:
                attn_out = layer.self_attn.o_proj.output
                no_syc_out.append(attn_out.save())
                no_syc_out_grads.append(attn_out.grad.save())
            no_syc_prompts_logits = model.lm_head.output.save()
            
            value = atp_utils.get_logit_diff(no_syc_prompts_logits.cpu(), syc_answers, no_syc_answers).save()
            value.backward(retain_graph=True)
    if (value >= 0):
        counter += 1
        return [], [], []
    else:
        argmax_token = model.tokenizer.batch_decode(torch.argmax(no_syc_prompts_logits[:, -1, :], dim=-1))[0].strip()
        no_syc_answers = model.tokenizer.batch_decode(no_syc_answers, skip_special_tokens=True)[0].strip()
        ref_logits = ref_model(no_syc_tokens).logits
        ref_argmax_token = model.tokenizer.batch_decode(torch.argmax(ref_logits[:, -1, :], dim=-1))[0]
        print(argmax_token, no_syc_answers, ref_argmax_token)
        if (argmax_token != ref_argmax_token):
            print('BIG ERROR ', model.tokenizer.batch_decode(no_syc_tokens)[0])
            assert torch.allclose(ref_logits, no_syc_prompts_logits.value)
        assert argmax_token == ref_argmax_token
        if (argmax_token != no_syc_answers):
            print('ERROR ', model.tokenizer.batch_decode(no_syc_tokens, skip_special_tokens=True)[0])
        syc_out = [c.value.detach().cpu() for c in syc_out]
        no_syc_out = [c.value.detach().cpu() for c in no_syc_out]
        no_syc_out_grads = [c.value.detach().cpu() for c in no_syc_out_grads]
        return syc_out, no_syc_out, no_syc_out_grads

Consistently generates different argmax values from the model.trace() run and the reference model run

@arunasank
Copy link
Author

Not consistent. I ran it 11 times. 10 times it failed to generate the same tokens when using the HF API vs NNSight's trace, but the 11th time it succeeded

@JadenFiotto-Kaufman
Copy link
Member

Hey @arunasank, in the future could you provide a script that runs without anyone having to add to it? (like model loading, tokenization, outside functions)

In the following snippet, the assert passes when comparing nnsight logits to the underlying model logits:

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from nnsight import LanguageModel

tokenizer = AutoTokenizer.from_pretrained(
    "HuggingFaceH4/zephyr-7b-alpha", padding_side="left"
)
tokenizer.pad_token = tokenizer.eos_token
model = LanguageModel(
    "HuggingFaceH4/zephyr-7b-alpha",
    tokenizer=tokenizer,
    device_map="cuda:0",
    torch_dtype=torch.bfloat16,
)

patch_input = "Hello"
clean_input = "Worldssss"

with model.trace() as tracer:
    with tracer.invoke(patch_input) as _:  # PATCH
        for layer in model.model.layers:
            attn_out = layer.self_attn.o_proj.output

    with tracer.invoke(clean_input) as _:  # CLEAN

        nnsight_logits = model.lm_head.output.save()

        nnsight_logits.sum().backward(retain_graph=True)

nnsight_output = model.tokenizer.batch_decode(
    torch.argmax(nnsight_logits[-1:, -1, :], dim=-1)
)[0].strip()

# This passes the assert
inputs = tokenizer(
    [patch_input, clean_input], return_tensors="pt", padding=True
).to("cuda:0")

# This does not pass the assert
# inputs = tokenizer(
#     [clean_input], return_tensors="pt", padding=True
# ).to("cuda:0")

control_logits = model._model(**inputs).logits.to(torch.bfloat16)

control_output = model.tokenizer.batch_decode(
    torch.argmax(control_logits[-1:, -1, :], dim=-1)
)[0].strip()

print(control_output)
print(nnsight_output)

if nnsight_output != control_output:
    print("BIG ERROR ")
    
print(control_logits)
print(nnsight_logits)
assert torch.allclose(control_logits[-1:], nnsight_logits.value[-1:])

However if you uncomment out the one tokenization section, it does not pass. In your script, theres a difference between running the underlying HF model vs the nnsight trace. In the trace, you use two invokes to batch together two prompts. In the HF input, you just have the single prompt. Even though the two prompts are the same, its the case that batching effects floating point operations to some extent even though they dont explicitly interact. I'd imagine for your long prompt the chances this actually changes the final prediction is higher.

@arunasank
Copy link
Author

I see. Thanks for looking and the comments. Noted about future bugs, will provide a script that runs off the shelf.

@arunasank arunasank reopened this Sep 13, 2024
@arunasank
Copy link
Author

As a follow up, is a way to avoid this to not use multiple invokes, and use two separate trace calls?

@Butanium
Copy link
Contributor

Butanium commented Oct 5, 2024

I think you can do two different traces instead of 2 invokes in your current code and it should work as expected

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants