diff --git a/premerge/advisor/advisor.py b/premerge/advisor/advisor.py index 898084226..a1cd2e082 100644 --- a/premerge/advisor/advisor.py +++ b/premerge/advisor/advisor.py @@ -1,13 +1,28 @@ +import sqlite3 + import flask from flask import Flask + import advisor_lib advisor_blueprint = flask.Blueprint("advisor", __name__) +def _get_db(): + if "db" not in flask.g: + flask.g.db = advisor_lib.setup_db(flask.current_app.config["DB_PATH"]) + return flask.g.db + + +def _close_db(exception): + db = flask.g.pop("db", None) + if db is not None: + db.close() + + @advisor_blueprint.route("/upload", methods=["POST"]) def upload(): - advisor_lib.upload_failures(flask.request.json) + advisor_lib.upload_failures(flask.request.json, _get_db()) return flask.Response(status=204) @@ -16,7 +31,10 @@ def explain(): return advisor_lib.explain_failures(flask.request.json) -def create_app(): +def create_app(db_path: str): app = Flask(__name__) app.register_blueprint(advisor_blueprint) + app.teardown_appcontext(_close_db) + with app.app_context(): + app.config["DB_PATH"] = db_path return app diff --git a/premerge/advisor/advisor_lib.py b/premerge/advisor/advisor_lib.py index 5e7092172..fdae414ac 100644 --- a/premerge/advisor/advisor_lib.py +++ b/premerge/advisor/advisor_lib.py @@ -1,4 +1,6 @@ from typing import TypedDict +import sqlite3 +import logging class TestFailure(TypedDict): @@ -12,8 +14,39 @@ class FailureExplanation(TypedDict): reason: str | None -def upload_failures(test_failures: list[TestFailure]): - pass +class FailureUpload(TypedDict): + source_type: str + base_commit_sha: str + source_id: str + failures: list[TestFailure] + + +def setup_db(db_path: str) -> sqlite3.Connection: + connection = sqlite3.connect(db_path) + tables = connection.execute("SELECT name from sqlite_master").fetchall() + if "failures" not in tables: + logging.info("Did not find failures table, creating.") + connection.execute( + "CREATE TABLE failures(source_type, base_commit_sha, source_id, test_file, failure_message)" + ) + connection.commit() + return connection + + +def upload_failures(failure_info: FailureUpload, db_connection: sqlite3.Connection): + failures = [] + for failure in failure_info["failures"]: + failures.append( + ( + failure_info["source_type"], + failure_info["base_commit_sha"], + failure_info["source_id"], + failure["name"], + failure["message"], + ) + ) + db_connection.executemany("INSERT INTO failures VALUES(?, ?, ?, ?, ?)", failures) + db_connection.commit() def explain_failures(test_failures: list[TestFailure]) -> list[FailureExplanation]: diff --git a/premerge/advisor/advisor_lib_test.py b/premerge/advisor/advisor_lib_test.py index 4d79e2dd5..4010e61e1 100644 --- a/premerge/advisor/advisor_lib_test.py +++ b/premerge/advisor/advisor_lib_test.py @@ -1,12 +1,49 @@ import unittest +import tempfile import advisor_lib class AdvisorLibTest(unittest.TestCase): + def setUp(self): + self.db_file = tempfile.NamedTemporaryFile() + self.db_connection = advisor_lib.setup_db(self.db_file.name) + + def tearDown(self): + self.db_connection.close() + self.db_file.close() + def test_upload_failures(self): - failures = [{"name": "a.ll", "message": "failed"}] - advisor_lib.upload_failures(failures) + failure_info = { + "source_type": "buildbot", + "base_commit_sha": "8d29a3bb6f3d92d65bf5811b53bf42bf63685359", + "source_id": "10000", + "failures": [ + {"name": "a.ll", "message": "failed in way 1"}, + {"name": "b.ll", "message": "failed in way 2"}, + ], + } + advisor_lib.upload_failures(failure_info, self.db_connection) + failures = self.db_connection.execute("SELECT * from failures").fetchall() + self.assertListEqual( + failures, + [ + ( + "buildbot", + "8d29a3bb6f3d92d65bf5811b53bf42bf63685359", + "10000", + "a.ll", + "failed in way 1", + ), + ( + "buildbot", + "8d29a3bb6f3d92d65bf5811b53bf42bf63685359", + "10000", + "b.ll", + "failed in way 2", + ), + ], + ) def test_explain_failures(self): failures = [{"name": "a.ll", "message": "failed"}] diff --git a/premerge/advisor/integration_test.py b/premerge/advisor/integration_test.py index c369d0f98..75055c822 100644 --- a/premerge/advisor/integration_test.py +++ b/premerge/advisor/integration_test.py @@ -1,24 +1,34 @@ import unittest +import tempfile import advisor class AdvisorIntegrationTest(unittest.TestCase): def setUp(self): - self.app = advisor.create_app() + self.db_file = tempfile.NamedTemporaryFile() + self.app = advisor.create_app(self.db_file.name) self.client = self.app.test_client() def tearDown(self): - pass + self.db_file.close() def test_upload_failures(self): - failures = [{"name": "a.ll", "message": "failed"}] - result = self.client.post("/upload", json=failures) + failure_info = { + "source_type": "buildbot", + "base_commit_sha": "8d29a3bb6f3d92d65bf5811b53bf42bf63685359", + "source_id": "10000", + "failures": [ + {"name": "a.ll", "message": "failed in way 1"}, + ], + } + result = self.client.post("/upload", json=failure_info) self.assertEqual(result.status_code, 204) def test_explain_failures(self): failures = [{"name": "a.ll", "message": "failed"}] result = self.client.get("/explain", json=failures) + self.assertEqual(result.status_code, 200) self.assertListEqual( result.json, [{"name": "a.ll", "explained": False, "reason": None}] ) diff --git a/premerge/advisor/server.py b/premerge/advisor/server.py index 37dd631f8..9ef014dc2 100644 --- a/premerge/advisor/server.py +++ b/premerge/advisor/server.py @@ -1,6 +1,8 @@ +import os + import advisor if __name__ == "__main__": - app = advisor.create_app() + app = advisor.create_app(os.environ["ADVISOR_DB_PATH"]) app.run(host="0.0.0.0", port=5000)