Skip to content

Commit

Permalink
!
Browse files Browse the repository at this point in the history
  • Loading branch information
jamjamjon committed Sep 21, 2024
1 parent a06b0d8 commit affffbf
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 120 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
- **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10)
- **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
- **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569)
- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World)
- **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242)

<details>
<summary>Click to expand Supported Models</summary>
Expand Down Expand Up @@ -71,6 +71,9 @@
| [MODNet](https://github.com/ZHKKKe/MODNet) | Image Matting | [demo](examples/modnet) |||||
| [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) | Open-Set Detection With Language | [demo](examples/grounding-dino) ||| | |
| [Sapiens](https://github.com/facebookresearch/sapiens/tree/main) | Body Part Segmentation | [demo](examples/sapiens) ||| | |
| [Florence2](https://arxiv.org/abs/2311.06242) | a Variety of Vision Tasks | [demo](examples/florence2) ||| | |



</details>

Expand Down
6 changes: 1 addition & 5 deletions examples/florence2/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
Task::ReferringExpressionSegmentation(_) => {
let annotator = annotator
.clone()
.with_polygons_alpha(200)
.with_saveout("Referring-Expression-Segmentation");
annotator.annotate(&xs, ys_);
}
Task::RegionToSegmentation(..) => {
let annotator = annotator
.clone()
.with_polygons_alpha(200)
.with_saveout("Region-To-Segmentation");
let annotator = annotator.clone().with_saveout("Region-To-Segmentation");
annotator.annotate(&xs, ys_);
}
Task::OcrWithRegion => {
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! - **YOLO Models**: [YOLOv5](https://github.com/ultralytics/yolov5), [YOLOv6](https://github.com/meituan/YOLOv6), [YOLOv7](https://github.com/WongKinYiu/yolov7), [YOLOv8](https://github.com/ultralytics/ultralytics), [YOLOv9](https://github.com/WongKinYiu/yolov9), [YOLOv10](https://github.com/THU-MIG/yolov10)
//! - **SAM Models**: [SAM](https://github.com/facebookresearch/segment-anything), [SAM2](https://github.com/facebookresearch/segment-anything-2), [MobileSAM](https://github.com/ChaoningZhang/MobileSAM), [EdgeSAM](https://github.com/chongzhou96/EdgeSAM), [SAM-HQ](https://github.com/SysCV/sam-hq), [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM)
//! - **Vision Models**: [RTDETR](https://arxiv.org/abs/2304.08069), [RTMO](https://github.com/open-mmlab/mmpose/tree/main/projects/rtmo), [DB](https://arxiv.org/abs/1911.08947), [SVTR](https://arxiv.org/abs/2205.00159), [Depth-Anything-v1-v2](https://github.com/LiheYoung/Depth-Anything), [DINOv2](https://github.com/facebookresearch/dinov2), [MODNet](https://github.com/ZHKKKe/MODNet), [Sapiens](https://arxiv.org/abs/2408.12569)
//! - **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World)
//! - **Vision-Language Models**: [CLIP](https://github.com/openai/CLIP), [BLIP](https://arxiv.org/abs/2201.12086), [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [YOLO-World](https://github.com/AILab-CVC/YOLO-World), [Florence2](https://arxiv.org/abs/2311.06242)
//!
//! # Examples
//!
Expand Down
216 changes: 103 additions & 113 deletions src/models/florence2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ pub struct Florence2 {
pub encoder: OrtEngine,
pub decoder: OrtEngine,
pub decoder_merged: OrtEngine,
pub height: MinOptMax,
pub width: MinOptMax,
pub batch: MinOptMax,
height: MinOptMax,
width: MinOptMax,
batch: MinOptMax,
tokenizer: Tokenizer,
max_length: usize,
quantizer: Quantizer,
Expand Down Expand Up @@ -97,28 +97,27 @@ impl Florence2 {
xs: &[DynamicImage],
tasks: &[Task],
) -> Result<BTreeMap<Task, Vec<Y>>> {
// encode batch images
let mut ys: BTreeMap<Task, Vec<Y>> = BTreeMap::new();

// encode images
let image_embeddings = self.encode_images(xs)?;

// note: the length of xs is not always equal to batch size
self.batch.update(xs.len() as isize);

// tasks loop
// build pb
let pb = build_progress_bar(
tasks.len() as u64,
" Working On",
None,
crate::PROGRESS_BAR_STYLE_CYAN_2,
)?;

let mut ys: BTreeMap<Task, Vec<Y>> = BTreeMap::new();
// tasks
for task in tasks.iter() {
// update pb
pb.inc(1);
pb.set_message(format!("{:?}", task));

let mut ys_task: Vec<Y> = Vec::new();

// construct prompt and encode
let input_ids = self
.encode_prompt(task)?
Expand All @@ -129,103 +128,98 @@ impl Florence2 {
// run
let texts = self.run_batch(&image_embeddings, &text_embeddings)?;

// postprocess
for batch in 0..self.batch() {
// image size
let image_width = xs[batch].width() as usize;
let image_height = xs[batch].height() as usize;

// texts cleanup
let text = texts[batch]
.as_str()
.replace("<s>", "")
.replace("</s>", "")
.replace("<pad>", "");

// cope with each task
if let Task::Caption(_) | Task::Ocr = task {
ys_task.push(Y::default().with_texts(&[text]));
} else {
let elems = Self::loc_parse(&text)?;
match task {
Task::RegionToCategory(..) | Task::RegionToDescription(..) => {
let text = elems[0][0].clone(); // extract text only
ys_task.push(Y::default().with_texts(&[text]));
}
Task::ObjectDetection
| Task::OpenSetDetection(_)
| Task::DenseRegionCaption
| Task::CaptionToPhraseGrounding(_) => {
let y_bboxes: Vec<Bbox> = elems
.par_iter()
.enumerate()
.flat_map(|(i, elem)| {
let name = &elem[0];
let y_bboxes: Vec<Bbox> = Self::process_bboxes(
&elem[1..],
&self.quantizer,
image_width,
image_height,
Some((name, i)),
);
y_bboxes
})
.collect();

ys_task.push(Y::default().with_bboxes(&y_bboxes));
}
Task::RegionProposal => {
let y_bboxes: Vec<Bbox> = Self::process_bboxes(
&elems[0],
&self.quantizer,
image_width,
image_height,
None,
);

ys_task.push(Y::default().with_bboxes(&y_bboxes));
}

Task::ReferringExpressionSegmentation(_)
| Task::RegionToSegmentation(..) => {
let points = Self::process_polygons(
&elems[0],
&self.quantizer,
image_width,
image_height,
);

ys_task.push(Y::default().with_polygons(&[
Polygon::default().with_points(&points).with_id(0),
]));
}
Task::OcrWithRegion => {
let y_polygons: Vec<Polygon> = elems
.par_iter()
.enumerate()
.map(|(i, elem)| {
let text = &elem[0];
let points = Self::process_polygons(
&elem[1..],
&self.quantizer,
image_width,
image_height,
);

Polygon::default()
.with_name(text)
.with_points(&points)
.with_id(i as _)
})
.collect();

ys_task.push(Y::default().with_polygons(&y_polygons));
}

_ => anyhow::bail!("Unsupported Florence2 task."),
};
}
}
// tasks iteration
let ys_task = (0..self.batch())
.into_par_iter()
.map(|batch| {
// image size
let image_width = xs[batch].width() as usize;
let image_height = xs[batch].height() as usize;

// texts cleanup
let text = texts[batch]
.as_str()
.replace("<s>", "")
.replace("</s>", "")
.replace("<pad>", "");

// postprocess
let mut y = Y::default();
if let Task::Caption(_) | Task::Ocr = task {
y = y.with_texts(&[text]);
} else {
let elems = Self::loc_parse(&text)?;
match task {
Task::RegionToCategory(..) | Task::RegionToDescription(..) => {
let text = elems[0][0].clone();
y = y.with_texts(&[text]);
}
Task::ObjectDetection
| Task::OpenSetDetection(_)
| Task::DenseRegionCaption
| Task::CaptionToPhraseGrounding(_) => {
let y_bboxes: Vec<Bbox> = elems
.par_iter()
.enumerate()
.flat_map(|(i, elem)| {
Self::process_bboxes(
&elem[1..],
&self.quantizer,
image_width,
image_height,
Some((&elem[0], i)),
)
})
.collect();
y = y.with_bboxes(&y_bboxes);
}
Task::RegionProposal => {
let y_bboxes: Vec<Bbox> = Self::process_bboxes(
&elems[0],
&self.quantizer,
image_width,
image_height,
None,
);
y = y.with_bboxes(&y_bboxes);
}
Task::ReferringExpressionSegmentation(_)
| Task::RegionToSegmentation(..) => {
let points = Self::process_polygons(
&elems[0],
&self.quantizer,
image_width,
image_height,
);
y = y.with_polygons(&[Polygon::default()
.with_points(&points)
.with_id(0)]);
}
Task::OcrWithRegion => {
let y_polygons: Vec<Polygon> = elems
.par_iter()
.enumerate()
.map(|(i, elem)| {
let points = Self::process_polygons(
&elem[1..],
&self.quantizer,
image_width,
image_height,
);
Polygon::default()
.with_name(&elem[0])
.with_points(&points)
.with_id(i as _)
})
.collect();
y = y.with_polygons(&y_polygons);
}
_ => anyhow::bail!("Unsupported Florence2 task."),
};
}
Ok(y)
})
.collect::<Result<Vec<Y>>>()?;

ys.insert(task.clone(), ys_task);
}
Expand Down Expand Up @@ -264,19 +258,14 @@ impl Florence2 {

let encoder_k0 = decoder_outputs[3].clone();
let encoder_v0 = decoder_outputs[4].clone();

let encoder_k1 = decoder_outputs[7].clone();
let encoder_v1 = decoder_outputs[8].clone();

let encoder_k2 = decoder_outputs[11].clone();
let encoder_v2 = decoder_outputs[12].clone();

let encoder_k3 = decoder_outputs[15].clone();
let encoder_v3 = decoder_outputs[16].clone();

let encoder_k4 = decoder_outputs[19].clone();
let encoder_v4 = decoder_outputs[20].clone();

let encoder_k5 = decoder_outputs[23].clone();
let encoder_v5 = decoder_outputs[24].clone();

Expand All @@ -285,8 +274,9 @@ impl Florence2 {

// save last batch tokens
let mut last_tokens: Vec<f32> = vec![0.; self.batch()];

let mut logits_sampler = LogitsSampler::new();

// generate
for _ in 0..self.max_length {
let logits = &decoder_outputs["logits"];
let decoder_k0 = &decoder_outputs[1];
Expand All @@ -302,7 +292,7 @@ impl Florence2 {
let decoder_k5 = &decoder_outputs[21];
let decoder_v5 = &decoder_outputs[22];

// Decode each token for each batch
// decode each token for each batch
for (i, logit) in logits.axis_iter(Axis(0)).enumerate() {
if !finished[i] {
let token_id = logits_sampler.decode(
Expand Down

0 comments on commit affffbf

Please sign in to comment.