Skip to content

Commit

Permalink
Merge pull request #1 from Passiolife/passiofy
Browse files Browse the repository at this point in the history
Passiofy
  • Loading branch information
kells1986 authored Jun 25, 2023
2 parents 81b634d + c84016f commit 3856342
Show file tree
Hide file tree
Showing 14 changed files with 713 additions and 87 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: minLoRAplus Workflow

on:
push:

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
- name: Test with pytest
run: |
pytest
46 changes: 33 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,43 +1,63 @@
# minLoRA
# minLoRAplus

A fork of the excellent [minLoRA](https://github.com/cccntu/minLoRA) repo by [cccntu](https://github.com/cccntu), with functionality added for Passio use cases

A minimal, but versatile PyTorch re-implementation of [LoRA](https://github.com/microsoft/LoRA). In only ~100 lines of code, minLoRA supports the following features:
A minimal, but versatile PyTorch re-implementation of [LoRA](https://github.com/microsoft/LoRA).

### Features
In only ~100 lines of code, minLoRA supports the following features:

- Functional, no need to modify the model definition
- Works everywhere, as long as you use `torch.nn.Module`
- PyTorch native, uses PyTorch's `torch.nn.utils.parametrize` to do all the heavy lifting
- Easily extendable, you can add your own LoRA parameterization
- Supports training, inference, and inference with multiple LoRA models

## Demo
### Plus:
- Finetune any [timm](https://github.com/huggingface/pytorch-image-models) models using LoRA:
- `ViTClassifier` module to train any `timm` model with transformer architecture
- `CNNClassifier` module to train any `timm` CNN.


## Basic Usage

- `demo.ipynb` shows the basic usage of the library
- `advanced_usage.ipynb` shows how you can add LoRA to other layers such as embedding, and how to tie weights

## Examples

- Finetuning GPT using LoRA + nanoGPT: https://github.com/cccntu/LoRAnanoGPT/pull/1/files
- Sample training code for `timm` ViT: `vit_trainer.ipynb`
- Sample training code for `timm` CNN: `cnn_trainer.ipynb`

## Running Trainers

If you just want to run the notebooks make sure the requirements are installed first:

```
pip install -r requirements.txt
```

Then run `jupyter notebook`.

## Library Installation

If you want to `import minlora` into your project:
If you want to `import minloraplus` into your project:

```
git clone https://github.com/cccntu/minLoRA.git
cd minLoRA
git clone https://github.com/Passiolife/minLoRAplus.git
cd minLoRAplus
pip install -e .
```

## Usage

```python
import torch
from minlora import add_lora, apply_to_lora, disable_lora, enable_lora, get_lora_params, merge_lora, name_is_lora, remove_lora, load_multiple_lora, select_lora
from minloraplus import add_lora, apply_to_lora, disable_lora, enable_lora, get_lora_params, merge_lora, name_is_lora,
remove_lora, load_multiple_lora, select_lora
```

### Training a model with minLoRA
### Training a model with minLoRAplus

```python
model = torch.nn.Linear(in_features=5, out_features=3)
Expand All @@ -58,7 +78,7 @@ optimizer = torch.optim.AdamW(parameters, lr=1e-3)
lora_state_dict = get_lora_state_dict(model)
```

### Loading and Inferencing with minLoRA
### Loading and Inferencing with minLoRAplus

```python
# Step 1: Add LoRA to your model
Expand Down Expand Up @@ -96,8 +116,8 @@ Y2 = select_lora(model, 2)(x)

- [microsoft/LoRA](https://github.com/microsoft/LoRA) has the official implementation of LoRA, in PyTorch
- [karpathy/minGPT](https://github.com/karpathy/minGPT) the structure of the repo is adapted from minGPT

- [cccntu/minLoRA](https://github.com/cccntu/minLoRA) the original repo minLoRAplus is based on

### TODO
- [x] A notebook to show how to configure LoRA parameters
- [x] Real training & inference examples
- [ ] Add pytorch-lightning training example for GPT
- [ ] Add conversion functions so vision models can be run in the Passio MindsEye app
103 changes: 71 additions & 32 deletions advanced_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@
"cell_type": "code",
"execution_count": 1,
"id": "92a4ce86",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-22T10:08:07.016198Z",
"start_time": "2023-06-22T10:08:06.472063Z"
}
},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"import torch\n",
"from minlora import (\n",
"from minloraplus import (\n",
" LoRAParametrization,\n",
" add_lora,\n",
" apply_to_lora,\n",
" merge_lora,\n",
" apply_to_lora\n",
")\n",
"\n",
"from torch import nn\n",
"\n",
"_ = torch.set_grad_enabled(False)"
Expand All @@ -34,30 +40,16 @@
"cell_type": "code",
"execution_count": 2,
"id": "ec04a954",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-22T10:08:08.418456Z",
"start_time": "2023-06-22T10:08:08.413440Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"Sequential(\n",
" (0): ParametrizedEmbedding(\n",
" 3, 2\n",
" (parametrizations): ModuleDict(\n",
" (weight): ParametrizationList(\n",
" (0): LoRAParametrization()\n",
" )\n",
" )\n",
" )\n",
" (1): ParametrizedLinear(\n",
" in_features=2, out_features=3, bias=True\n",
" (parametrizations): ModuleDict(\n",
" (weight): ParametrizationList(\n",
" (0): LoRAParametrization()\n",
" )\n",
" )\n",
" )\n",
")"
]
"text/plain": "Sequential(\n (0): ParametrizedEmbedding(\n 3, 2\n (parametrizations): ModuleDict(\n (weight): ParametrizationList(\n (0): LoRAParametrization()\n )\n )\n )\n (1): ParametrizedLinear(\n in_features=2, out_features=3, bias=True\n (parametrizations): ModuleDict(\n (weight): ParametrizationList(\n (0): LoRAParametrization()\n )\n )\n )\n)"
},
"execution_count": 2,
"metadata": {},
Expand Down Expand Up @@ -99,7 +91,12 @@
"cell_type": "code",
"execution_count": 3,
"id": "5b649fe9",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-22T10:08:10.284568Z",
"start_time": "2023-06-22T10:08:10.282754Z"
}
},
"outputs": [
{
"name": "stdout",
Expand All @@ -123,7 +120,12 @@
"cell_type": "code",
"execution_count": 4,
"id": "c7d00069",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-22T10:08:10.716836Z",
"start_time": "2023-06-22T10:08:10.712183Z"
}
},
"outputs": [
{
"name": "stdout",
Expand All @@ -146,7 +148,12 @@
"cell_type": "code",
"execution_count": 5,
"id": "3e6e9bfe",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-22T10:08:11.166360Z",
"start_time": "2023-06-22T10:08:11.161815Z"
}
},
"outputs": [
{
"name": "stdout",
Expand All @@ -166,7 +173,12 @@
"cell_type": "code",
"execution_count": 6,
"id": "d02d2819",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-22T10:08:11.639767Z",
"start_time": "2023-06-22T10:08:11.633738Z"
}
},
"outputs": [],
"source": [
"# to tie the weights, we need to add lora to the embedding layer as well\n",
Expand All @@ -189,7 +201,11 @@
"execution_count": 7,
"id": "f34b1f1d",
"metadata": {
"lines_to_next_cell": 2
"lines_to_next_cell": 2,
"ExecuteTime": {
"end_time": "2023-06-22T10:08:12.113884Z",
"start_time": "2023-06-22T10:08:12.110388Z"
}
},
"outputs": [
{
Expand All @@ -214,7 +230,12 @@
"cell_type": "code",
"execution_count": 8,
"id": "34ada79d",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-22T10:08:12.539045Z",
"start_time": "2023-06-22T10:08:12.533231Z"
}
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -242,7 +263,12 @@
"cell_type": "code",
"execution_count": 9,
"id": "51c20159",
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2023-06-22T10:08:12.965236Z",
"start_time": "2023-06-22T10:08:12.961597Z"
}
},
"outputs": [],
"source": [
"# we can put the logic of tying the weights in a function\n",
Expand All @@ -261,7 +287,11 @@
"execution_count": 10,
"id": "9d5cc0b4",
"metadata": {
"lines_to_next_cell": 0
"lines_to_next_cell": 0,
"ExecuteTime": {
"end_time": "2023-06-22T10:08:13.531372Z",
"start_time": "2023-06-22T10:08:13.528493Z"
}
},
"outputs": [],
"source": [
Expand All @@ -275,6 +305,15 @@
"# even after merging lora, the weights are still the same\n",
"assert torch.allclose(model[0].weight, model[1].weight)"
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false
}
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 3856342

Please sign in to comment.