Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,5 @@ boto3
azure_storage_blob==12.26.0
openpyxl==3.1.5
parameterized
sqlglot
pymemcache==4.0.0
4 changes: 2 additions & 2 deletions sql/engines/goinception.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def query_data_masking(self, instance, db_name=None, sql=""):
sql = f"""/*--user={user};--password={password};--host={host};--port={port};--masking=1;*/
inception_magic_start;
use `{db_name}`;
{sql}
{sql};
inception_magic_commit;"""
query_result = self.query(db_name=db_name, sql=sql)
# 有异常时主动抛出
Expand All @@ -218,7 +218,7 @@ def query_data_masking(self, instance, db_name=None, sql=""):
if print_info.get("errlevel") == 0 and print_info.get("errmsg") is None:
return json.loads(print_info["query_tree"])
else:
raise RuntimeError(f'Inception Error: print_info.get("errmsg")')
raise RuntimeError(f'Inception Error: {print_info.get("errmsg")}')

def get_rollback(self, workflow):
"""
Expand Down
6 changes: 3 additions & 3 deletions sql/engines/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,12 @@ def query_check(self, db_name=None, sql=""):
]
keyword_warning = ""
star_patter = r"(^|,|\s)\*(\s|\(|$)"
sql_whitelist = ["select", "sp_helptext"]
sql_whitelist = ["select", "sp_helptext", "with"]
# 根据白名单list拼接pattern语句
whitelist_pattern = "^" + "|^".join(sql_whitelist)
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sql.format(sql, strip_comments=True)
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
result["filtered_sql"] = sql.strip()
sql_lower = sql.lower()
Expand Down Expand Up @@ -365,7 +365,7 @@ def query_masking(self, db_name=None, sql="", resultset=None):
"""传入 sql语句, db名, 结果集,
返回一个脱敏后的结果集"""
# 仅对select语句脱敏
if re.match(r"^select", sql, re.I):
if re.match(r"^select|^with", sql, re.I):
filtered_result = brute_mask(self.instance, resultset)
filtered_result.is_masked = True
else:
Expand Down
13 changes: 8 additions & 5 deletions sql/engines/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sql.utils.sql_utils import get_syntax_type, remove_comments
from . import EngineBase
from .models import ResultSet, ReviewResult, ReviewSet
from sql.utils.data_masking import data_masking
from sql.utils.data_masking import data_masking, simple_column_mask
from common.config import SysConfig

logger = logging.getLogger("default")
Expand Down Expand Up @@ -557,14 +557,14 @@ def query_check(self, db_name=None, sql=""):
except IndexError:
result["bad_query"] = True
result["msg"] = "没有有效的SQL语句"
if re.match(r"^select|^show|^explain", sql, re.I) is None:
if re.match(r"^select|^show|^explain|^with", sql, re.I) is None:
result["bad_query"] = True
result["msg"] = "不支持的查询语法类型!"
if "*" in sql:
result["has_star"] = True
result["msg"] = "SQL语句中含有 * "
# select语句先使用Explain判断语法是否正确
if re.match(r"^select", sql, re.I):
# select和with语句先使用Explain判断语法是否正确
if re.match(r"^select|^with", sql, re.I):
explain_result = self.query(db_name=db_name, sql=f"explain {sql}")
if explain_result.error:
result["bad_query"] = True
Expand Down Expand Up @@ -620,6 +620,9 @@ def query_masking(self, db_name=None, sql="", resultset=None):
# 仅对select语句脱敏
if re.match(r"^select", sql, re.I):
mask_result = data_masking(self.instance, db_name, sql, resultset)
# 因goinception的支持问题,mysql的with语句脱敏使用simple_column_mask
elif re.match(r"^with", sql, re.I):
mask_result = simple_column_mask(self.instance, resultset)
else:
mask_result = resultset
return mask_result
Expand Down Expand Up @@ -661,7 +664,7 @@ def execute_check(self, db_name=None, sql=""):
# 获取提交类型
syntax_type = get_syntax_type(statement, parser=False, db_type="mysql")
# 禁用语句
if re.match(r"^select", statement.lower()):
if re.match(r"^select|^with", statement.lower()):
check_result.error_count += 1
row.stagestatus = "驳回不支持语句"
row.errlevel = 2
Expand Down
12 changes: 2 additions & 10 deletions sql/engines/odps.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,9 @@ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)

if not re.match(r"^select", sql, re.I):
if not re.match(r"^select|^with", sql, re.I):
result_set.error = str("仅支持ODPS查询语句")

# 存在limit,替换limit; 不存在,添加limit
if re.search("limit", sql):
sql = re.sub("limit.+(\d+)", "limit " + str(limit_num), sql)
else:
if sql.strip()[-1] == ";":
sql = sql[:-1]
sql = sql + " limit " + str(limit_num) + ";"

try:
conn = self.get_connection(db_name)
effect_row = conn.execute_sql(sql)
Expand All @@ -136,7 +128,7 @@ def query_check(self, db_name=None, sql=""):
# 查询语句的检查、注释去除、切分
result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
keyword_warning = ""
sql_whitelist = ["select"]
sql_whitelist = ["select", "with"]
# 根据白名单list拼接pattern语句
whitelist_pattern = re.compile("^" + "|^".join(sql_whitelist), re.IGNORECASE)
# 删除注释语句,进行语法判断,执行第一条有效sql
Expand Down
4 changes: 2 additions & 2 deletions sql/engines/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,8 @@ def query(
return result_set

def query_masking(self, db_name=None, sql="", resultset=None):
"""简单字段脱敏规则, 仅对select有效"""
if re.match(r"^select", sql, re.I):
"""简单字段脱敏规则, 仅对查询语句有效"""
if re.match(r"^select|^with", sql, re.I):
filtered_result = simple_column_mask(self.instance, resultset)
filtered_result.is_masked = True
else:
Expand Down
6 changes: 3 additions & 3 deletions sql/engines/pgsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def query_check(self, db_name=None, sql=""):
except IndexError:
result["bad_query"] = True
result["msg"] = "没有有效的SQL语句"
if re.match(r"^select|^explain", sql, re.I) is None:
if re.match(r"^select|^explain|^with", sql, re.I) is None:
result["bad_query"] = True
result["msg"] = "不支持的查询语法类型!"
if "*" in sql:
Expand Down Expand Up @@ -257,7 +257,7 @@ def filter_sql(self, sql="", limit_num=0):

def query_masking(self, db_name=None, sql="", resultset=None):
"""简单字段脱敏规则, 仅对select有效"""
if re.match(r"^select", sql, re.I):
if re.match(r"^select|^with", sql, re.I):
filtered_result = simple_column_mask(self.instance, resultset)
filtered_result.is_masked = True
else:
Expand All @@ -276,7 +276,7 @@ def execute_check(self, db_name=None, sql=""):
for statement in sqlparse.split(sql):
statement = sqlparse.format(statement, strip_comments=True)
# 禁用语句
if re.match(r"^select", statement.lower()):
if re.match(r"^select|^with", statement.lower()):
result = ReviewResult(
id=line,
errlevel=2,
Expand Down
1 change: 1 addition & 0 deletions sql/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ class QueryLog(models.Model):
instance_name = models.CharField("实例名称", max_length=50)
db_name = models.CharField("数据库名称", max_length=64)
sqllog = models.TextField("执行的查询语句")
original_sql = models.TextField("原始查询语句")
effect_row = models.BigIntegerField("返回行数")
cost_time = models.CharField("执行耗时", max_length=10, default="")
# TODO 改为user 外键
Expand Down
5 changes: 4 additions & 1 deletion sql/offlinedownload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import datetime
import xml.etree.ElementTree as ET
import zipfile
from numpy import diag
import sqlparse
import time

Expand All @@ -22,6 +23,7 @@
from sql.storage import DynamicStorage
from sql.engines import get_engine
from common.config import SysConfig
from sql.utils.sql_utils import SqlglotUtils

logger = logging.getLogger("default")

Expand Down Expand Up @@ -134,7 +136,8 @@ def pre_count_check(self, workflow):
full_sql = sqlparse.format(full_sql, strip_comments=True)
full_sql = sqlparse.split(full_sql)[0]
sql = full_sql.strip()
count_sql = f"SELECT COUNT(*) FROM ({sql.rstrip(';')}) t"
dialect = SqlglotUtils.get_dialect(workflow.db_type)
count_sql = SqlglotUtils.wrap_query_with_count(sql, dialect)
clean_sql = sql.strip().lower()
instance = workflow
check_result = ReviewSet(full_sql=sql)
Expand Down
24 changes: 21 additions & 3 deletions sql/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from sql.utils.tasks import add_kill_conn_schedule, del_schedule
from .models import QueryLog, Instance
from sql.engines import get_engine
from sql.utils.sql_utils import SqlglotUtils

logger = logging.getLogger("default")

Expand All @@ -36,6 +37,7 @@ def query(request):
tb_name = request.POST.get("tb_name")
limit_num = int(request.POST.get("limit_num", 0))
schema_name = request.POST.get("schema_name", None)
is_offline_export = int(request.POST.get("is_offline_export", 0))
user = request.user

result = {"status": 0, "msg": "ok", "data": {}}
Expand Down Expand Up @@ -68,6 +70,7 @@ def query(request):
result["msg"] = query_check_info.get("msg")
return HttpResponse(json.dumps(result), content_type="application/json")
sql_content = query_check_info["filtered_sql"]
original_sql = sql_content.strip()

# 查询权限校验,并且获取limit_num
priv_check_info = query_priv_check(
Expand All @@ -82,9 +85,22 @@ def query(request):
return HttpResponse(json.dumps(result), content_type="application/json")
# explain的limit_num设置为0
limit_num = 0 if re.match(r"^explain", sql_content.lower()) else limit_num

# 对查询sql增加limit限制或者改写语句
sql_content = query_engine.filter_sql(sql=sql_content, limit_num=limit_num)
dialect = SqlglotUtils.get_dialect(instance.db_type)
if is_offline_export:
# 离线导出,统计总数,需要包装count查询语句
sql_content = SqlglotUtils.wrap_query_with_count(sql_content, dialect)
else:
# 页面查询,增加行数限制
# 支持sqlglot方言转换的,使用方言添加行数限制
if dialect:
sql_content = SqlglotUtils.add_limit_to_query(
sql_content, limit_num, dialect
)
else:
# 不支持sqlglot方言的,使用引擎filter_sql函数处理
sql_content = query_engine.filter_sql(
sql=sql_content, limit_num=limit_num
)

# 先获取查询连接,用于后面查询复用连接以及终止会话
query_engine.get_connection(db_name=db_name)
Expand Down Expand Up @@ -176,6 +192,7 @@ def query(request):
db_name=db_name,
instance_name=instance.instance_name,
sqllog=sql_content,
original_sql=original_sql,
effect_row=limit_num,
cost_time=query_result.query_time,
priv_check=priv_check,
Expand Down Expand Up @@ -273,6 +290,7 @@ def _querylog(request):
"instance_name",
"db_name",
"sqllog",
"original_sql",
"effect_row",
"cost_time",
"user_display",
Expand Down
2 changes: 1 addition & 1 deletion sql/templates/sqlexportsubmit.html
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@
}

var sqlContent = sqlContent.trim().replace(/;$/, '');
var sqlContent = 'select count(1) from (' + sqlContent + '\n) t'

//提交请求
$.ajax({
Expand All @@ -540,6 +539,7 @@
schema_name: $("#schema_name").val(),
tb_name: $("#table_name").val(),
sql_content: sqlContent,
is_offline_export: 1,
limit_num: $("#limit_num").val()
},
complete: function () {
Expand Down
Loading
Loading