From 593a5b84bf732d4c9769bc5000c1a6ae217c60b7 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 3 Jan 2025 16:02:21 +0100 Subject: [PATCH] Patch utils and models (#167) * allow single value resize shape * fix generate step kwargs * fix dolphin vision * fix LM only for DS-VL1 * fix quantize_model * remove assert for DS-VL2 * fix paligemma LM only * add smoke test * update model card * add sys info * skip test_smoke file * add transformers * bump version --- .github/workflows/tests.yml | 2 +- mlx_vlm/generate.py | 14 +- .../models/deepseek_vl_v2/deepseek_vl_v2.py | 2 +- mlx_vlm/models/llava_bunny/llava_bunny.py | 10 + .../models/multi_modality/multi_modality.py | 2 +- mlx_vlm/models/paligemma/paligemma.py | 2 +- mlx_vlm/tests/test_smoke.py | 218 ++++++++++++++++++ mlx_vlm/utils.py | 10 +- mlx_vlm/version.py | 2 +- 9 files changed, 247 insertions(+), 15 deletions(-) create mode 100644 mlx_vlm/tests/test_smoke.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 41bc0a6..e1e4ce9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -39,4 +39,4 @@ jobs: - name: Run Python tests run: | cd mlx_vlm/ - pytest -s ./tests + pytest -s ./tests --ignore=tests/test_smoke.py diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 549424b..16fbde4 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -39,7 +39,7 @@ def parse_arguments(): parser.add_argument( "--resize-shape", type=int, - nargs=2, + nargs="+", default=None, help="Resize shape for the image.", ) @@ -84,10 +84,14 @@ def main(): kwargs = {} if args.resize_shape is not None: - assert ( - len(args.resize_shape) == 2 - ), "Resize shape must be a tuple of two integers" - kwargs["resize_shape"] = args.resize_shape + resize_shape = args.resize_shape + if len(resize_shape) not in [1, 2]: + raise ValueError("Resize shape must be 1 or 2 integers") + kwargs["resize_shape"] = ( + (resize_shape[0], resize_shape[0]) + if len(resize_shape) == 1 + else resize_shape + ) output = generate( model, diff --git a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py index 1796083..3629bcc 100644 --- a/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py +++ b/mlx_vlm/models/deepseek_vl_v2/deepseek_vl_v2.py @@ -401,7 +401,7 @@ def get_input_embeddings( total_tiles.append(pixel_values[idx, : batch_num_tiles[idx]]) total_tiles = mx.concatenate(total_tiles, axis=0) - assert total_tiles.shape[0] == sum(batch_num_tiles) + if total_tiles.shape[0] == 0: return self.language_model.model.embed_tokens(input_ids) diff --git a/mlx_vlm/models/llava_bunny/llava_bunny.py b/mlx_vlm/models/llava_bunny/llava_bunny.py index a9c6ad0..730d4dc 100644 --- a/mlx_vlm/models/llava_bunny/llava_bunny.py +++ b/mlx_vlm/models/llava_bunny/llava_bunny.py @@ -42,6 +42,16 @@ class ModelConfig: @classmethod def from_dict(cls, params): + if not params.get("text_config", {}): + # Copy text config parameters from root level + excluded_keys = {"vision_config"} + params["text_config"] = dict( + filter(lambda x: x[0] not in excluded_keys, params.items()) + ) + if not params.get("vision_config", {}).get("model_type", {}): + # Set default model type + params["vision_config"]["model_type"] = "siglip_vision_model" + return cls( **{ k: v diff --git a/mlx_vlm/models/multi_modality/multi_modality.py b/mlx_vlm/models/multi_modality/multi_modality.py index 130edf6..d512abc 100644 --- a/mlx_vlm/models/multi_modality/multi_modality.py +++ b/mlx_vlm/models/multi_modality/multi_modality.py @@ -303,7 +303,7 @@ def get_input_embeddings( pixel_values: Optional[mx.array] = None, ): if pixel_values is None: - return self.language_model(input_ids) + return self.language_model.model.embed_tokens(input_ids) image_token_index = self.config.image_token_index num_image_tokens = self.config.num_image_tokens diff --git a/mlx_vlm/models/paligemma/paligemma.py b/mlx_vlm/models/paligemma/paligemma.py index 7125940..b8179b9 100644 --- a/mlx_vlm/models/paligemma/paligemma.py +++ b/mlx_vlm/models/paligemma/paligemma.py @@ -66,7 +66,7 @@ def get_input_embeddings( mask: Optional[mx.array] = None, ): if pixel_values is None: - return self.language_model(input_ids) + return self.language_model.model.embed_tokens(input_ids), None inputs_embeds = self.language_model.model.embed_tokens(input_ids) diff --git a/mlx_vlm/tests/test_smoke.py b/mlx_vlm/tests/test_smoke.py new file mode 100644 index 0000000..f74d894 --- /dev/null +++ b/mlx_vlm/tests/test_smoke.py @@ -0,0 +1,218 @@ +import argparse +import json +import platform +import subprocess +import sys +import textwrap + +import mlx.core as mx +import psutil +from rich.console import Console +from rich.panel import Panel +from tqdm import tqdm +from transformers import __version__ as transformers_version + +from mlx_vlm import generate, load +from mlx_vlm.prompt_utils import apply_chat_template +from mlx_vlm.utils import load_config +from mlx_vlm.version import __version__ + +# Initialize console +console = Console() + + +def parse_args(): + parser = argparse.ArgumentParser(description="Test MLX-VLM models") + parser.add_argument( + "--models-file", + type=str, + required=True, + help="Path to file containing model paths, one per line", + ) + parser.add_argument( + "--image", type=str, nargs="+", required=True, help="Path(s) to test image(s)" + ) + parser.add_argument( + "--prompt", + type=str, + default="Describe this image.", + help="Vision-language prompt to test", + ) + parser.add_argument( + "--language-only-prompt", + type=str, + default="Hi, how are you?", + help="Language-only prompt to test", + ) + parser.add_argument( + "--temperature", type=float, default=0.7, help="Sampling temperature" + ) + parser.add_argument( + "--max-tokens", type=int, default=100, help="Maximum tokens to generate" + ) + return parser.parse_args() + + +def get_device_info(): + # Disable tokenizers parallelism to avoid deadlocks after forking + import os + + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + try: + data = subprocess.check_output( + ["system_profiler", "SPDisplaysDataType", "-json"], text=True + ) + device_info = json.loads(data) + return device_info + except Exception as e: + print(f"Could not retrieve GPU information: {e}") + return None + + +def test_model_loading(model_path): + try: + console.print("[bold green]Loading model...") + model, processor = load(model_path, trust_remote_code=True) + config = load_config(model_path, trust_remote_code=True) + console.print("[bold green]✓[/] Model loaded successfully") + return model, processor, config, False + except Exception as e: + console.print(f"[bold red]✗[/] Failed to load model: {str(e)}") + return None, None, None, True + + +def test_generation( + model, processor, config, model_path, test_inputs, vision_language=True +): + try: + test_type = "vision-language" if vision_language else "language-only" + console.print(f"[bold yellow]Testing {test_type} generation...") + + prompt = ( + test_inputs["prompt"] + if vision_language + else test_inputs["language_only_prompt"] + ) + num_images = len(test_inputs["image"]) if vision_language else 0 + + formatted_prompt = apply_chat_template( + processor, config, prompt, num_images=num_images + ) + + generate_args = { + "model": model, + "processor": processor, + "prompt": formatted_prompt, + "verbose": True, + **test_inputs["kwargs"], + } + if vision_language: + generate_args["image"] = test_inputs["image"] + + output = generate(**generate_args) + + # Deepseek-vl2-tiny outputs are empty on VLM generation + # Paligemma outputs are empty on language-only generation + # So we skip the assertion for these models + if ("deepseek-vl2-tiny" not in model_path and vision_language) or ( + "paligemma" not in model_path and not vision_language + ): + assert isinstance(output, str) and len(output) > 0 + + console.print(f"[bold green]✓[/] {test_type} generation successful") + return False + except Exception as e: + console.print(f"[bold red]✗[/] {test_type} generation failed: {str(e)}") + return True + + +def main(): + args = parse_args() + + # Load models list + with open(args.models_file, "r", encoding="utf-8") as f: + models = [line.strip() for line in f.readlines()] + + # Test inputs dictionary + test_inputs = { + "image": args.image, + "prompt": args.prompt, + "language_only_prompt": args.language_only_prompt, + "kwargs": { + "temp": args.temperature, + "max_tokens": args.max_tokens, + }, + } + + results = [] + + for model_path in tqdm(models): + console.print(Panel(f"Testing {model_path}", style="bold blue")) + + # Run tests + model, processor, config, error = test_model_loading(model_path) + + if not error and model: + print("\n") + # Test vision-language generation + error |= test_generation( + model, processor, config, model_path, test_inputs, vision_language=True + ) + + print("\n") + + # Clear cache and reset peak memory for next test + mx.metal.clear_cache() + mx.metal.reset_peak_memory() + + # Test language-only generation + error |= test_generation( + model, processor, config, model_path, test_inputs, vision_language=False + ) + print("\n") + + console.print("[bold blue]Cleaning up...") + del model, processor + mx.metal.clear_cache() + mx.metal.reset_peak_memory() + console.print("[bold green]✓[/] Cleanup complete\n") + results.append( + f"[bold {'green' if not error else 'red'}]{'✓' if not error else '✗'}[/] {model_path}" + ) + + print("\n") + success = all(result.startswith("[bold green]") for result in results) + panel_style = "bold green" if success else "bold red" + console.print(Panel("\n".join(results), title="Results", style=panel_style)) + console.print( + f"[bold {'green' if success else 'red'}]{'All' if success else 'Some'} models tested {'successfully' if success else 'failed to test'}" + ) + + print("\n") + device_info = get_device_info() + console.print( + Panel( + title="System Information", + renderable=textwrap.dedent( + f"""{platform.machine() == 'arm64' and f''' + MAC OS: v{platform.mac_ver()[0]} + Python: v{sys.version.split()[0]} + MLX: v{mx.__version__} + MLX-VLM: v{__version__} + Transformers: v{transformers_version} + + Hardware: + • Chip: {device_info['SPDisplaysDataType'][0]['_name']} + • RAM: {psutil.virtual_memory().total / (1024 ** 3):.1f} GB + • CPU Cores: {psutil.cpu_count(logical=False)} + • GPU Cores: {device_info['SPDisplaysDataType'][0]['sppci_cores']} + ''' or 'Not running on Apple Silicon'}""" + ), + style="bold blue", + ) + ) + + +if __name__ == "__main__": + main() diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index ffc71e4..0e262f5 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -409,7 +409,7 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): ``` ```bash - python -m mlx_vlm.generate --model {upload_repo} --max-tokens 100 --temp 0.0 + python -m mlx_vlm.generate --model {upload_repo} --max-tokens 100 --temp 0.0 --prompt "Describe this image." --image ``` """ ) @@ -666,7 +666,6 @@ def convert( revision: Optional[str] = None, dequantize: bool = False, skip_vision: bool = False, - skip_vision_non_divisible: bool = False, trust_remote_code: bool = True, ): print("[INFO] Loading") @@ -686,7 +685,7 @@ def convert( print("[INFO] Quantizing") model.load_weights(list(weights.items())) weights, config = quantize_model( - model, config, q_group_size, q_bits, skip_vision, skip_vision_non_divisible + model, config, q_group_size, q_bits, skip_vision ) if dequantize: @@ -838,6 +837,7 @@ def generate_step( model: nn.Module, pixel_values, mask, + *, max_tokens: int = 256, temp: float = 0.0, repetition_penalty: Optional[float] = None, @@ -1009,7 +1009,6 @@ def stream_generate( if not image: input_ids = prompt_tokens[None, :] pixel_values = mask = None - kwargs = {} else: inputs = prepare_inputs( processor, image, prompt, image_token_index, resize_shape @@ -1017,11 +1016,12 @@ def stream_generate( input_ids = inputs["input_ids"] pixel_values = inputs["pixel_values"] mask = inputs["attention_mask"] - kwargs = { + data_kwargs = { k: v for k, v in inputs.items() if k not in ["input_ids", "pixel_values", "attention_mask"] } + kwargs.update(data_kwargs) detokenizer = processor.detokenizer detokenizer.reset() diff --git a/mlx_vlm/version.py b/mlx_vlm/version.py index c11f861..569b121 100644 --- a/mlx_vlm/version.py +++ b/mlx_vlm/version.py @@ -1 +1 @@ -__version__ = "0.1.9" +__version__ = "0.1.10"