Skip to content

Commit

Permalink
mlx-image (#1237)
Browse files Browse the repository at this point in the history
* add mlx-image

* add mlx-image

* fix for mlx-image

* nit: blank line

---------

Co-authored-by: Pedro Cuenca <[email protected]>
  • Loading branch information
riccardomusmeci and pcuenca authored Mar 14, 2024
1 parent f49e966 commit 16a782b
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/hub/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
title: Keras
- local: ml-agents
title: ML-Agents
- local: mlx-image
title: mlxim
- local: mlx
title: MLX
- local: open_clip
Expand Down
98 changes: 98 additions & 0 deletions docs/hub/mlx-image.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Using mlxim at Hugging Face

[`mlxim`](https://github.com/riccardomusmeci/mlx-image) is an image models library built on Apple [MLX](https://github.com/ml-explore/mlx). It tries to replicate the great [timm](https://github.com/huggingface/pytorch-image-models) library from Ross Wightman, but for MLX models.


## Exploring MLX on the Hub

You can find `mlxim` models by filtering using the `mlxim` library name, like in [this query](https://huggingface.co/models?library=mlxim&sort=trending).
There's also an open [mlx-vision](https://huggingface.co/mlx-vision) space for contributors converting and publishing weights for MLX format.

Thanks to MLX Hugging Face Hub integration, you can load MLX models with a few lines of code.

## Installation

```bash
pip install mlx-image
```

## Models

Model weights are available on the [`mlx-vision`](https://huggingface.co/mlx-vision) community on HuggingFace.

To load a model with pre-trained weights:
```python
from mlxim.model import create_model

# loading weights from HuggingFace (https://huggingface.co/mlx-vision/resnet18-mlxim)
model = create_model("resnet18") # pretrained weights loaded from HF

# loading weights from local file
model = create_model("resnet18", weights="path/to/resnet18/model.safetensors")
```

To list all available models:

```python
from mlxim.model import list_models
list_models()
```
> [!WARNING]
> As of today (2024-03-08) mlx does not support `group` param for nn.Conv2d. Therefore, architectures such as `resnext`, `regnet` or `efficientnet` are not yet supported in `mlxim`.
## ImageNet-1K Results

Go to [results-imagenet-1k.csv](https://github.com/riccardomusmeci/mlx-image/blob/main/results/results-imagenet-1k.csv) to check every model converted to `mlxim` and its performance on ImageNet-1K with different settings.

> **TL;DR** performance is comparable to the original models from PyTorch implementations.

## Similarity to PyTorch and other familiar tools

`mlxim` tries to be as close as possible to PyTorch:
- `DataLoader` -> you can define your own `collate_fn` and also use `num_workers` to speed up data loading
- `Dataset` -> `mlxim` already supports `LabelFolderDataset` (the good and old PyTorch `ImageFolder`) and `FolderDataset` (a generic folder with images in it)
- `ModelCheckpoint` -> keeps track of the best model and saves it to disk (similar to PyTorchLightning). It also suggests early stopping

## Training

Training is similar to PyTorch. Here's an example of how to train a model:

```python
import mlx.nn as nn
import mlx.optimizers as optim
from mlxim.model import create_model
from mlxim.data import LabelFolderDataset, DataLoader

train_dataset = LabelFolderDataset(
root_dir="path/to/train",
class_map={0: "class_0", 1: "class_1", 2: ["class_2", "class_3"]}
)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=32,
shuffle=True,
num_workers=4
)
model = create_model("resnet18") # pretrained weights loaded from HF
optimizer = optim.Adam(learning_rate=1e-3)

def train_step(model, inputs, targets):
logits = model(inputs)
loss = mx.mean(nn.losses.cross_entropy(logits, target))
return loss

model.train()
for epoch in range(10):
for batch in train_loader:
x, target = batch
train_step_fn = nn.value_and_grad(model, train_step)
loss, grads = train_step_fn(x, target)
optimizer.update(model, grads)
mx.eval(model.state, optimizer.state)
```

## Additional Resources

* [mlxim repository](https://github.com/riccardomusmeci/mlx-image)
* [All mlxim models on Hub](https://huggingface.co/models?library=mlxim&sort=trending)

0 comments on commit 16a782b

Please sign in to comment.