Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix eval_imagenet and add eval_winoground #193

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 109 additions & 4 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@
default=False,
help="Whether to evaluate on ImageNet.",
)
parser.add_argument(
"--eval_winoground",
action="store_true",
default=False,
help="Whether to evaluate on ImageNet.",
)

parser.add_argument(
"--eval_flickr30",
Expand Down Expand Up @@ -150,6 +156,14 @@
## Imagenet dataset
parser.add_argument("--imagenet_root", type=str, default="/tmp")

## Winoground dataset
parser.add_argument(
"--hf_auth_token",
type=str,
default=None,
help="HuggingFace auth token is needed to download the Winoground dataset",
)

parser.add_argument(
"--model",
type=str,
Expand Down Expand Up @@ -277,6 +291,16 @@ def main():
{"shots": shot, "trials": scores, "mean": np.mean(scores)}
)

if args.eval_winoground:
print("Evaluating on Winoground...")
winoground_result = evaluate_winoground(eval_model, args.hf_auth_token)
results["winoground"].append(winoground_result)
print(
f'Text Score {winoground_result["text_score"] * 100:4.1f} | '
f'Image Score {winoground_result["image_score"] * 100:4.1f} | '
f'Group Score {winoground_result["group_score"] * 100:4.1f}'
)

if args.results_file is not None:
with open(args.results_file, "w") as f:
json.dump(results, f)
Expand Down Expand Up @@ -643,7 +667,7 @@ def evaluate_imagenet(
] + [eval_model.image_processor(batch["image"]).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
model._encode_vision_x(vision_x.cuda())
model._encode_vision_x(vision_x.to(eval_model.device))

context_class_names = [
in_context_samples[i]["class_name"] for i in range(effective_num_shots)
Expand Down Expand Up @@ -672,16 +696,16 @@ def evaluate_imagenet(

outputs = model(
vision_x=None,
lang_x=lang_x["input_ids"].cuda(),
attention_mask=lang_x["attention_mask"].cuda(),
lang_x=lang_x["input_ids"].to(eval_model.device),
attention_mask=lang_x["attention_mask"].to(eval_model.device),
clear_conditioned_layers=False,
use_cached_vision_x=True,
)
probs = torch.softmax(outputs.logits, dim=-1).detach()
# collect the probability of the generated token -- probability
# at index 0 corresponds to the token at index 1
probs = probs[:, :-1, :]
input_ids = lang_x["input_ids"][:, 1:].cuda()
input_ids = lang_x["input_ids"][:, 1:].to(eval_model.device)
gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)

probs = []
Expand Down Expand Up @@ -712,5 +736,86 @@ def evaluate_imagenet(
return float(acc1) / num_samples


def text_correct(result):
return result["c0_i0"] > result["c1_i0"] and result["c1_i1"] > result["c0_i1"]


def image_correct(result):
return result["c0_i0"] > result["c0_i1"] and result["c1_i1"] > result["c1_i0"]


def group_correct(result):
return image_correct(result) and text_correct(result)


def winoground_acc(scores):
text_correct_count = 0
image_correct_count = 0
group_correct_count = 0
for result in scores:
text_correct_count += 1 if text_correct(result) else 0
image_correct_count += 1 if image_correct(result) else 0
group_correct_count += 1 if group_correct(result) else 0

denominator = len(scores)
return {
"text_score": text_correct_count / denominator,
"image_score": image_correct_count / denominator,
"group_score": group_correct_count / denominator,
}


def evaluate_winoground(eval_model, auth_token):
if not hasattr(eval_model, "model") or not hasattr(eval_model, "tokenizer"):
raise NotImplementedError(
"evaluate_winoground is currently only supported for OpenFlamingo " "models"
)
model, tokenizer = eval_model.model, eval_model.tokenizer
assert isinstance(model, Flamingo)

from datasets import load_dataset

winoground = load_dataset("facebook/winoground", use_auth_token=auth_token)["test"]

prompt_text = "<image>A photo of a"

all_results = []
for sample_idx in tqdm(range(len(winoground))):
sample = winoground[sample_idx]
images = [sample[f"image_{i}"].convert("RGB") for i in range(2)]
captions = [sample[f"caption_{i}"] for i in range(2)]
cur_res = {"c0_i0": 0, "c1_i0": 0, "c0_i1": 0, "c1_i1": 0}

for img_idx, img in enumerate(images):
vision_x = eval_model.image_processor(img).unsqueeze(0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
model._encode_vision_x(vision_x.to(eval_model.device))

for cap_idx, caption in enumerate(captions):
target_text = f"{prompt_text} {caption}"
lang_x = tokenizer([target_text], return_tensors="pt")

outputs = model(
vision_x=None,
lang_x=lang_x["input_ids"].to(eval_model.device),
attention_mask=lang_x["attention_mask"].to(eval_model.device),
clear_conditioned_layers=False,
use_cached_vision_x=True,
)
probs = torch.softmax(outputs.logits, dim=-1).detach()

# collect the probability of the generated token -- probability
# at index 0 corresponds to the token at index 1
probs = probs[:, :-1, :]
input_ids = lang_x["input_ids"][:, 1:].to(eval_model.device)
gen_probs = torch.gather(probs, 2, input_ids[:, :, None]).squeeze(-1)
cur_prob = torch.prod(gen_probs.squeeze()).item()
cur_res[f"c{cap_idx}_i{img_idx}"] = cur_prob

all_results.append(cur_res)

return winoground_acc(all_results)


if __name__ == "__main__":
main()
29 changes: 29 additions & 0 deletions open_flamingo/scripts/run_eval_winoground.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
echo 'activating virtual environment'
source ~/.bashrc
eval "$(conda shell.bash hook)"
conda activate openflamingo
which python

LM_PATH="luodian/llama-7b-hf"
LM_TOKENIZER_PATH="luodian/llama-7b-hf"
VISION_ENCODER_NAME="ViT-L-14"
VISION_ENCODER_PRETRAINED="openai"
CKPT_PATH="openflamingo/OpenFlamingo-9B/checkpoint.pt"
DEVICE="0"

RANDOM_ID=$$
RESULTS_FILE="results_${RANDOM_ID}.json"

python open_flamingo/eval/evaluate.py \
--lm_path $LM_PATH \
--lm_tokenizer_path $LM_TOKENIZER_PATH \
--vision_encoder_path $VISION_ENCODER_NAME \
--vision_encoder_pretrained $VISION_ENCODER_PRETRAINED \
--checkpoint_path $CKPT_PATH \
--cross_attn_every_n_layers 4 \
--device $DEVICE \
--results_file $RESULTS_FILE \
--eval_winoground


echo "evaluation complete! results written to ${RESULTS_FILE}"