You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, I am curious as to why we can obtain cond_logits and uncond_logits by simply concatenating two identical 'x' on the batch dimension and forwarding it. Additionally, what is the meaning of the parameter cfg_interval?
Hello, I am curious as to why we can obtain cond_logits and uncond_logits by simply concatenating two identical 'x' on the batch dimension and forwarding it. Additionally, what is the meaning of the parameter cfg_interval?
def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, **sampling_kwargs):
assert input_pos.shape[-1] == 1
if cfg_scale > 1.0:
x_combined = torch.cat([x, x])
logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos)
logits_combined = logits
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
if cfg_flag:
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
else:
logits = cond_logits
else:
logits, _ = model(x, cond_idx=None, input_pos=input_pos)
return sample(logits, **sampling_kwargs)
def decode_n_tokens(
model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
cfg_scale: float, cfg_interval: int,
**sampling_kwargs):
new_tokens, new_probs = [], []
cfg_flag = True
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
if cfg_interval > -1 and i > cfg_interval:
cfg_flag = False
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, cfg_scale, cfg_flag, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
new_probs.append(next_prob.clone())
cur_token = next_token.view(-1, 1)
return new_tokens, new_probs
The text was updated successfully, but these errors were encountered: