Skip to content

Commit

Permalink
feat: Working on The Finals
Browse files Browse the repository at this point in the history
  • Loading branch information
Flowtter committed Jan 27, 2024
1 parent 60d799f commit d95f094
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 21 deletions.
23 changes: 13 additions & 10 deletions crispy-api/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@

ENCODERS_BY_TYPE[ObjectId] = str

neural_network = NeuralNetwork(GAME)

if GAME == SupportedGames.OVERWATCH:
neural_network.load(os.path.join(ASSETS, "overwatch.npy"))
elif GAME == SupportedGames.VALORANT:
neural_network.load(os.path.join(ASSETS, "valorant.npy"))
elif GAME == SupportedGames.CSGO2:
neural_network.load(os.path.join(ASSETS, "csgo2.npy"))
else:
raise ValueError(f"game {GAME} not supported")
neural_network = None

if GAME != SupportedGames.THEFINALS:
neural_network = NeuralNetwork(GAME)

if GAME == SupportedGames.OVERWATCH:
neural_network.load(os.path.join(ASSETS, "overwatch.npy"))
elif GAME == SupportedGames.VALORANT:
neural_network.load(os.path.join(ASSETS, "valorant.npy"))
elif GAME == SupportedGames.CSGO2:
neural_network.load(os.path.join(ASSETS, "csgo2.npy"))
else:
raise ValueError(f"game {GAME} not supported")


logging.getLogger("PIL").setLevel(logging.ERROR)
Expand Down
3 changes: 3 additions & 0 deletions crispy-api/api/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os

import easyocr
from starlette.config import Config

from api.tools.enums import SupportedGames
Expand Down Expand Up @@ -61,3 +62,5 @@
raise KeyError("game not found in settings.json")
if GAME.upper() not in [game.name for game in SupportedGames]:
raise ValueError(f"game {GAME} not supported")

READER = easyocr.Reader(["fr", "en"], gpu=True, verbose=False)
98 changes: 89 additions & 9 deletions crispy-api/api/models/highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,34 @@
class Box:
def __init__(
self,
offset_x: int,
x: int,
y: int,
width: int,
height: int,
shift_x: int,
stretch: bool,
from_center: bool = True,
) -> None:
"""
:param offset_x: Offset in pixels from the center of the video to the left
:param x: Offset in pixels from the left of the video or from the center if use_offset is enabled
:param y: Offset in pixels from the top of the video
:param width: Width of the box in pixels
:param height: Height of the box in pixels
:param shift_x: Shift the box by a certain amount of pixels to the right
:param stretch: Stretch the box to fit the video
:param use_offset: If enabled, x will be from the center of the video, else it will be from the left (usef)
example:
If you want to create a box at 50 px from the center on x, but shifted by 20px to the right
you would do:
Box(50, 0, 100, 100, 20)
"""
half = 720 if stretch else 960
if from_center:
half = 720 if stretch else 960
self.x = half - x + shift_x
else:
self.x = x + shift_x

self.x = half - offset_x + shift_x
self.y = y
self.width = width
self.height = height
Expand Down Expand Up @@ -93,19 +99,24 @@ async def extract_images(
post_process: Callable,
coordinates: Box,
framerate: int = 4,
save_path: str = "images",
force_extract: bool = False,
) -> bool:
"""
Extracts images from a video at a given framerate
:param post_process: Function to apply to each image
:param coordinates: Coordinates of the box to extract
:param framerate: Framerate to extract the images
:param save_path: Path to save the images
"""
if self.images_path:
if self.images_path and not force_extract:
return False
images_path = os.path.join(self.directory, "images")
images_path = os.path.join(self.directory, save_path)

