Skip to content

Commit

Permalink
Patch utils and models (#167)
Browse files Browse the repository at this point in the history
* allow single value resize shape

* fix generate step kwargs

* fix dolphin vision

* fix LM only for DS-VL1

* fix quantize_model

* remove assert for DS-VL2

* fix paligemma LM only

* add smoke test

* update model card

* add sys info

* skip test_smoke file

* add transformers

* bump version
  • Loading branch information
Blaizzy authored Jan 3, 2025
1 parent 564b8de commit 593a5b8
Show file tree
Hide file tree
Showing 9 changed files with 247 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ jobs:
- name: Run Python tests
run: |
cd mlx_vlm/
pytest -s ./tests
pytest -s ./tests --ignore=tests/test_smoke.py
14 changes: 9 additions & 5 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def parse_arguments():
parser.add_argument(
"--resize-shape",
type=int,
nargs=2,
nargs="+",
default=None,
help="Resize shape for the image.",
)
Expand Down Expand Up @@ -84,10 +84,14 @@ def main():

kwargs = {}
if args.resize_shape is not None:
assert (
len(args.resize_shape) == 2
), "Resize shape must be a tuple of two integers"
kwargs["resize_shape"] = args.resize_shape
resize_shape = args.resize_shape
if len(resize_shape) not in [1, 2]:
raise ValueError("Resize shape must be 1 or 2 integers")
kwargs["resize_shape"] = (
(resize_shape[0], resize_shape[0])
if len(resize_shape) == 1
else resize_shape
)

output = generate(
model,
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def get_input_embeddings(
total_tiles.append(pixel_values[idx, : batch_num_tiles[idx]])

total_tiles = mx.concatenate(total_tiles, axis=0)
assert total_tiles.shape[0] == sum(batch_num_tiles)

if total_tiles.shape[0] == 0:
return self.language_model.model.embed_tokens(input_ids)

Expand Down
10 changes: 10 additions & 0 deletions mlx_vlm/models/llava_bunny/llava_bunny.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ class ModelConfig:

@classmethod
def from_dict(cls, params):
if not params.get("text_config", {}):
# Copy text config parameters from root level
excluded_keys = {"vision_config"}
params["text_config"] = dict(
filter(lambda x: x[0] not in excluded_keys, params.items())
)
if not params.get("vision_config", {}).get("model_type", {}):
# Set default model type
params["vision_config"]["model_type"] = "siglip_vision_model"

return cls(
**{
k: v
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/multi_modality/multi_modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def get_input_embeddings(
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids)

image_token_index = self.config.image_token_index
num_image_tokens = self.config.num_image_tokens
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/paligemma/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_input_embeddings(
mask: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids), None

inputs_embeds = self.language_model.model.embed_tokens(input_ids)

Expand Down
218 changes: 218 additions & 0 deletions mlx_vlm/tests/test_smoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
import argparse
import json
import platform
import subprocess
import sys
import textwrap

import mlx.core as mx
import psutil
from rich.console import Console
from rich.panel import Panel
from tqdm import tqdm
from transformers import __version__ as transformers_version

from mlx_vlm import generate, load
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config
from mlx_vlm.version import __version__

# Initialize console
console = Console()


def parse_args():
parser = argparse.ArgumentParser(description="Test MLX-VLM models")
parser.add_argument(
"--models-file",
type=str,
required=True,
help="Path to file containing model paths, one per line",
)
parser.add_argument(
"--image", type=str, nargs="+", required=True, help="Path(s) to test image(s)"
)
parser.add_argument(
"--prompt",
type=str,
default="Describe this image.",
help="Vision-language prompt to test",
)
parser.add_argument(
"--language-only-prompt",
type=str,
default="Hi, how are you?",
help="Language-only prompt to test",
)
parser.add_argument(
"--temperature", type=float, default=0.7, help="Sampling temperature"
)
parser.add_argument(
"--max-tokens", type=int, default=100, help="Maximum tokens to generate"
)
return parser.parse_args()


def get_device_info():
# Disable tokenizers parallelism to avoid deadlocks after forking
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

try:
data = subprocess.check_output(
["system_profiler", "SPDisplaysDataType", "-json"], text=True
)
device_info = json.loads(data)
return device_info
except Exception as e:
print(f"Could not retrieve GPU information: {e}")
return None


def test_model_loading(model_path):
try:
console.print("[bold green]Loading model...")
model, processor = load(model_path, trust_remote_code=True)
config = load_config(model_path, trust_remote_code=True)
console.print("[bold green]✓[/] Model loaded successfully")
return model, processor, config, False
except Exception as e:
console.print(f"[bold red]✗[/] Failed to load model: {str(e)}")
return None, None, None, True


def test_generation(
model, processor, config, model_path, test_inputs, vision_language=True
):
try:
test_type = "vision-language" if vision_language else "language-only"
console.print(f"[bold yellow]Testing {test_type} generation...")

prompt = (
test_inputs["prompt"]
if vision_language
else test_inputs["language_only_prompt"]
)
num_images = len(test_inputs["image"]) if vision_language else 0

formatted_prompt = apply_chat_template(
processor, config, prompt, num_images=num_images
)

generate_args = {
"model": model,
"processor": processor,
"prompt": formatted_prompt,
"verbose": True,
**test_inputs["kwargs"],
}
if vision_language:
generate_args["image"] = test_inputs["image"]

output = generate(**generate_args)

# Deepseek-vl2-tiny outputs are empty on VLM generation
# Paligemma outputs are empty on language-only generation
# So we skip the assertion for these models
if ("deepseek-vl2-tiny" not in model_path and vision_language) or (
"paligemma" not in model_path and not vision_language
):
assert isinstance(output, str) and len(output) > 0

console.print(f"[bold green]✓[/] {test_type} generation successful")
return False
except Exception as e:
console.print(f"[bold red]✗[/] {test_type} generation failed: {str(e)}")
return True


def main():
args = parse_args()

# Load models list
with open(args.models_file, "r", encoding="utf-8") as f:
models = [line.strip() for line in f.readlines()]

# Test inputs dictionary
test_inputs = {
"image": args.image,
"prompt": args.prompt,
"language_only_prompt": args.language_only_prompt,
"kwargs": {
"temp": args.temperature,
"max_tokens": args.max_tokens,
},
}

results = []

for model_path in tqdm(models):
console.print(Panel(f"Testing {model_path}", style="bold blue"))

# Run tests
model, processor, config, error = test_model_loading(model_path)

if not error and model:
print("\n")
# Test vision-language generation
error |= test_generation(
model, processor, config, model_path, test_inputs, vision_language=True
)

print("\n")

# Clear cache and reset peak memory for next test
mx.metal.clear_cache()
mx.metal.reset_peak_memory()

# Test language-only generation
error |= test_generation(
model, processor, config, model_path, test_inputs, vision_language=False
)
print("\n")

console.print("[bold blue]Cleaning up...")
del model, processor
mx.metal.clear_cache()
mx.metal.reset_peak_memory()
console.print("[bold green]✓[/] Cleanup complete\n")
results.append(
f"[bold {'green' if not error else 'red'}]{'✓' if not error else '✗'}[/] {model_path}"
)

print("\n")
success = all(result.startswith("[bold green]") for result in results)
panel_style = "bold green" if success else "bold red"
console.print(Panel("\n".join(results), title="Results", style=panel_style))
console.print(
f"[bold {'green' if success else 'red'}]{'All' if success else 'Some'} models tested {'successfully' if success else 'failed to test'}"
)

print("\n")
device_info = get_device_info()
console.print(
Panel(
title="System Information",
renderable=textwrap.dedent(
f"""{platform.machine() == 'arm64' and f'''
MAC OS: v{platform.mac_ver()[0]}
Python: v{sys.version.split()[0]}
MLX: v{mx.__version__}
MLX-VLM: v{__version__}
Transformers: v{transformers_version}
Hardware:
• Chip: {device_info['SPDisplaysDataType'][0]['_name']}
• RAM: {psutil.virtual_memory().total / (1024 ** 3):.1f} GB
• CPU Cores: {psutil.cpu_count(logical=False)}
• GPU Cores: {device_info['SPDisplaysDataType'][0]['sppci_cores']}
''' or 'Not running on Apple Silicon'}"""
),
style="bold blue",
)
)


if __name__ == "__main__":
main()
10 changes: 5 additions & 5 deletions mlx_vlm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
```
```bash
python -m mlx_vlm.generate --model {upload_repo} --max-tokens 100 --temp 0.0
python -m mlx_vlm.generate --model {upload_repo} --max-tokens 100 --temp 0.0 --prompt "Describe this image." --image <path_to_image>
```
"""
)
Expand Down Expand Up @@ -666,7 +666,6 @@ def convert(
revision: Optional[str] = None,
dequantize: bool = False,
skip_vision: bool = False,
skip_vision_non_divisible: bool = False,
trust_remote_code: bool = True,
):
print("[INFO] Loading")
Expand All @@ -686,7 +685,7 @@ def convert(
print("[INFO] Quantizing")
model.load_weights(list(weights.items()))
weights, config = quantize_model(
model, config, q_group_size, q_bits, skip_vision, skip_vision_non_divisible
model, config, q_group_size, q_bits, skip_vision
)

if dequantize:
Expand Down Expand Up @@ -838,6 +837,7 @@ def generate_step(
model: nn.Module,
pixel_values,
mask,
*,
max_tokens: int = 256,
temp: float = 0.0,
repetition_penalty: Optional[float] = None,
Expand Down Expand Up @@ -1009,19 +1009,19 @@ def stream_generate(
if not image:
input_ids = prompt_tokens[None, :]
pixel_values = mask = None
kwargs = {}
else:
inputs = prepare_inputs(
processor, image, prompt, image_token_index, resize_shape
)
input_ids = inputs["input_ids"]
pixel_values = inputs["pixel_values"]
mask = inputs["attention_mask"]
kwargs = {
data_kwargs = {
k: v
for k, v in inputs.items()
if k not in ["input_ids", "pixel_values", "attention_mask"]
}
kwargs.update(data_kwargs)

detokenizer = processor.detokenizer
detokenizer.reset()
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.9"
__version__ = "0.1.10"

0 comments on commit 593a5b8

Please sign in to comment.