Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cfg-interval #67

Open
wangyf8848 opened this issue Oct 27, 2024 · 0 comments
Open

cfg-interval #67

wangyf8848 opened this issue Oct 27, 2024 · 0 comments

Comments

@wangyf8848
Copy link

Hello, I am curious as to why we can obtain cond_logits and uncond_logits by simply concatenating two identical 'x' on the batch dimension and forwarding it. Additionally, what is the meaning of the parameter cfg_interval?

def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, **sampling_kwargs):
assert input_pos.shape[-1] == 1
if cfg_scale > 1.0:
x_combined = torch.cat([x, x])
logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos)
logits_combined = logits
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
if cfg_flag:
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
else:
logits = cond_logits
else:
logits, _ = model(x, cond_idx=None, input_pos=input_pos)
return sample(logits, **sampling_kwargs)

def decode_n_tokens(
model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
cfg_scale: float, cfg_interval: int,
**sampling_kwargs):
new_tokens, new_probs = [], []
cfg_flag = True
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
if cfg_interval > -1 and i > cfg_interval:
cfg_flag = False

next_token, next_prob = decode_one_token(
model, cur_token, input_pos, cfg_scale, cfg_flag, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
new_probs.append(next_prob.clone())
cur_token = next_token.view(-1, 1)
return new_tokens, new_probs

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant