Skip to content

Commit

Permalink
draft for picture description models
Browse files Browse the repository at this point in the history
Signed-off-by: Michele Dolfi <[email protected]>
  • Loading branch information
dolfim-ibm committed Nov 6, 2024
1 parent 6c22cba commit e1cba8a
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 2 deletions.
48 changes: 46 additions & 2 deletions docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from enum import Enum
from pathlib import Path
from typing import List, Literal, Optional, Union
from typing import Annotated, Any, Dict, List, Literal, Optional, Union

from pydantic import BaseModel, ConfigDict, Field
from pydantic import AnyUrl, BaseModel, ConfigDict, Field


class TableFormerMode(str, Enum):
Expand Down Expand Up @@ -61,6 +61,46 @@ class TesseractOcrOptions(OcrOptions):
)


class PicDescBaseOptions(BaseModel):
kind: str
batch_size: int = 8
scale: float = 2

bitmap_area_threshold: float = (
0.2 # percentage of the area for a bitmap to processed with the models
)


class PicDescApiOptions(PicDescBaseOptions):
kind: Literal["api"] = "api"

url: AnyUrl = AnyUrl("")
headers: Dict[str, str] = {}
params: Dict[str, Any] = {}
timeout: float = 20

llm_prompt: str = ""
provenance: str = ""


class PicDescVllmOptions(PicDescBaseOptions):
kind: Literal["vllm"] = "vllm"

# For more example parameters see https://docs.vllm.ai/en/latest/getting_started/examples/offline_inference_vision_language.html

# Parameters for LLaVA-1.6/LLaVA-NeXT
llm_name: str = "llava-hf/llava-v1.6-mistral-7b-hf"
llm_prompt: str = "[INST] <image>\nDescribe the image in details. [/INST]"
llm_extra: Dict[str, Any] = dict(max_model_len=8192)

# Parameters for Phi-3-Vision
# llm_name: str = "microsoft/Phi-3-vision-128k-instruct"
# llm_prompt: str = "<|user|>\n<|image_1|>\nDescribe the image in details.<|end|>\n<|assistant|>\n"
# llm_extra: Dict[str, Any] = dict(max_num_seqs=5, trust_remote_code=True)

sampling_params: Dict[str, Any] = dict(max_tokens=64, seed=42)


class PipelineOptions(BaseModel):
create_legacy_output: bool = (
True # This defautl will be set to False on a future version of docling
Expand All @@ -71,11 +111,15 @@ class PdfPipelineOptions(PipelineOptions):
artifacts_path: Optional[Union[Path, str]] = None
do_table_structure: bool = True # True: perform table structure extraction
do_ocr: bool = True # True: perform OCR, replace programmatic PDF text
do_picture_description: bool = False

table_structure_options: TableStructureOptions = TableStructureOptions()
ocr_options: Union[EasyOcrOptions, TesseractCliOcrOptions, TesseractOcrOptions] = (
Field(EasyOcrOptions(), discriminator="kind")
)
picture_description_options: Annotated[
Union[PicDescApiOptions, PicDescVllmOptions], Field(discriminator="kind")
] = PicDescApiOptions() # TODO: needs defaults or optional

images_scale: float = 1.0
generate_page_images: bool = False
Expand Down
99 changes: 99 additions & 0 deletions docling/models/pic_description_api_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import base64
import io
import logging
from typing import List, Optional

import httpx
from docling_core.types.doc import PictureItem
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData,
)
from pydantic import BaseModel, ConfigDict

from docling.datamodel.pipeline_options import PicDescApiOptions
from docling.models.pic_description_base_model import PictureDescriptionBaseModel

_log = logging.getLogger(__name__)


class ChatMessage(BaseModel):
role: str
content: str


class ResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: str


class ResponseUsage(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int


class ApiResponse(BaseModel):
model_config = ConfigDict(
protected_namespaces=(),
)

id: str
model: Optional[str] = None # returned bu openai
choices: List[ResponseChoice]
created: int
usage: ResponseUsage


class PictureDescriptionApiModel(PictureDescriptionBaseModel):

def __init__(self, enabled: bool, options: PicDescApiOptions):
super().__init__(enabled=enabled, options=options)
self.options: PicDescApiOptions

def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
assert picture.image is not None

img_io = io.BytesIO()
picture.image.pil_image.save(img_io, "PNG")

image_base64 = base64.b64encode(img_io.getvalue()).decode("utf-8")

messages = [
{
"role": "user",
"content": [
{
"type": "text",
"text": self.options.llm_prompt,
},
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_base64}"},
},
],
}
]

payload = {
"messages": messages,
**self.options.params,
}

r = httpx.post(
str(self.options.url),
headers=self.options.headers,
json=payload,
timeout=self.options.timeout,
)
if not r.is_success:
_log.error(f"Error calling the API. Reponse was {r.text}")
r.raise_for_status()

api_resp = ApiResponse.model_validate_json(r.text)
generated_text = api_resp.choices[0].message.content.strip()

return PictureDescriptionData(
provenance=self.options.provenance,
text=generated_text,
)
46 changes: 46 additions & 0 deletions docling/models/pic_description_base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import logging
from pathlib import Path
from typing import Any, Iterable

from docling_core.types.doc import (
DoclingDocument,
NodeItem,
PictureClassificationClass,
PictureItem,
)
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData,
)

from docling.datamodel.pipeline_options import PicDescBaseOptions
from docling.models.base_model import BaseEnrichmentModel


class PictureDescriptionBaseModel(BaseEnrichmentModel):

def __init__(self, enabled: bool, options: PicDescBaseOptions):
self.enabled = enabled
self.options = options
self.provenance = "TODO"

def is_processable(self, doc: DoclingDocument, element: NodeItem) -> bool:
# TODO: once the image classifier is active, we can differentiate among image types
return self.enabled and isinstance(element, PictureItem)

def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
raise NotImplemented

def __call__(
self, doc: DoclingDocument, element_batch: Iterable[NodeItem]
) -> Iterable[Any]:
if not self.enabled:
return

for element in element_batch:
assert isinstance(element, PictureItem)
assert element.image is not None

annotation = self._annotate_image(element)
element.annotations.append(annotation)

yield element
59 changes: 59 additions & 0 deletions docling/models/pic_description_vllm_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import json
from typing import List

from docling_core.types.doc import PictureItem
from docling_core.types.doc.document import ( # TODO: move import to docling_core.types.doc
PictureDescriptionData,
)

from docling.datamodel.pipeline_options import PicDescVllmOptions
from docling.models.pic_description_base_model import PictureDescriptionBaseModel


class PictureDescriptionVllmModel(PictureDescriptionBaseModel):

def __init__(self, enabled: bool, options: PicDescVllmOptions):
super().__init__(enabled=enabled, options=options)
self.options: PicDescVllmOptions

if self.enabled:
raise NotImplemented

if self.enabled:
try:
from vllm import LLM, SamplingParams # type: ignore
except ImportError:
raise ImportError(
"VLLM is not installed. Please install Docling with the required extras `pip install docling[vllm]`."
)

self.sampling_params = SamplingParams(**self.options.sampling_params) # type: ignore
self.llm = LLM(model=self.options.llm_name, **self.options.llm_extra) # type: ignore

# Generate a stable hash from the extra parameters
def create_hash(t):
return ""

params_hash = create_hash(
json.dumps(self.options.llm_extra, sort_keys=True)
+ json.dumps(self.options.sampling_params, sort_keys=True)
)
self.provenance = f"{self.options.llm_name}-{params_hash[:8]}"

def _annotate_image(self, picture: PictureItem) -> PictureDescriptionData:
assert picture.image is not None

from vllm import RequestOutput

inputs = [
{
"prompt": self.options.llm_prompt,
"multi_modal_data": {"image": picture.image.pil_image},
}
]
outputs: List[RequestOutput] = self.llm.generate( # type: ignore
inputs, sampling_params=self.sampling_params # type: ignore
)

generated_text = outputs[0].outputs[0].text
return PictureDescriptionData(provenance=self.provenance, text=generated_text)
29 changes: 29 additions & 0 deletions docling/pipeline/standard_pdf_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from docling.datamodel.pipeline_options import (
EasyOcrOptions,
PdfPipelineOptions,
PicDescApiOptions,
PicDescVllmOptions,
TesseractCliOcrOptions,
TesseractOcrOptions,
)
Expand All @@ -23,6 +25,9 @@
PagePreprocessingModel,
PagePreprocessingOptions,
)
from docling.models.pic_description_api_model import PictureDescriptionApiModel
from docling.models.pic_description_base_model import PictureDescriptionBaseModel
from docling.models.pic_description_vllm_model import PictureDescriptionVllmModel
from docling.models.table_structure_model import TableStructureModel
from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel
from docling.models.tesseract_ocr_model import TesseractOcrModel
Expand Down Expand Up @@ -83,8 +88,15 @@ def __init__(self, pipeline_options: PdfPipelineOptions):
PageAssembleModel(options=PageAssembleOptions(keep_images=keep_images)),
]

# Picture description model
if (pic_desc_model := self.get_pic_description_model()) is None:
raise RuntimeError(
f"The specified picture description kind is not supported: {pipeline_options.picture_description_options.kind}."
)

self.enrichment_pipe = [
# Other models working on `NodeItem` elements in the DoclingDocument
pic_desc_model,
]

@staticmethod
Expand Down Expand Up @@ -120,6 +132,23 @@ def get_ocr_model(self) -> Optional[BaseOcrModel]:
)
return None

def get_pic_description_model(self) -> Optional[PictureDescriptionBaseModel]:
if isinstance(
self.pipeline_options.picture_description_options, PicDescApiOptions
):
return PictureDescriptionApiModel(
enabled=self.pipeline_options.do_picture_description,
options=self.pipeline_options.picture_description_options,
)
elif isinstance(
self.pipeline_options.picture_description_options, PicDescVllmOptions
):
return PictureDescriptionVllmModel(
enabled=self.pipeline_options.do_picture_description,
options=self.pipeline_options.picture_description_options,
)
return None

def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page:
with TimeRecorder(conv_res, "page_init"):
page._backend = conv_res.input._backend.load_page(page.page_no) # type: ignore
Expand Down

0 comments on commit e1cba8a

Please sign in to comment.