if not os.path.exists(images_path):
print("creating images path at", images_path)
os.mkdir(images_path)
(
ffmpeg.input(self.path)
Expand All @@ -124,8 +135,9 @@ async def extract_images(

post_process(im).save(im_path)

self.update({"images_path": images_path})
self.save()
if save_path == "images":
self.update({"images_path": images_path})
self.save()

return True

Expand Down Expand Up @@ -220,6 +232,72 @@ def post_process(image: Image) -> Image:
post_process, Box(50, 925, 100, 100, 20, stretch), framerate=framerate
)

async def extract_the_finals_images(
self, framerate: int = 4, stretch: bool = False
) -> bool:
def is_color_close(
pixel: Tuple[int, int, int],
expected: Tuple[int, int, int],
threshold: int = 100,
) -> bool:
distance: int = (
sum((pixel[i] - expected[i]) ** 2 for i in range(len(pixel))) ** 0.5
)
return distance < threshold

def post_process_killfeed(image: Image) -> Image:
r, g, b = image.split()
for x in range(image.width):
for y in range(image.height):
if not is_color_close(
(r.getpixel((x, y)), g.getpixel((x, y)), b.getpixel((x, y))),
(12, 145, 201),
):
r.putpixel((x, y), 0)
b.putpixel((x, y), 0)
g.putpixel((x, y), 0)

im = ImageOps.grayscale(Image.merge("RGB", (r, g, b)))

final = Image.new("RGB", (250, 115))
final.paste(im, (0, 0))
return final

killfeed_state = await self.extract_images(
post_process_killfeed,
Box(1500, 75, 250, 115, 0, stretch, from_center=False),
framerate=framerate,
)

def post_process(image: Image) -> Image:
r, g, b = image.split()
for x in range(image.width):
for y in range(image.height):
if not is_color_close(
(r.getpixel((x, y)), g.getpixel((x, y)), b.getpixel((x, y))),
(255, 255, 255),
):
r.putpixel((x, y), 0)
b.putpixel((x, y), 0)
g.putpixel((x, y), 0)

im = ImageOps.grayscale(Image.merge("RGB", (r, g, b)))

final = Image.new("RGB", (200, 120))
final.paste(im, (0, 0))
return final

return (
await self.extract_images(
post_process,
Box(20, 800, 200, 120, 0, stretch, from_center=False),
framerate=framerate,
save_path="usernames",
force_extract=True,
)
and killfeed_state
)

async def extract_images_from_game(
self, game: SupportedGames, framerate: int = 4, stretch: bool = False
) -> bool:
Expand All @@ -229,8 +307,10 @@ async def extract_images_from_game(
return await self.extract_valorant_images(framerate, stretch)
elif game == SupportedGames.CSGO2:
return await self.extract_csgo2_images(framerate, stretch)
elif game == SupportedGames.THEFINALS:
return await self.extract_the_finals_images(framerate, stretch)
else:
raise NotImplementedError
raise NotImplementedError(f"game {game} not supported")

def recompile(self) -> bool:
from api.tools.utils import sanitize_dict
Expand Down
1 change: 1 addition & 0 deletions crispy-api/api/tools/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ class SupportedGames(str, Enum):
VALORANT = "valorant"
OVERWATCH = "overwatch"
CSGO2 = "csgo2"
THEFINALS = "thefinals"
28 changes: 27 additions & 1 deletion crispy-api/api/tools/setup.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
import os
import shutil
from collections import Counter
from typing import List

import ffmpeg
from PIL import Image

from api.config import SESSION, SILENCE_PATH, STRETCH
from api.config import READER, SESSION, SILENCE_PATH, STRETCH
from api.models.filter import Filter
from api.models.highlight import Highlight
from api.models.music import Music
Expand Down Expand Up @@ -104,6 +105,31 @@ async def handle_highlights(

Highlight.update_many({}, {"$set": {"job_id": None}})

if game == SupportedGames.THEFINALS:
path = os.path.join(highlight.directory, "usernames")
for highlight in new_highlights:
images = os.listdir(path)
usernames = [""] * 2
usernames_histogram: Counter = Counter()

for i in range(0, len(images), framerate):
image = images[i]
image_path = os.path.join(path, image)
result = READER.readtext(image_path)
for text in result:
if text[1].isnumeric():
continue
usernames_histogram[text[1]] += 1
two_best = usernames_histogram.most_common(2)
if two_best[0][1] >= 10 and two_best[1][1] >= 10:
usernames = [
usernames_histogram.most_common(2)[0][0],
usernames_histogram.most_common(2)[1][0],
]
break
highlight.update({"usernames": usernames})
highlight.save()

return new_highlights


Expand Down
32 changes: 31 additions & 1 deletion crispy-api/api/tools/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import numpy as np
from PIL import Image

from api.config import GAME, READER
from api.models.highlight import Highlight
from api.models.segment import Segment
from api.tools.AI.network import NeuralNetwork
from api.tools.enums import SupportedGames

logger = logging.getLogger("uvicorn")

Expand Down Expand Up @@ -54,6 +56,34 @@ def _create_query_array(
return queries


def _get_the_finals_query_array(highlight: Highlight) -> List[int]:
usernames = highlight.usernames
images = os.listdir(highlight.images_path)
images.sort()
queries = []

for i, image in enumerate(images):
image_path = os.path.join(highlight.images_path, image)

text = READER.readtext(image_path)
for word in text:
if word[1] not in usernames:
queries.append(i)
break

return queries


def _get_query_array(
neural_network: NeuralNetwork, highlight: Highlight, confidence: float
) -> List[int]:
if neural_network:
return _create_query_array(neural_network, highlight, confidence)
if GAME == SupportedGames.THEFINALS:
return _get_the_finals_query_array(highlight)
raise ValueError(f"No neural network for game {GAME} and no custom query array")


def _normalize_queries(
queries: List[int], frames_before: int, frames_after: int
) -> List[Tuple[int, int]]:
Expand Down Expand Up @@ -135,7 +165,7 @@ async def extract_segments(
:return: list of segments
"""
queries = _create_query_array(neural_network, highlight, confidence)
queries = _get_query_array(neural_network, highlight, confidence)
normalized = _normalize_queries(queries, frames_before, frames_after)
processed = _post_process_query_array(normalized, offset, framerate)
segments = await highlight.extract_segments(processed)
Expand Down

0 comments on commit d95f094

Please sign in to comment.