Skip to content

Commit 61fa06e

Browse files
committed
feat: Add CSGO-2
1 parent 81d35ae commit 61fa06e

File tree

18 files changed

+193
-92
lines changed

18 files changed

+193
-92
lines changed

README.md

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ It uses a neural network to detect highlights in the video-game frames.\
1515

1616
# Supported games
1717

18-
Currently it supports **[Valorant](https://playvalorant.com/)**, **[Overwatch](https://playoverwatch.com/)**.
18+
Currently it supports **[Valorant](https://playvalorant.com/)**, **[Overwatch](https://playoverwatch.com/)** and **[CSGO2](https://www.counter-strike.net/cs2)**.
1919

2020
# Usage
2121

@@ -59,7 +59,63 @@ The following settings are adjustable:
5959
- second-before: Seconds of gameplay included before the highlight.
6060
- second-after: Seconds of gameplay included after the highlight.
6161
- second-between-kills: Transition time between highlights. If the time between two highlights is less than this value, the both highlights will be merged.
62-
- game: Chosen game (either "valorant" or "overwatch")
62+
- game: Chosen game (either "valorant", "overwatch" or "csgo2")
63+
64+
### Recommended settings
65+
66+
I recommend you to use the trials and errors method to find the best settings for your videos.\
67+
Here are some settings that I found to work well for me:
68+
69+
#### Valorant
70+
71+
```json
72+
{
73+
"neural-network": {
74+
"confidence": 0.8
75+
},
76+
"clip": {
77+
"framerate": 8,
78+
"second-before": 4,
79+
"second-after": 0.5,
80+
"second-between-kills": 3
81+
},
82+
"game": "valorant"
83+
}
84+
```
85+
86+
#### Overwatch
87+
88+
```json
89+
{
90+
"neural-network": {
91+
"confidence": 0.6
92+
},
93+
"clip": {
94+
"framerate": 8,
95+
"second-before": 4,
96+
"second-after": 3,
97+
"second-between-kills": 5
98+
},
99+
"game": "overwatch"
100+
}
101+
```
102+
103+
#### CSGO2
104+
105+
```json
106+
{
107+
"neural-network": {
108+
"confidence": 0.7
109+
},
110+
"clip": {
111+
"framerate": 8,
112+
"second-before": 4,
113+
"second-after": 1,
114+
"second-between-kills": 3
115+
},
116+
"game": "csgo2"
117+
}
118+
```
63119

64120
## Run
65121

@@ -139,3 +195,7 @@ Now `pre-commit` will run on every `git commit`.
139195

140196
- `cd crispy-frontend && yarn && yarn dev`
141197
- `cd crispy-backend && pip install -Ir requirements-dev.txt && python -m api`
198+
199+
## Test
200+
201+
- `cd crispy-api && pytest`

crispy-api/api/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,24 @@
1111
from montydb import MontyClient, set_storage
1212
from pydantic.json import ENCODERS_BY_TYPE
1313

14-
from api.config import DATABASE_PATH, DEBUG, GAME, MUSICS, VIDEOS
14+
from api.config import DATABASE_PATH, DEBUG, FRAMERATE, GAME, MUSICS, VIDEOS
1515
from api.tools.AI.network import NeuralNetwork
1616
from api.tools.enums import SupportedGames
1717
from api.tools.filters import apply_filters # noqa
1818
from api.tools.setup import handle_highlights, handle_musics
1919

2020
ENCODERS_BY_TYPE[ObjectId] = str
2121

22+
neural_network = NeuralNetwork(GAME)
23+
2224
if GAME == SupportedGames.OVERWATCH:
23-
neural_network = NeuralNetwork([10000, 120, 15, 2])
2425
neural_network.load("./assets/overwatch.npy")
2526
elif GAME == SupportedGames.VALORANT:
26-
neural_network = NeuralNetwork([4000, 120, 15, 2], 0.01)
2727
neural_network.load("./assets/valorant.npy")
28+
elif GAME == SupportedGames.CSGO2:
29+
neural_network.load("./assets/csgo2.npy")
30+
else:
31+
raise ValueError(f"game {GAME} not supported")
2832

2933

3034
logging.getLogger("PIL").setLevel(logging.ERROR)
@@ -62,7 +66,7 @@ def is_tool_installed(ffmpeg_tool: str) -> None:
6266
@app.on_event("startup")
6367
async def setup_crispy() -> None:
6468
await handle_musics(MUSICS)
65-
await handle_highlights(VIDEOS, GAME, framerate=8)
69+
await handle_highlights(VIDEOS, GAME, framerate=FRAMERATE)
6670

6771

6872
@app.exception_handler(HTTPException)

crispy-api/api/__main__.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,35 @@
11
import argparse
22
import asyncio
3+
import os
4+
import sys
35

46
import uvicorn
57

68
from api import init_database
7-
from api.config import DEBUG, HOST, PORT
9+
from api.config import (
10+
DATASET_CSV_PATH,
11+
DATASET_CSV_TEST_PATH,
12+
DEBUG,
13+
HOST,
14+
NETWORK_OUTPUTS_PATH,
15+
PORT,
16+
)
17+
from api.tools.AI.trainer import Trainer, test, train
818
from api.tools.dataset import create_dataset
919
from api.tools.enums import SupportedGames
1020

1121
_parser = argparse.ArgumentParser()
22+
# Dataset
1223
_parser.add_argument("--dataset", action="store_true")
24+
25+
# Trainer
26+
_parser.add_argument("--train", help="Train the network", action="store_true")
27+
_parser.add_argument("--test", help="Test the network", action="store_true")
28+
_parser.add_argument("--epoch", help="Number of epochs", type=int, default=1000)
29+
_parser.add_argument("--load", help="Load a trained network", action="store_true")
30+
_parser.add_argument("--path", help="Path to the network", type=str)
31+
32+
# Game
1333
_parser.add_argument(
1434
"--game", type=str, choices=[game.value for game in SupportedGames]
1535
)
@@ -26,11 +46,34 @@ async def generate_dataset(game: SupportedGames) -> None:
2646

2747

2848
if __name__ == "__main__":
29-
if not _args.dataset:
49+
if not _args.dataset and not _args.train and not _args.test:
3050
uvicorn.run("api:app", host=HOST, port=PORT, reload=DEBUG, proxy_headers=True)
3151
else:
3252
game = SupportedGames(_args.game)
33-
if not game:
34-
raise ValueError("Game not supported")
53+
if _args.dataset:
54+
if not game:
55+
raise ValueError("Game not supported")
56+
57+
asyncio.run(generate_dataset(game))
58+
else:
59+
trainer = Trainer(game, 0.01)
60+
61+
if _args.load:
62+
trainer.load(_args.path)
63+
else:
64+
trainer.initialize_weights()
65+
66+
print(trainer)
67+
if _args.train:
68+
if not os.path.exists(NETWORK_OUTPUTS_PATH):
69+
os.makedirs(NETWORK_OUTPUTS_PATH)
70+
train(
71+
_args.epoch, trainer, DATASET_CSV_PATH, True, NETWORK_OUTPUTS_PATH
72+
)
73+
74+
if _args.test:
75+
if not _args.load and not _args.train:
76+
print("You need to load a trained network")
77+
sys.exit(1)
3578

36-
asyncio.run(generate_dataset(game))
79+
sys.exit(not test(trainer, DATASET_CSV_TEST_PATH))

crispy-api/api/config.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
ASSETS = "assets"
1616
SILENCE_PATH = os.path.join(ASSETS, "silence.mp3")
17-
DOT_PATH = os.path.join(ASSETS, "dot.png")
17+
VALORANT_MASK_PATH = os.path.join(ASSETS, "valorant-mask.png")
18+
CSGO2_MASK_PATH = os.path.join(ASSETS, "csgo2-mask.png")
1819

1920
BACKUP = "backup"
2021

@@ -24,7 +25,10 @@
2425
MUSICS = os.path.join(RESOURCES, "musics")
2526

2627
DATASET_PATH = "dataset"
27-
DATASET_VALUES_PATH = "dataset-values.json"
28+
DATASET_VALUES_PATH = os.path.join(DATASET_PATH, "dataset-values.json")
29+
DATASET_CSV_PATH = os.path.join(DATASET_PATH, "result.csv")
30+
DATASET_CSV_TEST_PATH = os.path.join(DATASET_PATH, "test.csv")
31+
NETWORK_OUTPUTS_PATH = "outputs"
2832

2933
DATABASE_PATH = ".data"
3034

crispy-api/api/models/highlight.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from mongo_thingy import Thingy
88
from PIL import Image, ImageFilter, ImageOps
99

10-
from api.config import BACKUP, DOT_PATH
10+
from api.config import BACKUP, CSGO2_MASK_PATH, VALORANT_MASK_PATH
1111
from api.models.filter import Filter
1212
from api.models.segment import Segment
1313
from api.tools.audio import silence_if_no_audio
1414
from api.tools.enums import SupportedGames
1515
from api.tools.ffmpeg import merge_videos
1616

1717
logger = logging.getLogger("uvicorn")
18+
valorant_mask = Image.open(VALORANT_MASK_PATH)
19+
csgo2_mask = Image.open(CSGO2_MASK_PATH)
1820

1921

2022
class Highlight(Thingy):
@@ -130,9 +132,7 @@ def _apply_filter_and_do_operations(
130132

131133
image = image.crop((1, 1, image.width - 2, image.height - 2))
132134

133-
dot = Image.open(DOT_PATH)
134-
135-
image.paste(dot, (0, 0), dot)
135+
image.paste(valorant_mask, (0, 0), valorant_mask)
136136

137137
left = image.crop((0, 0, 25, 60))
138138
right = image.crop((95, 0, 120, 60))
@@ -162,13 +162,31 @@ def post_process(image: Image) -> Image:
162162
post_process, (899, 801, 122, 62), framerate=framerate
163163
)
164164

165+
async def extract_csgo2_images(self, framerate: int = 4) -> bool:
166+
def post_process(image: Image) -> Image:
167+
image = ImageOps.grayscale(
168+
image.filter(ImageFilter.FIND_EDGES).filter(
169+
ImageFilter.EDGE_ENHANCE_MORE
170+
)
171+
)
172+
final = Image.new("RGB", (100, 100))
173+
final.paste(image, (0, 0))
174+
final.paste(csgo2_mask, (0, 0), csgo2_mask)
175+
return final
176+
177+
return await self.extract_images(
178+
post_process, (930, 925, 100, 100), framerate=framerate
179+
)
180+
165181
async def extract_images_from_game(
166182
self, game: SupportedGames, framerate: int = 4
167183
) -> bool:
168184
if game == SupportedGames.OVERWATCH:
169185
return await self.extract_overwatch_images(framerate)
170186
elif game == SupportedGames.VALORANT:
171187
return await self.extract_valorant_images(framerate)
188+
elif game == SupportedGames.CSGO2:
189+
return await self.extract_csgo2_images(framerate)
172190
else:
173191
raise NotImplementedError
174192

crispy-api/api/tools/AI/network.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,22 @@
33
import numpy as np
44
import scipy.special
55

6+
from api.tools.enums import SupportedGames
7+
8+
NetworkResolution = {
9+
SupportedGames.VALORANT: [4000, 120, 15, 2],
10+
SupportedGames.OVERWATCH: [10000, 120, 15, 2],
11+
SupportedGames.CSGO2: [10000, 120, 15, 2],
12+
}
13+
614

715
class NeuralNetwork:
816
"""
917
Neural network to predict if a kill is on the image
1018
"""
1119

12-
def __init__(self, nodes: List[int], learning_rate: float = 0.01) -> None:
13-
self.nodes = nodes
20+
def __init__(self, game: SupportedGames, learning_rate: float = 0.01) -> None:
21+
self.nodes = NetworkResolution[game]
1422
self.learning_rate = learning_rate
1523
self.weights: List[Any] = []
1624
self.activation_function = lambda x: scipy.special.expit(x)
@@ -55,7 +63,7 @@ def _train(self, inputs: List[float], targets: Any) -> Tuple[int, int, int]:
5563
for i in range(len(self.nodes) - 1 - 1, 0, -1):
5664
errors.insert(0, np.dot(self.weights[i].T, errors[0]))
5765

58-
# ten times more likely to be not be kill
66+
# five times more likely to be not be kill
5967
# so we mitigate the error
6068
if expected == 0:
6169
errors = [e / 5 for e in errors]

0 commit comments

Comments
 (0)