From 0cf607d13aeec246c73ad98556012e7421d68ea4 Mon Sep 17 00:00:00 2001 From: yunjinqi Date: Tue, 14 May 2024 21:27:21 +0800 Subject: [PATCH] fix format --- wtpy/CodeHelper.py | 9 +- wtpy/ContractMgr.py | 95 ++-- wtpy/CtaContext.py | 364 ++++++------- wtpy/ExtModuleDefs.py | 101 ++-- wtpy/ExtToolDefs.py | 32 +- wtpy/HftContext.py | 321 ++++++------ wtpy/ProductMgr.py | 43 +- wtpy/SelContext.py | 279 +++++----- wtpy/SessionMgr.py | 78 +-- wtpy/StrategyDefs.py | 242 ++++----- wtpy/WtBtEngine.py | 292 ++++++----- wtpy/WtCoreDefs.py | 458 ++++++++-------- wtpy/WtDataDefs.py | 285 +++++----- wtpy/WtDtEngine.py | 62 +-- wtpy/WtDtServo.py | 75 +-- wtpy/WtEngine.py | 266 +++++----- wtpy/WtMsgQue.py | 35 +- wtpy/WtUtilDefs.py | 7 +- wtpy/__init__.py | 24 +- wtpy/apps/WtBtAnalyst.py | 898 +++++++++++++++---------------- wtpy/apps/WtCCLoader.py | 154 +++--- wtpy/apps/WtCtaGAOptimizer.py | 82 +-- wtpy/apps/WtCtaOptimizer.py | 305 ++++++----- wtpy/apps/WtHftOptimizer.py | 275 +++++----- wtpy/apps/WtHotPicker.py | 422 +++++++-------- wtpy/monitor/DataMgr.py | 378 +++++++------- wtpy/monitor/EventReceiver.py | 50 +- wtpy/monitor/PushSvr.py | 47 +- wtpy/monitor/WatchDog.py | 159 +++--- wtpy/monitor/WtBtMon.py | 163 +++--- wtpy/monitor/WtBtSnooper.py | 333 ++++++------ wtpy/monitor/WtLogger.py | 29 +- wtpy/monitor/WtMonSvr.py | 103 ++-- wtpy/monitor/__init__.py | 3 +- wtpy/wrapper/ContractLoader.py | 42 +- wtpy/wrapper/PlatformHelper.py | 18 +- wtpy/wrapper/TraderDumper.py | 77 +-- wtpy/wrapper/WtBtWrapper.py | 914 ++++++++++++++++---------------- wtpy/wrapper/WtDtHelper.py | 219 ++++---- wtpy/wrapper/WtDtServoApi.py | 71 +-- wtpy/wrapper/WtDtWrapper.py | 61 ++- wtpy/wrapper/WtExecApi.py | 27 +- wtpy/wrapper/WtMQWrapper.py | 30 +- wtpy/wrapper/WtWrapper.py | 927 +++++++++++++++++---------------- wtpy/wrapper/__init__.py | 5 +- 45 files changed, 4605 insertions(+), 4255 deletions(-) diff --git a/wtpy/CodeHelper.py b/wtpy/CodeHelper.py index 719273b8..7e517513 100644 --- a/wtpy/CodeHelper.py +++ b/wtpy/CodeHelper.py @@ -1,9 +1,10 @@ import re + class CodeHelper: - + @staticmethod - def isStdChnFutOptCode(stdCode:str) -> bool: + def isStdChnFutOptCode(stdCode: str) -> bool: pattern = re.compile("^[A-Z]+.[A-z]+\\d{4}.(C|P).\\d+$") if re.match(pattern, stdCode) is not None: return True @@ -11,7 +12,7 @@ def isStdChnFutOptCode(stdCode:str) -> bool: return False @staticmethod - def stdCodeToStdCommID(stdCode:str) -> str: + def stdCodeToStdCommID(stdCode: str) -> str: ay = stdCode.split(".") if not CodeHelper.isStdChnFutOptCode(stdCode): return ay[0] + "." + ay[1] @@ -24,4 +25,4 @@ def stdCodeToStdCommID(stdCode:str) -> str: elif exchg == 'CFFEX': return exchg + "." + pid else: - return exchg + "." + pid + '_o' \ No newline at end of file + return exchg + "." + pid + '_o' diff --git a/wtpy/ContractMgr.py b/wtpy/ContractMgr.py index efad3d54..dbedc131 100644 --- a/wtpy/ContractMgr.py +++ b/wtpy/ContractMgr.py @@ -4,40 +4,41 @@ from .ProductMgr import ProductMgr, ProductInfo + class ContractInfo: def __init__(self): - self.exchg:str = '' #交易所 - self.code:str = '' #合约代码 - self.name:str = '' #合约名称 - self.product:str = '' #品种代码 - self.stdCode:str = '' #标准代码 + self.exchg: str = '' # 交易所 + self.code: str = '' # 合约代码 + self.name: str = '' # 合约名称 + self.product: str = '' # 品种代码 + self.stdCode: str = '' # 标准代码 - self.isOption:bool = False # 是否期权合约 - self.underlying:str = '' # underlying - self.strikePrice:float = 0 # 行权价 - self.underlyingScale:float = 0 # 放大倍数 - self.isCall:bool = True # 是否看涨期权 + self.isOption: bool = False # 是否期权合约 + self.underlying: str = '' # underlying + self.strikePrice: float = 0 # 行权价 + self.underlyingScale: float = 0 # 放大倍数 + self.isCall: bool = True # 是否看涨期权 - self.openDate:int = 19000101 # 上市日期 - self.expireDate:int = 20991231 # 到期日 + self.openDate: int = 19000101 # 上市日期 + self.expireDate: int = 20991231 # 到期日 - self.longMarginRatio:float = 0 # 多头保证金率 - self.shortMarginRatio:float = 0 # 空头保证金率 + self.longMarginRatio: float = 0 # 多头保证金率 + self.shortMarginRatio: float = 0 # 空头保证金率 class ContractMgr: - def __init__(self, prodMgr:ProductMgr = None): + def __init__(self, prodMgr: ProductMgr = None): self.__contracts__ = dict() - self.__underlyings__ = dict() # 期权专用 - self.__products__ = dict() # 期权专用 + self.__underlyings__ = dict() # 期权专用 + self.__products__ = dict() # 期权专用 self.__prod_mgr__ = prodMgr - def load(self, fname:str): - ''' + def load(self, fname: str): + """ 从文件加载品种信息 - ''' + """ f = open(fname, 'rb') content = f.read() f.close() @@ -71,8 +72,8 @@ def load(self, fname:str): cInfo.shortMarginRatio = float(cObj["shortmarginratio"]) if "product" in cObj: - cInfo.product = cObj["product"] - #股票标准代码为SSE.000001,期货标准代码为SHFE.rb.2010 + cInfo.product = cObj["product"] + # 股票标准代码为SSE.000001,期货标准代码为SHFE.rb.2010 if cInfo.code[:len(cInfo.product)] == cInfo.product: month = cInfo.code[len(cInfo.product):] if len(month) < 4: @@ -107,7 +108,7 @@ def load(self, fname:str): if "option" in cObj: oObj = cObj["option"] cInfo.isOption = True - cInfo.isCall = (int(oObj["optiontype"])==49) + cInfo.isCall = (int(oObj["optiontype"]) == 49) cInfo.underlying = oObj["underlying"] cInfo.strikePrice = float(oObj["strikeprice"]) cInfo.underlyingScale = float(oObj["underlyingscale"]) @@ -120,61 +121,59 @@ def load(self, fname:str): self.__underlyings__[stdUnderlying].append(cInfo.stdCode) - def getContractInfo(self, stdCode:str, uDate:int = 0) -> ContractInfo: - ''' + def getContractInfo(self, stdCode: str, uDate: int = 0) -> ContractInfo: + """ 获取合约信息 @stdCode 合约代码,格式如SHFE.rb.2305 - ''' + """ if stdCode not in self.__contracts__: return None - - cInfo:ContractInfo = self.__contracts__[stdCode] + + cInfo: ContractInfo = self.__contracts__[stdCode] if uDate != 0 and (cInfo.openDate > uDate or cInfo.expireDate < uDate): return None - + return cInfo - def getTotalCodes(self, uDate:int = 0) -> list: - ''' + def getTotalCodes(self, uDate: int = 0) -> list: + """ 获取全部合约代码列表 @uDate 交易日, 格式如20210101 - ''' + """ codes = list() for code in self.__contracts__: - cInfo:ContractInfo = self.__contracts__[code] - if uDate == 0 or (cInfo.openDate <= uDate and cInfo.expireDate >= uDate): + cInfo: ContractInfo = self.__contracts__[code] + if uDate == 0 or (cInfo.openDate <= uDate <= cInfo.expireDate): codes.append(self.__contracts__[code].stdCode) return codes - - def getCodesByUnderlying(self, underlying:str, uDate:int = 0) -> list: - ''' + + def getCodesByUnderlying(self, underlying: str, uDate: int = 0) -> list: + """ 根据underlying读取合约列表 @underlying 格式如CFFEX.IM2304 @uDate 交易日, 格式如20210101 - ''' + """ ret = list() if underlying in self.__underlyings__: codes = self.__underlyings__[underlying] for code in codes: - cInfo:ContractInfo = self.__contracts__[code] - if uDate == 0 or (cInfo.openDate <= uDate and cInfo.expireDate >= uDate): + cInfo: ContractInfo = self.__contracts__[code] + if uDate == 0 or (cInfo.openDate <= uDate <= cInfo.expireDate): ret.append(self.__contracts__[code].stdCode) return ret - - def getCodesByProduct(self, stdPID:str, uDate:int = 0) -> list: - ''' + + def getCodesByProduct(self, stdPID: str, uDate: int = 0) -> list: + """ 根据品种代码读取合约列表 @stdPID 品种代码, 格式如SHFE.rb @uDate 交易日, 格式如20210101 - ''' + """ ret = list() if stdPID in self.__products__: codes = self.__products__[stdPID] for code in codes: - cInfo:ContractInfo = self.__contracts__[code] - if uDate == 0 or (cInfo.openDate <= uDate and cInfo.expireDate >= uDate): + cInfo: ContractInfo = self.__contracts__[code] + if uDate == 0 or (cInfo.openDate <= uDate <= cInfo.expireDate): ret.append(self.__contracts__[code].stdCode) return ret - - diff --git a/wtpy/CtaContext.py b/wtpy/CtaContext.py index 131516f3..1ff76ce8 100644 --- a/wtpy/CtaContext.py +++ b/wtpy/CtaContext.py @@ -5,24 +5,25 @@ from wtpy.WtCoreDefs import WTSBarStruct, WTSTickStruct from ctypes import POINTER + class CtaContext: - ''' + """ Context是策略可以直接访问的唯一对象 策略所有的接口都通过Context对象调用 Context类包括以下几类接口: 1、时间接口(日期、时间等), 接口格式如: stra_xxx 2、数据接口(K线、财务等), 接口格式如: stra_xxx 3、下单接口(设置目标仓位、直接下单等), 接口格式如: stra_xxx - ''' - - def __init__(self, id:int, stra, wrapper, engine): - self.__stra_info__ = stra #策略对象, 对象基类BaseStrategy.py - self.__wrapper__ = wrapper #底层接口转换器 - self.__id__ = id #策略ID - self.__bar_cache__ = dict() #K线缓存 - self.__tick_cache__ = dict() #tTick缓存, 每次都重新去拉取, 这个只做中转用, 不在python里维护副本 - self.__sname__ = stra.name() - self.__engine__ = engine #交易环境 + """ + + def __init__(self, id: int, stra, wrapper, engine): + self.__stra_info__ = stra # 策略对象, 对象基类BaseStrategy.py + self.__wrapper__ = wrapper # 底层接口转换器 + self.__id__ = id # 策略ID + self.__bar_cache__ = dict() # K线缓存 + self.__tick_cache__ = dict() # tTick缓存, 每次都重新去拉取, 这个只做中转用, 不在python里维护副本 + self.__sname__ = stra.name() + self.__engine__ = engine # 交易环境 self.__pos_cache__ = None self.__alias__() @@ -30,15 +31,15 @@ def __init__(self, id:int, stra, wrapper, engine): @property def id(self): return self.__id__ - + @property def is_backtest(self): return self.__engine__.is_backtest - + def __alias__(self): - ''' + """ 接口函数别名 - ''' + """ self.enter_long = self.stra_enter_long self.enter_short = self.stra_enter_short self.exit_long = self.stra_exit_long @@ -72,80 +73,80 @@ def __alias__(self): self.sub_bar_events = self.stra_sub_bar_events pass - def write_indicator(self, tag:str, time:int, data:dict): - ''' + def write_indicator(self, tag: str, time: int, data: dict): + """ 输出指标数据 @tag 指标标签 @time 输出时间 @data 输出的指标数据, dict类型, 会转成json以后保存 - ''' + """ self.__engine__.write_indicator(self.__stra_info__.name(), tag, time, data) def on_init(self): - ''' + """ 初始化, 一般用于系统启动的时候 - ''' + """ self.__stra_info__.on_init(self) - def on_session_begin(self, curTDate:int): - ''' + def on_session_begin(self, curTDate: int): + """ 交易日开始事件 @curTDate 交易日, 格式为20210220 - ''' + """ self.__stra_info__.on_session_begin(self, curTDate) - def on_session_end(self, curTDate:int): - ''' + def on_session_end(self, curTDate: int): + """ 交易日结束事件 @curTDate 交易日, 格式为20210220 - ''' + """ self.__stra_info__.on_session_end(self, curTDate) def on_backtest_end(self): - ''' + """ 回测结束事件 - ''' + """ self.__stra_info__.on_backtest_end(self) - def on_getticks(self, stdCode:str, newTicks:WtNpTicks): + def on_getticks(self, stdCode: str, newTicks: WtNpTicks): key = stdCode self.__tick_cache__[key] = newTicks - def on_getpositions(self, stdCode:str, qty:float, frozen:float): + def on_getpositions(self, stdCode: str, qty: float, frozen: float): if len(stdCode) == 0: return self.__pos_cache__[stdCode] = qty - def on_getbars(self, stdCode:str, period:str, npBars:WtNpKline): + def on_getbars(self, stdCode: str, period: str, npBars: WtNpKline): key = "%s#%s" % (stdCode, period) self.__bar_cache__[key] = npBars - def on_condition_triggered(self, stdCode:str, target:float, price:float, usertag:str): - self.__stra_info__.on_condition_triggered(self,stdCode, target, price, usertag) + def on_condition_triggered(self, stdCode: str, target: float, price: float, usertag: str): + self.__stra_info__.on_condition_triggered(self, stdCode, target, price, usertag) - def on_tick(self, stdCode:str, newTick:POINTER(WTSTickStruct)): - ''' + def on_tick(self, stdCode: str, newTick: POINTER(WTSTickStruct)): + """ tick回调事件响应 - ''' + """ self.__stra_info__.on_tick(self, stdCode, newTick.contents.to_dict()) - def on_bar(self, stdCode:str, period:str, newBar:POINTER(WTSBarStruct)): - ''' + def on_bar(self, stdCode: str, period: str, newBar: POINTER(WTSBarStruct)): + """ K线闭合事件响应 @stdCode 品种代码 @period K线基础周期 @times 周期倍数 @newBar 最新K线 - ''' + """ key = "%s#%s" % (stdCode, period) if key not in self.__bar_cache__: return - + try: self.__stra_info__.on_bar(self, stdCode, period, newBar.contents.to_dict()) except ValueError as ve: @@ -154,122 +155,122 @@ def on_bar(self, stdCode:str, period:str, newBar:POINTER(WTSBarStruct)): return def on_calculate(self): - ''' + """ 策略重算回调 主K线闭合才会触发该回调接口 - ''' + """ self.__stra_info__.on_calculate(self) def on_calculate_done(self): - ''' + """ 重算结束回调 只有在异步模式下才会触发, 目前主要针对强化学习的训练场景, 需要在重算以后将智能体的信号传递给底层 - ''' + """ self.__stra_info__.on_calculate_done(self) - def stra_log_text(self, message:str, level:int = 1): - ''' + def stra_log_text(self, message: str, level: int = 1): + """ 输出日志 @level 日志级别, 0-debug, 1-info, 2-warn, 3-error @message 消息内容, 最大242字符 - ''' + """ self.__wrapper__.cta_log_text(self.__id__, level, message[:242]) def stra_get_tdate(self) -> int: - ''' + """ 获取当前交易日 @return int, 格式如20180513 - ''' + """ return self.__wrapper__.cta_get_tdate() - + def stra_get_date(self) -> int: - ''' + """ 获取当前日期 @return int, 格式如20180513 - ''' + """ return self.__wrapper__.cta_get_date() - def stra_get_position_avgpx(self, stdCode:str = "") -> float: - ''' + def stra_get_position_avgpx(self, stdCode: str = "") -> float: + """ 获取当前持仓均价 @stdCode 合约代码 @return 持仓均价 - ''' + """ return self.__wrapper__.cta_get_position_avgpx(self.__id__, stdCode) - def stra_get_position_profit(self, stdCode:str = "") -> float: - ''' + def stra_get_position_profit(self, stdCode: str = "") -> float: + """ 获取持仓浮动盈亏 @stdCode 合约代码, 为None时读取全部品种的浮动盈亏 @return 浮动盈亏 - ''' + """ return self.__wrapper__.cta_get_position_profit(self.__id__, stdCode) - def stra_get_fund_data(self, flag:int = 0) -> float: - ''' + def stra_get_fund_data(self, flag: int = 0) -> float: + """ 获取资金数据 @flag 0-动态权益, 1-总平仓盈亏, 2-总浮动盈亏, 3-总手续费 @return 资金数据 - ''' + """ return self.__wrapper__.cta_get_fund_data(self.__id__, flag) def stra_get_time(self) -> int: - ''' + """ 获取当前时间, 24小时制, 精确到分 @return int, 格式如1231 - ''' + """ return self.__wrapper__.cta_get_time() - def stra_get_price(self, stdCode:str) -> float: - ''' + def stra_get_price(self, stdCode: str) -> float: + """ 获取最新价格, 一般在获取了K线以后再获取该价格 @return 最新价格 - ''' + """ return self.__wrapper__.cta_get_price(stdCode) - def stra_get_day_price(self, stdCode:str, flag:int = 0) -> float: - ''' + def stra_get_day_price(self, stdCode: str, flag: int = 0) -> float: + """ 获取当日价格 @flag 价格标记, 0-开盘价, 1-最高价, 2-最低价, 3-最新价 @return 最新价格 - ''' + """ return self.__wrapper__.cta_get_day_price(stdCode, flag) def stra_get_all_position(self) -> dict: - ''' + """ 获取全部持仓 - ''' - self.__pos_cache__ = dict() # + """ + self.__pos_cache__ = dict() # self.__wrapper__.cta_get_all_position(self.__id__) return self.__pos_cache__ - def stra_prepare_bars(self, stdCode:str, period:str, count:int, isMain:bool = False): - ''' + def stra_prepare_bars(self, stdCode: str, period: str, count: int, isMain: bool = False): + """ 准备历史K线 一般在on_init调用 @stdCode 合约代码 @period K线周期, 如m3/d7 @count 要拉取的K线条数 @isMain 是否是主K线 - ''' + """ key = "%s#%s" % (stdCode, period) if key in self.__bar_cache__: - #这里做一个数据长度处理 + # 这里做一个数据长度处理 return self.__bar_cache__[key] - + self.__wrapper__.cta_get_bars(self.__id__, stdCode, period, count, isMain) - def stra_get_bars(self, stdCode:str, period:str, count:int, isMain:bool = False) -> WtNpKline: - ''' + def stra_get_bars(self, stdCode: str, period: str, count: int, isMain: bool = False) -> WtNpKline: + """ 获取历史K线 @stdCode 合约代码 @period K线周期, 如m3/d7 @count 要拉取的K线条数 @isMain 是否是主K线 - ''' + """ key = "%s#%s" % (stdCode, period) - cnt = self.__wrapper__.cta_get_bars(self.__id__, stdCode, period, count, isMain) + cnt = self.__wrapper__.cta_get_bars(self.__id__, stdCode, period, count, isMain) if cnt == 0: return None @@ -277,290 +278,293 @@ def stra_get_bars(self, stdCode:str, period:str, count:int, isMain:bool = False) return npBars - def stra_get_ticks(self, stdCode:str, count:int) -> WtNpTicks: - ''' + def stra_get_ticks(self, stdCode: str, count: int) -> WtNpTicks: + """ 获取tick数据 @stdCode 合约代码 @count 要拉取的tick数量 - ''' + """ cnt = self.__wrapper__.cta_get_ticks(self.__id__, stdCode, count) if cnt == 0: return None - + np_ticks = self.__tick_cache__[stdCode] return np_ticks - def stra_sub_ticks(self, stdCode:str): - ''' + def stra_sub_ticks(self, stdCode: str): + """ 订阅实时行情 获取K线和tick数据的时候会自动订阅, 这里只需要订阅额外要检测的品种即可 @stdCode 合约代码 - ''' + """ self.__wrapper__.cta_sub_ticks(self.__id__, stdCode) - def stra_sub_bar_events(self, stdCode:str, perriod:str): - ''' + def stra_sub_bar_events(self, stdCode: str, perriod: str): + """ 订阅K线事件, 订阅以后on_bar会触发, 一般在on_init调用 @stdCode 合约代码 @period K线周期, 如m3/d7 - ''' + """ self.__wrapper__.cta_sub_bar_events(self.__id__, stdCode, perriod) - def stra_get_position(self, stdCode:str, bonlyvalid:bool = False, usertag:str = "") -> float: - ''' + def stra_get_position(self, stdCode: str, bonlyvalid: bool = False, usertag: str = "") -> float: + """ 读取当前仓位 @stdCode 合约/股票代码 @bonlyvalid 只读可用持仓, 默认为False @usertag 入场标记 @return 正为多仓, 负为空仓 - ''' + """ return self.__wrapper__.cta_get_position(self.__id__, stdCode, bonlyvalid, usertag) - def stra_set_position(self, stdCode:str, qty:float, usertag:str = "", limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def stra_set_position(self, stdCode: str, qty: float, usertag: str = "", limitprice: float = 0.0, + stopprice: float = 0.0): + """ 设置仓位 @stdCode 合约/股票代码 @qty 目标仓位, 正为多仓, 负为空仓 @return 设置结果TRUE/FALSE - ''' + """ self.__wrapper__.cta_set_position(self.__id__, stdCode, qty, usertag, limitprice, stopprice) - - def stra_enter_long(self, stdCode:str, qty:float, usertag:str = "", limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def stra_enter_long(self, stdCode: str, qty: float, usertag: str = "", limitprice: float = 0.0, + stopprice: float = 0.0): + """ 多仓进场, 如果有空仓, 则平空再开多 @stdCode 品种代码 @qty 数量 @limitprice 限价, 默认为0 @stopprice 止价, 默认为0 - ''' + """ self.__wrapper__.cta_enter_long(self.__id__, stdCode, qty, usertag, limitprice, stopprice) - def stra_exit_long(self, stdCode:str, qty:float, usertag:str = "", limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def stra_exit_long(self, stdCode: str, qty: float, usertag: str = "", limitprice: float = 0.0, + stopprice: float = 0.0): + """ 多仓出场, 如果剩余多仓不够, 则全部平掉即可 @stdCode 品种代码 @qty 数量 @limitprice 限价, 默认为0 @stopprice 止价, 默认为0 - ''' + """ self.__wrapper__.cta_exit_long(self.__id__, stdCode, qty, usertag, limitprice, stopprice) - def stra_enter_short(self, stdCode:str, qty:float, usertag:str = "", limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def stra_enter_short(self, stdCode: str, qty: float, usertag: str = "", limitprice: float = 0.0, + stopprice: float = 0.0): + """ 空仓进场, 如果有多仓, 则平多再开空 @stdCode 品种代码 @qty 数量 @limitprice 限价, 默认为0 @stopprice 止价, 默认为0 - ''' + """ self.__wrapper__.cta_enter_short(self.__id__, stdCode, qty, usertag, limitprice, stopprice) - def stra_exit_short(self, stdCode:str, qty:float, usertag:str = "", limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def stra_exit_short(self, stdCode: str, qty: float, usertag: str = "", limitprice: float = 0.0, + stopprice: float = 0.0): + """ 空仓出场, 如果剩余空仓不够, 则全部平掉即可 @stdCode 品种代码 @qty 数量 @limitprice 限价, 默认为0 @stopprice 止价, 默认为0 - ''' + """ self.__wrapper__.cta_exit_short(self.__id__, stdCode, qty, usertag, limitprice, stopprice) - def stra_get_last_entrytime(self, stdCode:str) -> int: - ''' + def stra_get_last_entrytime(self, stdCode: str) -> int: + """ 获取当前持仓最后一次进场时间 @stdCode 品种代码 @return 返回最后一次开仓的时间, 格式如201903121047 - ''' + """ return self.__wrapper__.cta_get_last_entertime(self.__id__, stdCode) - def stra_get_last_entrytag(self, stdCode:str) -> str: - ''' + def stra_get_last_entrytag(self, stdCode: str) -> str: + """ 获取当前持仓最后一次进场标记 @stdCode 品种代码 @return 返回最后一次开仓标记 - ''' + """ return self.__wrapper__.cta_get_last_entertag(self.__id__, stdCode) - def stra_get_last_exittime(self, stdCode:str) -> int: - ''' + def stra_get_last_exittime(self, stdCode: str) -> int: + """ 获取当前持仓最后一次出场时间 @stdCode 品种代码 @return 返回最后一次开仓的时间, 格式如201903121047 - ''' + """ return self.__wrapper__.cta_get_last_exittime(self.__id__, stdCode) - def stra_get_first_entrytime(self, stdCode:str) -> int: - ''' + def stra_get_first_entrytime(self, stdCode: str) -> int: + """ 获取当前持仓第一次进场时间 @stdCode 品种代码 @return 返回最后一次开仓的时间, 格式如201903121047 - ''' + """ return self.__wrapper__.cta_get_first_entertime(self.__id__, stdCode) - - def user_save_data(self, key:str, val): - ''' + def user_save_data(self, key: str, val): + """ 保存用户数据 @key 数据id @val 数据值, 可以直接转换成str的数据均可 - ''' + """ self.__wrapper__.cta_save_user_data(self.__id__, key, str(val)) - def user_load_data(self, key:str, defVal = None, vType = float): - ''' + def user_load_data(self, key: str, defVal=None, vType=float): + """ 读取用户数据 @key 数据id @defVal 默认数据, 如果找不到则返回改数据, 默认为None @return 返回值, 默认处理为float数据 - ''' + """ ret = self.__wrapper__.cta_load_user_data(self.__id__, key, "") if ret == "": return defVal return vType(ret) - def stra_get_detail_profit(self, stdCode:str, usertag:str, flag:int = 0) -> float: - ''' + def stra_get_detail_profit(self, stdCode: str, usertag: str, flag: int = 0) -> float: + """ 获取指定标记的持仓的盈亏 @stdCode 合约代码 @usertag 进场标记 @flag 盈亏记号, 0-浮动盈亏, 1-最大浮盈, -1-最大亏损(负数), 2-最高浮动价格, -2-最低浮动价格 @return 盈亏 - ''' + """ return self.__wrapper__.cta_get_detail_profit(self.__id__, stdCode, usertag, flag) - def stra_get_detail_cost(self, stdCode:str, usertag:str) -> float: - ''' + def stra_get_detail_cost(self, stdCode: str, usertag: str) -> float: + """ 获取指定标记的持仓的开仓价 @stdCode 合约代码 @usertag 进场标记 @return 开仓价 - ''' + """ return self.__wrapper__.cta_get_detail_cost(self.__id__, stdCode, usertag) - def stra_get_detail_entertime(self, stdCode:str, usertag:str) -> int: - ''' + def stra_get_detail_entertime(self, stdCode: str, usertag: str) -> int: + """ 获取指定标记的持仓的进场时间 @stdCode 合约代码 @usertag 进场标记 @return 进场时间, 格式如201907260932 - ''' + """ return self.__wrapper__.cta_get_detail_entertime(self.__id__, stdCode, usertag) - + def stra_get_all_codes(self) -> list: - ''' + """ 获取全部合约代码列表 - ''' + """ if self.__engine__ is None: return [] return self.__engine__.getAllCodes() - - def stra_get_codes_by_product(self, stdPID:str) -> list: - ''' + + def stra_get_codes_by_product(self, stdPID: str) -> list: + """ 根据品种代码读取合约列表 @stdPID 品种代码,格式如SHFE.rb - ''' + """ if self.__engine__ is None: return [] return self.__engine__.getCodesByProduct(stdPID) - - def stra_get_codes_by_underlying(self, underlying:str) -> list: - ''' + + def stra_get_codes_by_underlying(self, underlying: str) -> list: + """ 根据underlying读取合约列表 @underlying 格式如CFFEX.IM2304 - ''' + """ if self.__engine__ is None: return [] return self.__engine__.getCodesByUnderlying(underlying) - def stra_get_comminfo(self, stdCode:str) -> ProductInfo: - ''' + def stra_get_comminfo(self, stdCode: str) -> ProductInfo: + """ 获取品种详情 @stdCode 合约代码如SHFE.ag.HOT, 或者品种代码如SHFE.ag @return 品种信息, 结构请参考ProductMgr中的ProductInfo - ''' + """ if self.__engine__ is None: return None return self.__engine__.getProductInfo(stdCode) - - def stra_get_contract(self, stdCode:str) -> ContractInfo: - ''' + + def stra_get_contract(self, stdCode: str) -> ContractInfo: + """ 获取合约详情,回测框架下支持不够完善,慎用! @stdCode 合约代码如SHFE.ag.2302 @return 品种信息, 结构请参考ContractMgr中的ContractInfo - ''' + """ if self.__engine__ is None: return None return self.__engine__.getContractInfo(stdCode) - def stra_get_rawcode(self, stdCode:str): - ''' + def stra_get_rawcode(self, stdCode: str): + """ 获取分月合约代码 @stdCode 连续合约代码如SHFE.ag.HOT @return 品种信息,结构请参考ProductMgr中的ProductInfo - ''' + """ if self.__engine__ is None: return "" return self.__engine__.getRawStdCode(stdCode) - def stra_get_sessinfo(self, stdCode:str) -> SessionInfo: - ''' + def stra_get_sessinfo(self, stdCode: str) -> SessionInfo: + """ 获取交易时段详情 @stdCode 合约代码如SHFE.ag.HOT, 或者品种代码如SHFE.ag @return 品种信息, 结构请参考SessionMgr中的SessionInfo - ''' + """ if self.__engine__ is None: return None return self.__engine__.getSessionByCode(stdCode) - def set_chart_kline(self, stdCode:str, period:str): - ''' + def set_chart_kline(self, stdCode: str, period: str): + """ 设置图表K线 @stdCode 合约代码 @period K线周期 - ''' + """ self.__wrapper__.cta_set_chart_kline(self.__id__, stdCode, period) - def add_chart_mark(self, price:float, icon:str, tag:str = 'Notag'): - ''' + def add_chart_mark(self, price: float, icon: str, tag: str = 'Notag'): + """ 添加图表标记 @price 价格, 决定图标出现的位置 @icon 图标, 系统一定的图标ID @tag 标签, 自定义的 - ''' + """ self.__wrapper__.cta_add_chart_mark(self.__id__, price, icon, tag) - def register_index(self, idxName:str, idxType:int = 1): - ''' + def register_index(self, idxName: str, idxType: int = 1): + """ 注册指标, on_init调用 @idxName 指标名 @idxType 指标类型, 0-主图指标, 1-副图指标 - ''' + """ self.__wrapper__.cta_register_index(self.__id__, idxName, idxType) - def register_index_line(self, idxName:str, lineName:str, lineType:int = 0) -> bool: - ''' + def register_index_line(self, idxName: str, lineName: str, lineType: int = 0) -> bool: + """ 注册指标线, on_init调用 @idxName 指标名称 @lineName 线名称 @lineType 线型, 0-曲线, 1-柱子 - ''' + """ return self.__wrapper__.cta_register_index_line(self.__id__, idxName, lineName, lineType) - def add_index_baseline(self, idxName:str, lineName:str, value:float) -> bool: - ''' + def add_index_baseline(self, idxName: str, lineName: str, value: float) -> bool: + """ 添加基准线, on_init调用 @idxName 指标名称 @lineName 线名称 @value 基准线数值 - ''' + """ return self.__wrapper__.cta_add_index_baseline(self.__id__, idxName, lineName, value) - def set_index_value(self, idxName:str, lineName:str, val:float) -> bool: - ''' + def set_index_value(self, idxName: str, lineName: str, val: float) -> bool: + """ 设置指标值, 只有在oncalc的时候才生效 @idxName 指标名称 @lineName 线名称 - ''' + """ return self.__wrapper__.cta_set_index_value(self.__id__, idxName, lineName, val) diff --git a/wtpy/ExtModuleDefs.py b/wtpy/ExtModuleDefs.py index 81bc1921..f58b55c8 100644 --- a/wtpy/ExtModuleDefs.py +++ b/wtpy/ExtModuleDefs.py @@ -1,14 +1,13 @@ - - class BaseExtParser: - ''' + """ 扩展行情接入模块基类 - ''' - def __init__(self, id:str): - ''' + """ + + def __init__(self, id: str): + """ 构造函数 @id 解析器ID - ''' + """ self.__id__ = id return @@ -16,56 +15,56 @@ def id(self) -> str: return self.__id__ def init(self, engine): - ''' + """ 初始化 - ''' + """ self.__engine__ = engine return def connect(self): - ''' + """ 开始连接 - ''' + """ return def disconnect(self): - ''' + """ 断开连接 - ''' + """ return def release(self): - ''' + """ 释放,一般是进程退出时调用 - ''' + """ return - def subscribe(self, fullCode:str): - ''' + def subscribe(self, fullCode: str): + """ 订阅实时行情 @fullCode 合约代码,格式如CFFEX.IF2106 - ''' + """ return - def unsubscribe(self, fullCode:str): - ''' + def unsubscribe(self, fullCode: str): + """ 退订实时行情 @fullCode 合约代码,格式如CFFEX.IF2106 - ''' + """ return class BaseExtExecuter: - ''' + """ 扩展执行器基类 - ''' + """ - def __init__(self, id:str, scale:float): - ''' + def __init__(self, id: str, scale: float): + """ 构造函数 @id 执行器ID @scale 数量放大倍数 - ''' + """ self.__id__ = id self.__scale__ = scale self.__targets__ = dict() @@ -73,16 +72,16 @@ def __init__(self, id:str, scale:float): def id(self): return self.__id__ - + def init(self): return - def set_position(self, stdCode:str, targetPos:float): - ''' + def set_position(self, stdCode: str, targetPos: float): + """ 设置目标部位 @stdCode 合约代码,期货格式为CFFEX.IF.2106 @targetPos 目标仓位,浮点数 - ''' + """ # 确定原来的目标仓位 oldPos = 0 @@ -93,74 +92,76 @@ def set_position(self, stdCode:str, targetPos:float): self.__targets__[stdCode] = targetPos return + class BaseExtDataLoader: def __init__(self): pass - def load_final_his_bars(self, stdCode:str, period:str, feeder) -> bool: - ''' + def load_final_his_bars(self, stdCode: str, period: str, feeder) -> bool: + """ 加载最终历史K线(回测、实盘) 该接口一般用于加载外部处理好的复权数据、主力合约数据 @stdCode 合约代码,格式如CFFEX.IF.2106 @period 周期,m1/m5/d1 @feeder 回调函数,feed_raw_bars(bars:POINTER(WTSBarStruct), count:int) - ''' + """ return False - def load_raw_his_bars(self, stdCode:str, period:str, feeder) -> bool: - ''' + def load_raw_his_bars(self, stdCode: str, period: str, feeder) -> bool: + """ 加载未加工的历史K线(回测、实盘) 该接口一般用于加载原始的K线数据,如未复权数据和分月合约数据 @stdCode 合约代码,格式如CFFEX.IF.2106 @period 周期,m1/m5/d1 @feeder 回调函数,feed_raw_bars(bars:POINTER(WTSBarStruct), count:int) - ''' + """ return False - def load_his_ticks(self, stdCode:str, uDate:int, feeder) -> bool: - ''' + def load_his_ticks(self, stdCode: str, uDate: int, feeder) -> bool: + """ 加载历史K线(只在回测有效,实盘只提供当日落地的) @stdCode 合约代码,格式如CFFEX.IF.2106 @uDate 日期,格式如yyyymmdd @feeder 回调函数,feed_raw_bars(bars:POINTER(WTSTickStruct), count:int) - ''' + """ return False - def load_adj_factors(self, stdCode:str = "", feeder = None) -> bool: - ''' + def load_adj_factors(self, stdCode: str = "", feeder=None) -> bool: + """ 加载的权因子 @stdCode 合约代码,格式如CFFEX.IF.2106,如果stdCode为空,则是加载全部除权数据,如果stdCode不为空,则按需加载 @feeder 回调函数,feed_adj_factors(stdCode:str, dates:list, factors:list) - ''' + """ return False + class BaseExtDataDumper: - def __init__(self, id:str): + def __init__(self, id: str): self.__id__ = id def id(self): return self.__id__ - def dump_his_bars(self, stdCode:str, period:str, bars, count:int) -> bool: - ''' + def dump_his_bars(self, stdCode: str, period: str, bars, count: int) -> bool: + """ 加载历史K线(回测、实盘) @stdCode 合约代码,格式如CFFEX.IF.2106 @period 周期,m1/m5/d1 @bars 回调函数,WTSBarStruct的指针 @count 数据条数 - ''' + """ return True - def dump_his_ticks(self, stdCode:str, uDate:int, ticks, count:int) -> bool: - ''' + def dump_his_ticks(self, stdCode: str, uDate: int, ticks, count: int) -> bool: + """ 加载历史K线(只在回测有效,实盘只提供当日落地的) @stdCode 合约代码,格式如CFFEX.IF.2106 @uDate 日期,格式如yyyymmdd @ticks 回调函数,WTSTickStruct的指针 @count 数据条数 - ''' - return True \ No newline at end of file + """ + return True diff --git a/wtpy/ExtToolDefs.py b/wtpy/ExtToolDefs.py index a58932cb..e5a23037 100644 --- a/wtpy/ExtToolDefs.py +++ b/wtpy/ExtToolDefs.py @@ -2,43 +2,46 @@ import time from threading import Thread + def fileToJson(filename, encoding="utf-8"): f = open(filename, 'r') content = f.read() f.close() try: return json.loads(content) - except: + except Exception as e: + print(e) return None + class BaseIndexWriter: - ''' + """ 基础指标输出工具 - ''' + """ def __init__(self): return - def write_indicator(self, id:str, tag:str, time:int, data:dict): - ''' + def write_indicator(self, id: str, tag: str, time: int, data: dict): + """ 将指标数据出 @id 指标ID @tag 数据标记 @time 指标时间 @data 数据对象, 一个dict - ''' + """ raise Exception("Basic writer cannot output index data to any media") class BaseDataReporter: - ''' + """ 数据报告器 - ''' - TaskReportRTData = 1 - TaskReportSettleData = 2 - TaskReportInitData = 3 + """ + TaskReportRTData = 1 + TaskReportSettleData = 2 + TaskReportInitData = 3 - def __init__(self, id:str): + def __init__(self, id: str): self.__inited__ = False self.__id__ = id return @@ -49,7 +52,7 @@ def init(self): self.__tasks__ = list() self.__stopped__ = False - #读取策略标记 + # 读取策略标记 filename = "./generated/marker.json" obj = fileToJson(filename) if obj is not None: @@ -88,7 +91,7 @@ def __task_loop__(self): time.sleep(1) continue else: - taskid = self.__tasks__.pop(0) + taskid = self.__tasks__.pop(0) if taskid == self.TaskReportRTData: self.__do_report_rt_data__() elif taskid == self.TaskReportSettleData: @@ -127,4 +130,3 @@ def report_init_data(self): self.__tasks__.append(self.TaskReportInitData) if self.__thrd_task__ is None: self.__start__() - \ No newline at end of file diff --git a/wtpy/HftContext.py b/wtpy/HftContext.py index 49ca28bf..d533c174 100644 --- a/wtpy/HftContext.py +++ b/wtpy/HftContext.py @@ -3,27 +3,28 @@ from wtpy.WtCoreDefs import WTSBarStruct, WTSOrdDtlStruct, WTSOrdQueStruct, WTSTickStruct, WTSTransStruct from wtpy.WtDataDefs import WtNpKline, WtNpOrdDetails, WtNpOrdQueues, WtNpTicks, WtNpTransactions + class HftContext: - ''' + """ Context是策略可以直接访问的唯一对象 策略所有的接口都通过Context对象调用 Context类包括以下几类接口: 1、时间接口(日期、时间等),接口格式如:stra_xxx 2、数据接口(K线、财务等),接口格式如:stra_xxx 3、下单接口(设置目标仓位、直接下单等),接口格式如:stra_xxx - ''' - - def __init__(self, id:int, stra, wrapper, engine): - self.__stra_info__ = stra #策略对象,对象基类BaseStrategy.py - self.__wrapper__ = wrapper #底层接口转换器 - self.__id__ = id #策略ID - self.__bar_cache__ = dict() #K线缓存 - self.__tick_cache__ = dict() #Tick缓存 - self.__ordque_cache__ = dict() #委托队列缓存 - self.__orddtl_cache__ = dict() #逐笔委托缓存 - self.__trans_cache__ = dict() #逐笔成交缓存 - self.__sname__ = stra.name() - self.__engine__ = engine #交易环境 + """ + + def __init__(self, id: int, stra, wrapper, engine): + self.__stra_info__ = stra # 策略对象,对象基类BaseStrategy.py + self.__wrapper__ = wrapper # 底层接口转换器 + self.__id__ = id # 策略ID + self.__bar_cache__ = dict() # K线缓存 + self.__tick_cache__ = dict() # Tick缓存 + self.__ordque_cache__ = dict() # 委托队列缓存 + self.__orddtl_cache__ = dict() # 逐笔委托缓存 + self.__trans_cache__ = dict() # 逐笔成交缓存 + self.__sname__ = stra.name() + self.__engine__ = engine # 交易环境 self.is_backtest = self.__engine__.is_backtest @@ -32,80 +33,80 @@ def id(self): return self.__id__ def on_init(self): - ''' + """ 初始化,一般用于系统启动的时候 - ''' + """ self.__stra_info__.on_init(self) - def on_session_begin(self, curTDate:int): - ''' + def on_session_begin(self, curTDate: int): + """ 交易日开始事件 @curTDate 交易日,格式为20210220 - ''' + """ self.__stra_info__.on_session_begin(self, curTDate) - def on_session_end(self, curTDate:int): - ''' + def on_session_end(self, curTDate: int): + """ 交易日结束事件 @curTDate 交易日,格式为20210220 - ''' + """ self.__stra_info__.on_session_end(self, curTDate) def on_backtest_end(self): - ''' + """ 回测结束事件 - ''' + """ self.__stra_info__.on_backtest_end(self) - def on_getticks(self, stdCode:str, newTicks:WtNpTicks): + def on_getticks(self, stdCode: str, newTicks: WtNpTicks): self.__tick_cache__[stdCode] = newTicks - def on_getbars(self, stdCode:str, period:str, npBars:WtNpKline): + def on_getbars(self, stdCode: str, period: str, npBars: WtNpKline): key = "%s#%s" % (stdCode, period) self.__bar_cache__[key] = npBars - def on_tick(self, stdCode:str, newTick:POINTER(WTSTickStruct)): - ''' + def on_tick(self, stdCode: str, newTick: POINTER(WTSTickStruct)): + """ tick回调事件响应 - ''' + """ self.__stra_info__.on_tick(self, stdCode, newTick.contents.to_dict()) - def on_order_queue(self, stdCode:str, newOrdQue:POINTER(WTSOrdQueStruct)): - ''' + def on_order_queue(self, stdCode: str, newOrdQue: POINTER(WTSOrdQueStruct)): + """ 委托队列回调事件响应 - ''' + """ self.__stra_info__.on_order_queue(self, stdCode, newOrdQue.contents.to_dict()) - def on_get_order_queue(self, stdCode:str, newOdrQues:WtNpOrdQueues): - ''' + def on_get_order_queue(self, stdCode: str, newOdrQues: WtNpOrdQueues): + """ 委托队列数据获取事件响应 - ''' + """ self.__ordque_cache__[stdCode] = newOdrQues - def on_order_detail(self, stdCode:str, newOrdDtl:POINTER(WTSOrdDtlStruct)): - ''' + def on_order_detail(self, stdCode: str, newOrdDtl: POINTER(WTSOrdDtlStruct)): + """ 逐笔委托回调事件响应 - ''' + """ self.__stra_info__.on_order_detail(self, stdCode, newOrdDtl.contents.to_dict()) - def on_get_order_detail(self, stdCode:str, newOrdDtls:WtNpOrdDetails): - ''' + def on_get_order_detail(self, stdCode: str, newOrdDtls: WtNpOrdDetails): + """ 逐笔委托数据获取事件响应 - ''' + """ self.__orddtl_cache__[stdCode] = newOrdDtls - def on_transaction(self, stdCode:str, newTrans:POINTER(WTSTransStruct)): - ''' + def on_transaction(self, stdCode: str, newTrans: POINTER(WTSTransStruct)): + """ 逐笔成交回调事件响应 - ''' + """ self.__stra_info__.on_transaction(self, stdCode, newTrans.contents.to_dict()) - def on_get_transaction(self, stdCode:str, newTranses:WtNpTransactions): - ''' + def on_get_transaction(self, stdCode: str, newTranses: WtNpTransactions): + """ 逐笔成交数据获取事件响应 - ''' + """ key = stdCode self.__trans_cache__[key] = newTranses @@ -115,31 +116,32 @@ def on_channel_ready(self): def on_channel_lost(self): self.__stra_info__.on_channel_lost(self) - def on_entrust(self, localid:int, stdCode:str, bSucc:bool, msg:str, userTag:str): + def on_entrust(self, localid: int, stdCode: str, bSucc: bool, msg: str, userTag: str): self.__stra_info__.on_entrust(self, localid, stdCode, bSucc, msg, userTag) - def on_position(self, stdCode:str, isLong:bool, prevol:float, preavail:float, newvol:float, newavail:float): + def on_position(self, stdCode: str, isLong: bool, prevol: float, preavail: float, newvol: float, newavail: float): self.__stra_info__.on_position(self, stdCode, isLong, prevol, preavail, newvol, newavail) - def on_order(self, localid:int, stdCode:str, isBuy:bool, totalQty:float, leftQty:float, price:float, isCanceled:bool, userTag:str): + def on_order(self, localid: int, stdCode: str, isBuy: bool, totalQty: float, leftQty: float, price: float, + isCanceled: bool, userTag: str): self.__stra_info__.on_order(self, localid, stdCode, isBuy, totalQty, leftQty, price, isCanceled, userTag) - def on_trade(self, localid:int, stdCode:str, isBuy:bool, qty:float, price:float, userTag:str): + def on_trade(self, localid: int, stdCode: str, isBuy: bool, qty: float, price: float, userTag: str): self.__stra_info__.on_trade(self, localid, stdCode, isBuy, qty, price, userTag) - - def on_bar(self, stdCode:str, period:str, newBar:POINTER(WTSBarStruct)): - ''' + + def on_bar(self, stdCode: str, period: str, newBar: POINTER(WTSBarStruct)): + """ K线闭合事件响应 @stdCode 品种代码 @period K线基础周期 @times 周期倍数 @newBar 最新K线 - ''' + """ key = "%s#%s" % (stdCode, period) if key not in self.__bar_cache__: return - + try: self.__stra_info__.on_bar(self, stdCode, period, newBar.contents.to_dict()) except ValueError as ve: @@ -147,248 +149,247 @@ def on_bar(self, stdCode:str, period:str, newBar:POINTER(WTSBarStruct)): else: return - def stra_log_text(self, message:str, level:int = 1): - ''' + def stra_log_text(self, message: str, level: int = 1): + """ 输出日志 @level 日志级别,0-debug,1-info,2-warn,3-error @message 消息内容,最大242字符 - ''' + """ self.__wrapper__.hft_log_text(self.__id__, level, message[:242]) - + def stra_get_date(self): - ''' + """ 获取当前日期 @return int,格式如20180513 - ''' + """ return self.__wrapper__.hft_get_date() def stra_get_time(self): - ''' + """ 获取当前时间,24小时制,精确到分 @return int,格式如1231 - ''' + """ return self.__wrapper__.hft_get_time() def stra_get_secs(self): - ''' + """ 获取当前秒数,精确到毫秒 @return int,格式如1231 - ''' + """ return self.__wrapper__.hft_get_secs() def stra_get_price(self, stdCode): - ''' + """ 获取最新价格,一般在获取了K线以后再获取该价格 @return 最新价格 - ''' + """ return self.__wrapper__.hft_get_price(stdCode) - - def stra_prepare_bars(self, stdCode:str, period:str, count:int): - ''' + + def stra_prepare_bars(self, stdCode: str, period: str, count: int): + """ 准备历史K线 一般在on_init调用 @stdCode 合约代码 @period K线周期, 如m3/d7 @count 要拉取的K线条数 @isMain 是否是主K线 - ''' + """ self.__wrapper__.hft_get_bars(self.__id__, stdCode, period, count) - def stra_get_bars(self, stdCode:str, period:str, count:int) -> WtNpKline: - ''' + def stra_get_bars(self, stdCode: str, period: str, count: int) -> WtNpKline: + """ 获取历史K线 @stdCode 合约代码 @period K线周期,如m3/d7 @count 要拉取的K线条数 @isMain 是否是主K线 - ''' + """ key = "%s#%s" % (stdCode, period) - cnt = self.__wrapper__.hft_get_bars(self.__id__, stdCode, period, count) + cnt = self.__wrapper__.hft_get_bars(self.__id__, stdCode, period, count) if cnt == 0: return None return self.__bar_cache__[key] - def stra_get_ticks(self, stdCode:str, count:int) -> WtNpTicks: - ''' + def stra_get_ticks(self, stdCode: str, count: int) -> WtNpTicks: + """ 获取tick数据 @stdCode 合约代码 @count 要拉取的tick数量 - ''' + """ cnt = self.__wrapper__.hft_get_ticks(self.__id__, stdCode, count) if cnt == 0: return None - + return self.__tick_cache__[stdCode] - def stra_get_order_queue(self, stdCode:str, count:int) -> WtNpOrdQueues: - ''' + def stra_get_order_queue(self, stdCode: str, count: int) -> WtNpOrdQueues: + """ 获取委托队列数据 @stdCode 合约代码 @count 要拉取的tick数量 - ''' + """ cnt = self.__wrapper__.hft_get_ordque(self.__id__, stdCode, count) if cnt == 0: return None - - return self.__ordque_cache__[stdCode] - def stra_get_order_detail(self, stdCode:str, count:int) -> WtNpOrdDetails: - ''' + return self.__ordque_cache__[stdCode] + + def stra_get_order_detail(self, stdCode: str, count: int) -> WtNpOrdDetails: + """ 获取逐笔委托数据 @stdCode 合约代码 @count 要拉取的tick数量 - ''' + """ if stdCode in self.__orddtl_cache__: - #这里做一个数据长度处理 + # 这里做一个数据长度处理 return self.__orddtl_cache__[stdCode] cnt = self.__wrapper__.hft_get_orddtl(self.__id__, stdCode, count) if cnt == 0: return None - + return self.__orddtl_cache__[stdCode] - def stra_get_transaction(self, stdCode:str, count:int) -> WtNpTransactions: - ''' + def stra_get_transaction(self, stdCode: str, count: int) -> WtNpTransactions: + """ 获取逐笔成交数据 @stdCode 合约代码 @count 要拉取的tick数量 - ''' + """ cnt = self.__wrapper__.hft_get_trans(self.__id__, stdCode, count) if cnt == 0: return None - + return self.__trans_cache__[stdCode] - def stra_get_position(self, stdCode:str, bonlyvalid:bool = False): - ''' + def stra_get_position(self, stdCode: str, bonlyvalid: bool = False): + """ 读取当前仓位 @stdCode 合约/股票代码 @return 正为多仓,负为空仓 - ''' + """ return self.__wrapper__.hft_get_position(self.__id__, stdCode, bonlyvalid) - def stra_get_position_profit(self, stdCode:str = ""): - ''' + def stra_get_position_profit(self, stdCode: str = ""): + """ 读取指定持仓的浮动盈亏 @stdCode 合约/股票代码 @return 指定持仓的浮动盈亏 - ''' + """ return self.__wrapper__.hft_get_position_profit(self.__id__, stdCode) - def stra_get_position_avgpx(self, stdCode:str = ""): - ''' + def stra_get_position_avgpx(self, stdCode: str = ""): + """ 读取指定持仓的持仓均价 @stdCode 合约/股票代码 @return 指定持仓的浮动盈亏 - ''' + """ return self.__wrapper__.hft_get_position_avgpx(self.__id__, stdCode) - def stra_get_undone(self, stdCode:str): + def stra_get_undone(self, stdCode: str): return self.__wrapper__.hft_get_undone(self.__id__, stdCode) - - def user_save_data(self, key:str, val): - ''' + def user_save_data(self, key: str, val): + """ 保存用户数据 @key 数据id @val 数据值,可以直接转换成str的数据均可 - ''' + """ self.__wrapper__.hft_save_user_data(self.__id__, key, str(val)) - def user_load_data(self, key:str, defVal = None, vType = float): - ''' + def user_load_data(self, key: str, defVal=None, vType=float): + """ 读取用户数据 @key 数据id @defVal 默认数据,如果找不到则返回改数据,默认为None @return 返回值,默认处理为float数据 - ''' + """ ret = self.__wrapper__.hft_load_user_data(self.__id__, key, "") if ret == "": return defVal return vType(ret) - def stra_get_rawcode(self, stdCode:str): - ''' + def stra_get_rawcode(self, stdCode: str): + """ 获取分月合约代码 @stdCode 连续合约代码如SHFE.ag.HOT @return 品种信息,结构请参考ProductMgr中的ProductInfo - ''' + """ if self.__engine__ is None: return "" return self.__engine__.getRawStdCode(stdCode) - def stra_get_comminfo(self, stdCode:str): - ''' + def stra_get_comminfo(self, stdCode: str): + """ 获取品种详情 @stdCode 合约代码如SHFE.ag.HOT,或者品种代码如SHFE.ag @return 品种信息,结构请参考ProductMgr中的ProductInfo - ''' + """ if self.__engine__ is None: return None return self.__engine__.getProductInfo(stdCode) - - def stra_get_sessinfo(self, stdCode:str) -> SessionInfo: - ''' + + def stra_get_sessinfo(self, stdCode: str) -> SessionInfo: + """ 获取交易时段详情 @stdCode 合约代码如SHFE.ag.HOT,或者品种代码如SHFE.ag @return 品种信息,结构请参考SessionMgr中的SessionInfo - ''' + """ if self.__engine__ is None: return None return self.__engine__.getSessionByCode(stdCode) - def stra_sub_ticks(self, stdCode:str): - ''' + def stra_sub_ticks(self, stdCode: str): + """ 订阅实时行情数据 获取K线和tick数据的时候会自动订阅,这里只需要订阅额外要检测的品种即可 @stdCode 品种代码 - ''' + """ self.__wrapper__.hft_sub_ticks(self.__id__, stdCode) - def stra_sub_order_queue(self, stdCode:str): - ''' + def stra_sub_order_queue(self, stdCode: str): + """ 订阅实时委托队列数据 @id 策略ID @stdCode 品种代码 - ''' + """ self.__wrapper__.hft_sub_order_queue(self.__id__, stdCode) - - def stra_sub_order_detail(self, stdCode:str): - ''' + + def stra_sub_order_detail(self, stdCode: str): + """ 订阅逐笔委托数据 @id 策略ID @stdCode 品种代码 - ''' + """ self.__wrapper__.hft_sub_order_detail(self.__id__, stdCode) - - def stra_sub_transaction(self, stdCode:str): - ''' + + def stra_sub_transaction(self, stdCode: str): + """ 订阅逐笔成交数据 @id 策略ID @stdCode 品种代码 - ''' + """ self.__wrapper__.hft_sub_transaction(self.__id__, stdCode) - def stra_cancel(self, localid:int): - ''' + def stra_cancel(self, localid: int): + """ 撤销指定订单 @id 策略ID @localid 下单时返回的本地订单号 - ''' + """ return self.__wrapper__.hft_cancel(self.__id__, localid) - def stra_cancel_all(self, stdCode:str, isBuy:bool): - ''' + def stra_cancel_all(self, stdCode: str, isBuy: bool): + """ 撤销指定品种的全部买入订单or卖出订单 @id 策略ID @stdCode 品种代码 @isBuy 买入or卖出 - ''' + """ idstr = self.__wrapper__.hft_cancel_all(self.__id__, stdCode, isBuy) if len(idstr) == 0: return list() @@ -399,66 +400,66 @@ def stra_cancel_all(self, stdCode:str, isBuy:bool): localids.append(int(localid)) return localids - def stra_buy(self, stdCode:str, price:float, qty:float, userTag:str = "", flag:int = 0): - ''' + def stra_buy(self, stdCode: str, price: float, qty: float, userTag: str = "", flag: int = 0): + """ 买入指令 @id 策略ID @stdCode 品种代码 @price 买入价格, 0为市价 @qty 买入数量 @flag 下单标志, 0-normal, 1-fak, 2-fok - ''' + """ idstr = self.__wrapper__.hft_buy(self.__id__, stdCode, price, qty, userTag, flag) if len(idstr) == 0: return list() - + ids = idstr.split(",") localids = list() for localid in ids: localids.append(int(localid)) return localids - def stra_sell(self, stdCode:str, price:float, qty:float, userTag:str = "", flag:int = 0): - ''' + def stra_sell(self, stdCode: str, price: float, qty: float, userTag: str = "", flag: int = 0): + """ 卖出指令 @id 策略ID @stdCode 品种代码 @price 卖出价格, 0为市价 @qty 卖出数量 @flag 下单标志, 0-normal, 1-fak, 2-fok - ''' + """ idstr = self.__wrapper__.hft_sell(self.__id__, stdCode, price, qty, userTag, flag) if len(idstr) == 0: return list() - + ids = idstr.split(",") localids = list() for localid in ids: localids.append(int(localid)) return localids - + def stra_get_all_codes(self) -> list: - ''' + """ 获取全部合约代码列表 - ''' + """ if self.__engine__ is None: return [] return self.__engine__.getAllCodes() - - def stra_get_codes_by_product(self, stdPID:str) -> list: - ''' + + def stra_get_codes_by_product(self, stdPID: str) -> list: + """ 根据品种代码读取合约列表 @stdPID 品种代码,格式如SHFE.rb - ''' + """ if self.__engine__ is None: return [] return self.__engine__.getCodesByProduct(stdPID) - - def stra_get_codes_by_underlying(self, underlying:str) -> list: - ''' + + def stra_get_codes_by_underlying(self, underlying: str) -> list: + """ 根据underlying读取合约列表 @underlying 格式如CFFEX.IM2304 - ''' + """ if self.__engine__ is None: return [] return self.__engine__.getCodesByUnderlying(underlying) diff --git a/wtpy/ProductMgr.py b/wtpy/ProductMgr.py index 28a28f31..c850f8b8 100644 --- a/wtpy/ProductMgr.py +++ b/wtpy/ProductMgr.py @@ -3,33 +3,36 @@ import os import chardet + class ProductInfo: - ''' + """ 品种信息 - ''' + """ def __init__(self): - self.exchg = '' #交易所 - self.product = '' #品种代码 - self.name = '' #品种名称 - self.session = '' #交易时段名 - self.pricetick = 0 #价格变动单位 - self.volscale = 1 #数量乘数 - self.minlots = 1 #最小交易数量 - self.lotstick = 1 #交易数量变动单位 + self.exchg = '' # 交易所 + self.product = '' # 品种代码 + self.name = '' # 品种名称 + self.session = '' # 交易时段名 + self.pricetick = 0 # 价格变动单位 + self.volscale = 1 # 数量乘数 + self.minlots = 1 # 最小交易数量 + self.lotstick = 1 # 交易数量变动单位 + class ProductMgr: - ''' + """ 品种信息管理器 - ''' + """ + def __init__(self): self.__products__ = dict() return - def load(self, fname:str): - ''' + def load(self, fname: str): + """ 从文件加载品种信息 - ''' + """ if not os.path.exists(fname): return f = open(fname, 'rb') @@ -63,15 +66,15 @@ def load(self, fname:str): key = "%s.%s" % (exchg, pid) self.__products__[key] = pInfo - - def addProductInfo(self, key:str, pInfo:ProductInfo): + + def addProductInfo(self, key: str, pInfo: ProductInfo): self.__products__[key] = pInfo - def getProductInfo(self, pid:str) -> ProductInfo: - #pid形式可能为SHFE.ag.HOT,或者SHFE.ag.1912,或者SHFE.ag + def getProductInfo(self, pid: str) -> ProductInfo: + # pid形式可能为SHFE.ag.HOT,或者SHFE.ag.1912,或者SHFE.ag items = pid.split(".") key = items[0] + "." + items[1] if key not in self.__products__: return None - return self.__products__[key] \ No newline at end of file + return self.__products__[key] diff --git a/wtpy/SelContext.py b/wtpy/SelContext.py index 5599e5b3..6b403fb8 100644 --- a/wtpy/SelContext.py +++ b/wtpy/SelContext.py @@ -2,38 +2,39 @@ from wtpy.WtCoreDefs import WTSBarStruct, WTSTickStruct from wtpy.WtDataDefs import WtNpKline, WtNpTicks + class SelContext: - ''' + """ Context是策略可以直接访问的唯一对象 策略所有的接口都通过Context对象调用 Context类包括以下几类接口: 1、时间接口(日期、时间等),接口格式如:stra_xxx 2、数据接口(K线、财务等),接口格式如:stra_xxx 3、下单接口(设置目标仓位、直接下单等),接口格式如:stra_xxx - ''' - - def __init__(self, id:int, stra, wrapper, engine): - self.__stra_info__ = stra #策略对象,对象基类BaseStrategy.py - self.__wrapper__ = wrapper #底层接口转换器 - self.__id__ = id #策略ID - self.__bar_cache__ = dict() #K线缓存 - self.__tick_cache__ = dict() #tTick缓存,每次都重新去拉取,这个只做中转用,不在python里维护副本 - self.__sname__ = stra.name() - self.__engine__ = engine #交易环境 + """ + + def __init__(self, id: int, stra, wrapper, engine): + self.__stra_info__ = stra # 策略对象,对象基类BaseStrategy.py + self.__wrapper__ = wrapper # 底层接口转换器 + self.__id__ = id # 策略ID + self.__bar_cache__ = dict() # K线缓存 + self.__tick_cache__ = dict() # tTick缓存,每次都重新去拉取,这个只做中转用,不在python里维护副本 + self.__sname__ = stra.name() + self.__engine__ = engine # 交易环境 self.__pos_cache__ = None self.is_backtest = self.__engine__.is_backtest self.__alias__() - + @property def id(self): return self.__id__ - + def __alias__(self): - ''' + """ 接口函数别名 - ''' + """ self.get_all_position = self.stra_get_all_position self.get_bars = self.stra_get_bars self.get_comminfo = self.stra_get_comminfo @@ -63,73 +64,73 @@ def __alias__(self): pass def write_indicator(self, tag, time, data): - ''' + """ 输出指标数据 @tag 指标标签 @time 输出时间 @data 输出的指标数据,dict类型,会转成json以后保存 - ''' + """ self.__engine__.write_indicator(self.__stra_info__.name(), tag, time, data) def on_init(self): - ''' + """ 初始化,一般用于系统启动的时候 - ''' + """ self.__stra_info__.on_init(self) - def on_session_begin(self, curTDate:int): - ''' + def on_session_begin(self, curTDate: int): + """ 交易日开始事件 @curTDate 交易日,格式为20210220 - ''' + """ self.__stra_info__.on_session_begin(self, curTDate) - def on_session_end(self, curTDate:int): - ''' + def on_session_end(self, curTDate: int): + """ 交易日结束事件 @curTDate 交易日,格式为20210220 - ''' + """ self.__stra_info__.on_session_end(self, curTDate) def on_backtest_end(self): - ''' + """ 回测结束事件 - ''' + """ self.__stra_info__.on_backtest_end(self) - def on_getticks(self, stdCode:str, newTicks:WtNpTicks): + def on_getticks(self, stdCode: str, newTicks: WtNpTicks): key = stdCode self.__tick_cache__[key] = newTicks - def on_getpositions(self, stdCode:str, qty:float, frozen:float): + def on_getpositions(self, stdCode: str, qty: float, frozen: float): if len(stdCode) == 0: return self.__pos_cache__[stdCode] = qty - def on_getbars(self, stdCode:str, period:str, npBars:WtNpKline): + def on_getbars(self, stdCode: str, period: str, npBars: WtNpKline): key = "%s#%s" % (stdCode, period) self.__bar_cache__[key] = npBars - def on_tick(self, stdCode:str, newTick:POINTER(WTSTickStruct)): + def on_tick(self, stdCode: str, newTick: POINTER(WTSTickStruct)): self.__stra_info__.on_tick(self, stdCode, newTick.contents.to_dict()) - def on_bar(self, stdCode:str, period:str, newBar:POINTER(WTSBarStruct)): - ''' + def on_bar(self, stdCode: str, period: str, newBar: POINTER(WTSBarStruct)): + """ K线闭合事件响应 @stdCode 品种代码 @period K线基础周期 @times 周期倍数 @newBar 最新K线 - ''' + """ key = "%s#%s" % (stdCode, period) if key not in self.__bar_cache__: return - + try: self.__stra_info__.on_bar(self, stdCode, period, newBar.contents.to_dict()) except ValueError as ve: @@ -143,296 +144,296 @@ def on_calculate(self): def on_calculate_done(self): self.__stra_info__.on_calculate_done(self) - def stra_log_text(self, message:str, level:int = 1): - ''' + def stra_log_text(self, message: str, level: int = 1): + """ 输出日志 @level 日志级别,0-debug,1-info,2-warn,3-error @message 消息内容,最大242字符 - ''' + """ self.__wrapper__.sel_log_text(self.__id__, level, message[:242]) - + def stra_get_tdate(self) -> int: - ''' + """ 获取当前交易日 @return int, 格式如20180513 - ''' + """ return self.__wrapper__.sel_get_tdate() - + def stra_get_date(self): - ''' + """ 获取当前日期 @return int,格式如20180513 - ''' + """ return self.__wrapper__.sel_get_date() - - def stra_get_position_avgpx(self, stdCode:str = "") -> float: - ''' + + def stra_get_position_avgpx(self, stdCode: str = "") -> float: + """ 获取当前持仓均价 @stdCode 合约代码 @return 持仓均价 - ''' + """ return self.__wrapper__.sel_get_position_avgpx(self.__id__, stdCode) - def stra_get_position_profit(self, stdCode:str = "") -> float: - ''' + def stra_get_position_profit(self, stdCode: str = "") -> float: + """ 获取持仓浮动盈亏 @stdCode 合约代码, 为None时读取全部品种的浮动盈亏 @return 浮动盈亏 - ''' + """ return self.__wrapper__.sel_get_position_profit(self.__id__, stdCode) - def stra_get_fund_data(self, flag:int = 0) -> float: - ''' + def stra_get_fund_data(self, flag: int = 0) -> float: + """ 获取资金数据 @flag 0-动态权益, 1-总平仓盈亏, 2-总浮动盈亏, 3-总手续费 @return 资金数据 - ''' + """ return self.__wrapper__.sel_get_fund_data(self.__id__, flag) def stra_get_time(self): - ''' + """ 获取当前时间,24小时制,精确到分 @return int,格式如1231 - ''' + """ return self.__wrapper__.sel_get_time() def stra_get_price(self, stdCode): - ''' + """ 获取最新价格,一般在获取了K线以后再获取该价格 @return 最新价格 - ''' + """ return self.__wrapper__.sel_get_price(stdCode) - - def stra_get_day_price(self, stdCode:str, flag:int = 0) -> float: - ''' + + def stra_get_day_price(self, stdCode: str, flag: int = 0) -> float: + """ 获取当日价格 @flag 价格标记, 0-开盘价, 1-最高价, 2-最低价, 3-最新价 @return 最新价格 - ''' + """ return self.__wrapper__.sel_get_day_price(stdCode, flag) def stra_get_all_position(self): - ''' + """ 获取全部持仓 - ''' - self.__pos_cache__ = dict() # + """ + self.__pos_cache__ = dict() # self.__wrapper__.sel_get_all_position(self.__id__) return self.__pos_cache__ - - def stra_prepare_bars(self, stdCode:str, period:str, count:int): - ''' + + def stra_prepare_bars(self, stdCode: str, period: str, count: int): + """ 准备历史K线 一般在on_init调用 @stdCode 合约代码 @period K线周期, 如m3/d7 @count 要拉取的K线条数 @isMain 是否是主K线 - ''' + """ self.__wrapper__.sel_get_bars(self.__id__, stdCode, period, count) - def stra_get_bars(self, stdCode:str, period:str, count:int) -> WtNpKline: - ''' + def stra_get_bars(self, stdCode: str, period: str, count: int) -> WtNpKline: + """ 获取历史K线 @stdCode 合约代码 @period K线周期, 如m3/d7 @count 要拉取的K线条数 @isMain 是否是主K线 - ''' + """ key = "%s#%s" % (stdCode, period) # 每次都重新构造,不然onbar处理会更麻烦 - cnt = self.__wrapper__.sel_get_bars(self.__id__, stdCode, period, count) + cnt = self.__wrapper__.sel_get_bars(self.__id__, stdCode, period, count) if cnt == 0: return None npBars = self.__bar_cache__[key] return npBars - - def stra_get_ticks(self, stdCode:str, count:int) -> WtNpTicks: - ''' + + def stra_get_ticks(self, stdCode: str, count: int) -> WtNpTicks: + """ 获取tick数据 @stdCode 合约代码 @count 要拉取的tick数量 - ''' + """ self.__tick_cache__[stdCode] = WtNpTicks() cnt = self.__wrapper__.sel_get_ticks(self.__id__, stdCode, count) if cnt == 0: return None - + np_ticks = self.__tick_cache__[stdCode] return np_ticks - def stra_sub_ticks(self, stdCode:str): - ''' + def stra_sub_ticks(self, stdCode: str): + """ 订阅实时行情 @stdCode 合约代码 - ''' + """ self.__wrapper__.sel_sub_ticks(stdCode) - def stra_get_position(self, stdCode:str, bonlyvalid:bool = False, usertag:str = "") -> float: - ''' + def stra_get_position(self, stdCode: str, bonlyvalid: bool = False, usertag: str = "") -> float: + """ 读取当前仓位 @stdCode 合约/股票代码 @usertag 入场标记 @return 正为多仓,负为空仓 - ''' + """ return self.__wrapper__.sel_get_position(self.__id__, stdCode, bonlyvalid, usertag) - def stra_set_position(self, stdCode:str, qty:float, usertag:str = ""): - ''' + def stra_set_position(self, stdCode: str, qty: float, usertag: str = ""): + """ 设置仓位 @stdCode 合约/股票代码 @qty 目标仓位,正为多仓,负为空仓 @return 设置结果TRUE/FALSE - ''' + """ self.__wrapper__.sel_set_position(self.__id__, stdCode, qty, usertag) - def stra_get_last_entrytime(self, stdCode:str) -> int: - ''' + def stra_get_last_entrytime(self, stdCode: str) -> int: + """ 获取当前持仓最后一次进场时间 @stdCode 品种代码 @return 返回最后一次开仓的时间, 格式如201903121047 - ''' + """ return self.__wrapper__.sel_get_last_entertime(self.__id__, stdCode) - def stra_get_last_entrytag(self, stdCode:str) -> str: - ''' + def stra_get_last_entrytag(self, stdCode: str) -> str: + """ 获取当前持仓最后一次进场标记 @stdCode 品种代码 @return 返回最后一次开仓标记 - ''' + """ return self.__wrapper__.sel_get_last_entertag(self.__id__, stdCode) - def stra_get_last_exittime(self, stdCode:str) -> int: - ''' + def stra_get_last_exittime(self, stdCode: str) -> int: + """ 获取当前持仓最后一次出场时间 @stdCode 品种代码 @return 返回最后一次开仓的时间, 格式如201903121047 - ''' + """ return self.__wrapper__.sel_get_last_exittime(self.__id__, stdCode) - def stra_get_first_entrytime(self, stdCode:str) -> int: - ''' + def stra_get_first_entrytime(self, stdCode: str) -> int: + """ 获取当前持仓第一次进场时间 @stdCode 品种代码 @return 返回最后一次开仓的时间, 格式如201903121047 - ''' + """ return self.__wrapper__.sel_get_first_entertime(self.__id__, stdCode) - - def user_save_data(self, key:str, val): - ''' + + def user_save_data(self, key: str, val): + """ 保存用户数据 @key 数据id @val 数据值,可以直接转换成str的数据均可 - ''' + """ self.__wrapper__.sel_save_user_data(self.__id__, key, str(val)) - def user_load_data(self, key:str, defVal = None, vType = float): - ''' + def user_load_data(self, key: str, defVal=None, vType=float): + """ 读取用户数据 @key 数据id @defVal 默认数据,如果找不到则返回改数据,默认为None @return 返回值,默认处理为float数据 - ''' + """ ret = self.__wrapper__.sel_load_user_data(self.__id__, key, "") if ret == "": return defVal return vType(ret) - - def stra_get_detail_profit(self, stdCode:str, usertag:str, flag:int = 0) -> float: - ''' + + def stra_get_detail_profit(self, stdCode: str, usertag: str, flag: int = 0) -> float: + """ 获取指定标记的持仓的盈亏 @stdCode 合约代码 @usertag 进场标记 @flag 盈亏记号, 0-浮动盈亏, 1-最大浮盈, -1-最大亏损(负数), 2-最高浮动价格, -2-最低浮动价格 @return 盈亏 - ''' + """ return self.__wrapper__.sel_get_detail_profit(self.__id__, stdCode, usertag, flag) - def stra_get_detail_cost(self, stdCode:str, usertag:str) -> float: - ''' + def stra_get_detail_cost(self, stdCode: str, usertag: str) -> float: + """ 获取指定标记的持仓的开仓价 @stdCode 合约代码 @usertag 进场标记 @return 开仓价 - ''' + """ return self.__wrapper__.sel_get_detail_cost(self.__id__, stdCode, usertag) - def stra_get_detail_entertime(self, stdCode:str, usertag:str) -> int: - ''' + def stra_get_detail_entertime(self, stdCode: str, usertag: str) -> int: + """ 获取指定标记的持仓的进场时间 @stdCode 合约代码 @usertag 进场标记 @return 进场时间, 格式如201907260932 - ''' + """ return self.__wrapper__.sel_get_detail_entertime(self.__id__, stdCode, usertag) - - def stra_get_comminfo(self, stdCode:str): - ''' + + def stra_get_comminfo(self, stdCode: str): + """ 获取品种详情 @stdCode 合约代码如SHFE.ag.HOT,或者品种代码如SHFE.ag @return 品种信息,结构请参考ProductMgr中的ProductInfo - ''' + """ if self.__engine__ is None: return None return self.__engine__.getProductInfo(stdCode) - def stra_get_rawcode(self, stdCode:str): - ''' + def stra_get_rawcode(self, stdCode: str): + """ 获取分月合约代码 @stdCode 连续合约代码如SHFE.ag.HOT @return 品种信息,结构请参考ProductMgr中的ProductInfo - ''' + """ if self.__engine__ is None: return "" return self.__engine__.getRawStdCode(stdCode) - def stra_get_sessioninfo(self, stdCode:str): - ''' + def stra_get_sessioninfo(self, stdCode: str): + """ 获取品种详情 @stdCode 合约代码如SHFE.ag.HOT,或者品种代码如SHFE.ag @return 品种信息,结构请参考ProductMgr中的ProductInfo - ''' + """ if self.__engine__ is None: return None return self.__engine__.getSessionByCode(stdCode) - def stra_get_contract(self, stdCode:str): - ''' + def stra_get_contract(self, stdCode: str): + """ 获取品种详情 @stdCode 合约代码如SHFE.ag.HOT,或者品种代码如SHFE.ag @return 品种信息,结构请参考ProductMgr中的ProductInfo - ''' + """ if self.__engine__ is None: return None return self.__engine__.getContractInfo(stdCode) def stra_get_all_codes(self) -> list: - ''' + """ 获取全部合约代码列表 - ''' + """ if self.__engine__ is None: return [] return self.__engine__.getAllCodes() - - def stra_get_codes_by_product(self, stdPID:str) -> list: - ''' + + def stra_get_codes_by_product(self, stdPID: str) -> list: + """ 根据品种代码读取合约列表 @stdPID 品种代码,格式如SHFE.rb - ''' + """ if self.__engine__ is None: return [] return self.__engine__.getCodesByProduct(stdPID) - - def stra_get_codes_by_underlying(self, underlying:str) -> list: - ''' + + def stra_get_codes_by_underlying(self, underlying: str) -> list: + """ 根据underlying读取合约列表 @underlying 格式如CFFEX.IM2304 - ''' + """ if self.__engine__ is None: return [] return self.__engine__.getCodesByUnderlying(underlying) diff --git a/wtpy/SessionMgr.py b/wtpy/SessionMgr.py index 8129fdd4..08893522 100644 --- a/wtpy/SessionMgr.py +++ b/wtpy/SessionMgr.py @@ -3,12 +3,14 @@ import yaml import chardet + class SectionInfo: def __init__(self): self.stime = 0 self.etime = 0 + class SessionInfo: def __init__(self): @@ -20,9 +22,9 @@ def __init__(self): self.totalMins = 0 def toString(self): - ''' + """ 将SessionInfo转换成json字符串 - ''' + """ obj = dict() obj["name"] = self.name obj["offset"] = self.offset @@ -40,28 +42,27 @@ def toString(self): return json.dumps(obj, ensure_ascii=True) - - def offsetTime(self, rawTime:int): - curMinute = math.floor(rawTime/100)*60 + rawTime%100 + def offsetTime(self, rawTime: int): + curMinute = math.floor(rawTime / 100) * 60 + rawTime % 100 curMinute += self.offset if curMinute >= 1440: curMinute -= 1440 elif curMinute < 0: curMinute += 1440 - - return math.floor(curMinute/60)*100 + curMinute%60 - def originalTime(self, offTime:int): - curMinute = math.floor(offTime/100)*60 + offTime%100 + return math.floor(curMinute / 60) * 100 + curMinute % 60 + + def originalTime(self, offTime: int): + curMinute = math.floor(offTime / 100) * 60 + offTime % 100 curMinute -= self.offset if curMinute >= 1440: curMinute -= 1440 elif curMinute < 0: curMinute += 1440 - - return math.floor(curMinute/60)*100 + curMinute%60 - def getOpenTime(self, bOffset:bool = False): + return math.floor(curMinute / 60) * 100 + curMinute % 60 + + def getOpenTime(self, bOffset: bool = False): if len(self.sections) == 0: return 0 @@ -71,7 +72,7 @@ def getOpenTime(self, bOffset:bool = False): else: return opentm - def getCloseTime(self, bOffset:bool = False): + def getCloseTime(self, bOffset: bool = False): if len(self.sections) == 0: return 0 @@ -84,58 +85,58 @@ def getCloseTime(self, bOffset:bool = False): def getTradingMins(self): if len(self.sections) == 0: return 0 - + if self.totalMins == 0: for sec in self.sections: s = sec.stime e = sec.etime h = math.floor(e / 100) - math.floor(s / 100) - m = (e%100) - (s%100) - self.totalMins += (h*60 + m) + m = (e % 100) - (s % 100) + self.totalMins += (h * 60 + m) return self.totalMins def getTradingSecs(self): - return self.getTradingMins()*60 + return self.getTradingMins() * 60 - def getSectionIndex(self, rawTime:int) -> int: + def getSectionIndex(self, rawTime: int) -> int: offTime = self.offsetTime(rawTime) for idx in range(len(self.sections)): sec = self.sections[idx] - if sec.stime <= offTime and offTime <= sec.etime: + if sec.stime <= offTime <= sec.etime: return idx - + return -1 - def isLastOfSection(self, rawTime:int): + def isLastOfSection(self, rawTime: int): offTime = self.offsetTime(rawTime) for sec in self.sections: if sec.etime == offTime: return True - + return False - def isInTradingTime(self, rawTime:int, bStrict:bool = False): + def isInTradingTime(self, rawTime: int, bStrict: bool = False): mins = self.timeToMinutes(rawTime) if mins == -1: return False if bStrict and self.isLastOfSection(rawTime): return False - + return True - def isFirstOfSection(self, rawTime:int): + def isFirstOfSection(self, rawTime: int): offTime = self.offsetTime(rawTime) for sec in self.sections: if sec.stime == offTime: return True - + return False - def timeToMinutes(self, rawTime:int): + def timeToMinutes(self, rawTime: int): if len(self.sections) == 0: return -1 @@ -144,30 +145,30 @@ def timeToMinutes(self, rawTime:int): bFound = False offset = 0 for sec in self.sections: - if sec.stime <= offTime and offTime <= sec.etime: + if sec.stime <= offTime <= sec.etime: hour = math.floor(offTime / 100) - math.floor(sec.stime / 100) minute = offTime % 100 - sec.stime % 100 - offset += hour*60 + minute + offset += hour * 60 + minute bFound = True break else: hour = math.floor(sec.etime / 100) - math.floor(sec.stime / 100) minute = sec.etime % 100 - sec.stime % 100 - offset += hour*60 + minute + offset += hour * 60 + minute if not bFound: return -1 return offset - def minutesToTime(self, minutes:int, bHeadFirst:bool = False): + def minutesToTime(self, minutes: int, bHeadFirst: bool = False): if len(self.sections) == 0: return -1 offset = minutes for sec in self.sections: - startMin = math.floor(sec.stime / 100)*60 + sec.stime % 100 - stopMin = math.floor(sec.etime / 100)*60 + sec.etime % 100 + startMin = math.floor(sec.stime / 100) * 60 + sec.stime % 100 + stopMin = math.floor(sec.etime / 100) * 60 + sec.etime % 100 if not bHeadFirst: if startMin + offset >= stopMin: @@ -192,14 +193,14 @@ def minutesToTime(self, minutes:int, bHeadFirst:bool = False): return self.getCloseTime() + class SessionMgr: def __init__(self): self.__sessions__ = dict() return - - def load(self, fname:str): + def load(self, fname: str): f = open(fname, 'rb') content = f.read() f.close() @@ -210,7 +211,7 @@ def load(self, fname:str): sessions_dict = yaml.full_load(content) else: sessions_dict = json.loads(content) - for sid in sessions_dict: + for sid in sessions_dict: if sid in self.__sessions__: continue @@ -232,9 +233,8 @@ def load(self, fname:str): self.__sessions__[sid] = sInfo - - def getSession(self, sid:str) -> SessionInfo: + def getSession(self, sid: str) -> SessionInfo: if sid not in self.__sessions__: return None - return self.__sessions__[sid] \ No newline at end of file + return self.__sessions__[sid] diff --git a/wtpy/StrategyDefs.py b/wtpy/StrategyDefs.py index b9419583..4cf23979 100644 --- a/wtpy/StrategyDefs.py +++ b/wtpy/StrategyDefs.py @@ -1,62 +1,61 @@ from wtpy import CtaContext, SelContext, HftContext + class BaseCtaStrategy: - ''' + """ CTA策略基础类,所有的策略都从该类派生 包含了策略的基本开发框架 - ''' - def __init__(self, name:str): + """ + + def __init__(self, name: str): self.__name__ = name - - + def name(self) -> str: return self.__name__ - - def on_init(self, context:CtaContext): - ''' + def on_init(self, context: CtaContext): + """ 策略初始化,启动的时候调用 用于加载自定义数据 @context 策略运行上下文 - ''' + """ return - def on_session_begin(self, context:CtaContext, curTDate:int): - ''' + def on_session_begin(self, context: CtaContext, curTDate: int): + """ 交易日开始事件 @curTDate 交易日,格式为20210220 - ''' + """ return - def on_session_end(self, context:CtaContext, curTDate:int): - ''' + def on_session_end(self, context: CtaContext, curTDate: int): + """ 交易日结束事件 @curTDate 交易日,格式为20210220 - ''' + """ return - - def on_calculate(self, context:CtaContext): - ''' + + def on_calculate(self, context: CtaContext): + """ K线闭合时调用,一般作为策略的核心计算模块 @context 策略运行上下文 - ''' + """ return - def on_calculate_done(self, context:CtaContext): - ''' + def on_calculate_done(self, context: CtaContext): + """ K线闭合时调用,一般作为策略的核心计算模块 @context 策略运行上下文 - ''' + """ return - - def on_tick(self, context:CtaContext, stdCode:str, newTick:dict): - ''' + def on_tick(self, context: CtaContext, stdCode: str, newTick: dict): + """ 逐笔数据进来时调用 生产环境中,每笔行情进来就直接调用 回测环境中,是模拟的逐笔数据 @@ -64,30 +63,30 @@ def on_tick(self, context:CtaContext, stdCode:str, newTick:dict): @context 策略运行上下文 @stdCode 合约代码 @newTick 最新逐笔 - ''' + """ return - def on_bar(self, context:CtaContext, stdCode:str, period:str, newBar:dict): - ''' + def on_bar(self, context: CtaContext, stdCode: str, period: str, newBar: dict): + """ K线闭合时回调 @context 策略上下文 @stdCode 合约代码 @period K线周期 @newBar 最新闭合的K线 - ''' + """ return - def on_backtest_end(self, context:CtaContext): - ''' + def on_backtest_end(self, context: CtaContext): + """ 回测结束时回调,只在回测框架下会触发 @context 策略上下文 - ''' + """ return - def on_condition_triggered(self, context:CtaContext, stdCode:str, target:float, price:float, usertag:str): - ''' + def on_condition_triggered(self, context: CtaContext, stdCode: str, target: float, price: float, usertag: str): + """ 条件单触发回调 @context 策略上下文 @@ -95,124 +94,124 @@ def on_condition_triggered(self, context:CtaContext, stdCode:str, target:float, @target 触发以后的最终目标仓位 @price 触发价格 @usertag 用户标记 - ''' + """ return + class BaseHftStrategy: - ''' + """ HFT策略基础类,所有的策略都从该类派生 包含了策略的基本开发框架 - ''' - def __init__(self, name:str): + """ + + def __init__(self, name: str): self.__name__ = name - - + def name(self) -> str: return self.__name__ - - def on_init(self, context:HftContext): - ''' + def on_init(self, context: HftContext): + """ 策略初始化,启动的时候调用 用于加载自定义数据 @context 策略运行上下文 - ''' + """ return - - def on_session_begin(self, context:HftContext, curTDate:int): - ''' + + def on_session_begin(self, context: HftContext, curTDate: int): + """ 交易日开始事件 @curTDate 交易日,格式为20210220 - ''' + """ return - def on_session_end(self, context:HftContext, curTDate:int): - ''' + def on_session_end(self, context: HftContext, curTDate: int): + """ 交易日结束事件 @curTDate 交易日,格式为20210220 - ''' + """ return - def on_backtest_end(self, context:CtaContext): - ''' + def on_backtest_end(self, context: CtaContext): + """ 回测结束时回调,只在回测框架下会触发 @context 策略上下文 - ''' + """ return - def on_tick(self, context:HftContext, stdCode:str, newTick:dict): - ''' + def on_tick(self, context: HftContext, stdCode: str, newTick: dict): + """ Tick数据进来时调用 @context 策略运行上下文 @stdCode 合约代码 @newTick 最新Tick - ''' + """ return - def on_order_detail(self, context:HftContext, stdCode:str, newOrdQue:dict): - ''' + def on_order_detail(self, context: HftContext, stdCode: str, newOrdQue: dict): + """ 逐笔委托数据进来时调用 @context 策略运行上下文 @stdCode 合约代码 @newOrdQue 最新逐笔委托 - ''' + """ return - def on_order_queue(self, context:HftContext, stdCode:str, newOrdQue:dict): - ''' + def on_order_queue(self, context: HftContext, stdCode: str, newOrdQue: dict): + """ 委托队列数据进来时调用 @context 策略运行上下文 @stdCode 合约代码 @newOrdQue 最新委托队列 - ''' + """ return - def on_transaction(self, context:HftContext, stdCode:str, newTrans:dict): - ''' + def on_transaction(self, context: HftContext, stdCode: str, newTrans: dict): + """ 逐笔成交数据进来时调用 @context 策略运行上下文 @stdCode 合约代码 @newTrans 最新逐笔成交 - ''' + """ return - def on_bar(self, context:HftContext, stdCode:str, period:str, newBar:dict): - ''' + def on_bar(self, context: HftContext, stdCode: str, period: str, newBar: dict): + """ K线闭合时回调 @context 策略上下文 @stdCode 合约代码 @period K线周期 @newBar 最新闭合的K线 - ''' + """ return - def on_channel_ready(self, context:HftContext): - ''' + def on_channel_ready(self, context: HftContext): + """ 交易通道就绪通知 @context 策略上下文 - ''' + """ return - def on_channel_lost(self, context:HftContext): - ''' + def on_channel_lost(self, context: HftContext): + """ 交易通道丢失通知 @context 策略上下文 - ''' + """ return - def on_entrust(self, context:HftContext, localid:int, stdCode:str, bSucc:bool, msg:str, userTag:str): - ''' + def on_entrust(self, context: HftContext, localid: int, stdCode: str, bSucc: bool, msg: str, userTag: str): + """ 下单结果回报 @context 策略上下文 @@ -220,11 +219,12 @@ def on_entrust(self, context:HftContext, localid:int, stdCode:str, bSucc:bool, m @stdCode 合约代码 @bSucc 下单结果 @mes 下单结果描述 - ''' + """ return - def on_order(self, context:HftContext, localid:int, stdCode:str, isBuy:bool, totalQty:float, leftQty:float, price:float, isCanceled:bool, userTag:str): - ''' + def on_order(self, context: HftContext, localid: int, stdCode: str, isBuy: bool, totalQty: float, leftQty: float, + price: float, isCanceled: bool, userTag: str): + """ 订单回报 @context 策略上下文 @localid 本地订单id @@ -234,11 +234,12 @@ def on_order(self, context:HftContext, localid:int, stdCode:str, isBuy:bool, tot @leftQty 剩余数量 @price 下单价格 @isCanceled 是否已撤单 - ''' + """ return - def on_trade(self, context:HftContext, localid:int, stdCode:str, isBuy:bool, qty:float, price:float, userTag:str): - ''' + def on_trade(self, context: HftContext, localid: int, stdCode: str, isBuy: bool, qty: float, price: float, + userTag: str): + """ 成交回报 @context 策略上下文 @@ -246,11 +247,12 @@ def on_trade(self, context:HftContext, localid:int, stdCode:str, isBuy:bool, qty @isBuy 是否买入 @qty 成交数量 @price 成交价格 - ''' + """ return - def on_position(self, context:HftContext, stdCode:str, isLong:bool, prevol:float, preavail:float, newvol:float, newavail:float): - ''' + def on_position(self, context: HftContext, stdCode: str, isLong: bool, prevol: float, preavail: float, + newvol: float, newavail: float): + """ 初始持仓回报 实盘可用, 回测的时候初始仓位都是空, 所以不需要 @@ -261,86 +263,86 @@ def on_position(self, context:HftContext, stdCode:str, isLong:bool, prevol:float @preavail 可用昨仓 @newvol 今仓 @newavail 可用今仓 - ''' + """ return + class BaseSelStrategy: - ''' + """ 选股策略基础类,所有的多因子策略都从该类派生 包含了策略的基本开发框架 - ''' - def __init__(self, name:str): + """ + + def __init__(self, name: str): self.__name__ = name - - + def name(self) -> str: return self.__name__ - - def on_init(self, context:SelContext): - ''' + def on_init(self, context: SelContext): + """ 策略初始化,启动的时候调用 用于加载自定义数据 @context 策略运行上下文 - ''' + """ return - - def on_session_begin(self, context:SelContext, curTDate:int): - ''' + + def on_session_begin(self, context: SelContext, curTDate: int): + """ 交易日开始事件 @curTDate 交易日,格式为20210220 - ''' + """ return - def on_session_end(self, context:SelContext, curTDate:int): - ''' + def on_session_end(self, context: SelContext, curTDate: int): + """ 交易日结束事件 @curTDate 交易日,格式为20210220 - ''' + """ return - - def on_calculate(self, context:SelContext): - ''' + + def on_calculate(self, context: SelContext): + """ K线闭合时调用,一般作为策略的核心计算模块 @context 策略运行上下文 - ''' + """ return - def on_calculate_done(self, context:SelContext): - ''' + def on_calculate_done(self, context: SelContext): + """ K线闭合时调用,一般作为策略的核心计算模块 @context 策略运行上下文 - ''' + """ return - def on_backtest_end(self, context:CtaContext): - ''' + def on_backtest_end(self, context: CtaContext): + """ 回测结束时回调,只在回测框架下会触发 @context 策略上下文 - ''' + """ return - def on_tick(self, context:SelContext, stdCode:str, newTick:dict): - ''' + def on_tick(self, context: SelContext, stdCode: str, newTick: dict): + """ 逐笔数据进来时调用 生产环境中,每笔行情进来就直接调用 回测环境中,是模拟的逐笔数据 @context 策略运行上下文 @stdCode 合约代码 @newTick 最新逐笔 - ''' + """ return - def on_bar(self, context:SelContext, stdCode:str, period:str, newBar:dict): - ''' + def on_bar(self, context: SelContext, stdCode: str, period: str, newBar: dict): + """ K线闭合时回调 @context 策略上下文 @stdCode 合约代码 @period K线周期 @newBar 最新闭合的K线 - ''' - return \ No newline at end of file + """ + return diff --git a/wtpy/WtBtEngine.py b/wtpy/WtBtEngine.py index a421c794..bacd34f5 100644 --- a/wtpy/WtBtEngine.py +++ b/wtpy/WtBtEngine.py @@ -19,53 +19,55 @@ import chardet import os + @singleton class WtBtEngine: - def __init__(self, eType:EngineType = EngineType.ET_CTA, logCfg:str = "logcfgbt.yaml", isFile:bool = True, bDumpCfg:bool = False, outDir:str = "./outputs_bt"): - ''' + def __init__(self, eType: EngineType = EngineType.ET_CTA, logCfg: str = "logcfgbt.yaml", isFile: bool = True, + bDumpCfg: bool = False, outDir: str = "./outputs_bt"): + """ 构造函数 @eType 引擎类型 @logCfg 日志模块配置文件,也可以直接是配置内容字符串 @isFile 是否文件,如果是文件,则将logCfg当做文件路径处理,如果不是文件,则直接当成json格式的字符串进行解析 @bDumpCfg 回测的实际配置文件是否落地 @outDir 回测数据输出目录 - ''' + """ self.is_backtest = True - self.__wrapper__ = WtBtWrapper(self) #api接口转换器 - self.__context__ = None #策略ctx映射表 - self.__config__ = dict() #框架配置项 - self.__cfg_commited__ = False #配置是否已提交 + self.__wrapper__ = WtBtWrapper(self) # api接口转换器 + self.__context__ = None # 策略ctx映射表 + self.__config__ = dict() # 框架配置项 + self.__cfg_commited__ = False # 配置是否已提交 - self.__idx_writer__ = None #指标输出模块 + self.__idx_writer__ = None # 指标输出模块 - self.__dump_config__ = bDumpCfg #是否保存最终配置 + self.__dump_config__ = bDumpCfg # 是否保存最终配置 self.__is_cfg_yaml__ = False - - self.trading_day = 0 #当前交易日 - self.__ext_data_loader__:BaseExtDataLoader = None #扩展历史数据加载器 + self.trading_day = 0 # 当前交易日 + + self.__ext_data_loader__: BaseExtDataLoader = None # 扩展历史数据加载器 if eType == eType.ET_CTA: - self.__wrapper__.initialize_cta(logCfg, isFile, outDir) #初始化CTA环境 + self.__wrapper__.initialize_cta(logCfg, isFile, outDir) # 初始化CTA环境 elif eType == eType.ET_HFT: - self.__wrapper__.initialize_hft(logCfg, isFile, outDir) #初始化HFT环境 + self.__wrapper__.initialize_hft(logCfg, isFile, outDir) # 初始化HFT环境 elif eType == eType.ET_SEL: - self.__wrapper__.initialize_sel(logCfg, isFile, outDir) #初始化SEL环境 + self.__wrapper__.initialize_sel(logCfg, isFile, outDir) # 初始化SEL环境 def __check_config__(self): - ''' + """ 检查设置项 主要会补充一些默认设置项 - ''' + """ if "replayer" not in self.__config__: self.__config__["replayer"] = dict() self.__config__["replayer"]["basefiles"] = dict() self.__config__["replayer"]["mode"] = "csv" self.__config__["replayer"]["store"] = { - "path":"./storage/" + "path": "./storage/" } if "basefiles" not in self.__config__["replayer"]: @@ -81,38 +83,38 @@ def __check_config__(self): self.__config__["env"] = dict() self.__config__["env"]["mocker"] = "cta" - def set_writer(self, writer:BaseIndexWriter): - ''' + def set_writer(self, writer: BaseIndexWriter): + """ 设置指标输出模块 - ''' + """ self.__idx_writer__ = writer - def write_indicator(self, id:str, tag:str, time:int, data:dict): - ''' + def write_indicator(self, id: str, tag: str, time: int, data: dict): + """ 写入指标数据 @id 指标id @tag 标签,主要用于区分指标对应的周期,如m5/d @time 时间,如yyyymmddHHMM @data 指标值 - ''' + """ if self.__idx_writer__ is not None: self.__idx_writer__.write_indicator(id, tag, time, data) - def init_with_config(self, folder:str, - config:dict, - commfile:str = None, - contractfile:str = None, - sessionfile:str = None, - holidayfile:str= None, - hotfile:str = None, - secondfile:str = None): + def init_with_config(self, folder: str, + config: dict, + commfile: str = None, + contractfile: str = None, + sessionfile: str = None, + holidayfile: str = None, + hotfile: str = None, + secondfile: str = None): self.__config__ = config.copy() self.__check_config__() if contractfile is not None: self.__config__["replayer"]["basefiles"]["contract"] = os.path.join(folder, contractfile) - + if sessionfile is not None: self.__config__["replayer"]["basefiles"]["session"] = os.path.join(folder, sessionfile) @@ -130,37 +132,37 @@ def init_with_config(self, folder:str, self.productMgr = ProductMgr() if self.__config__["replayer"]["basefiles"]["commodity"] is not None: - if type(self.__config__["replayer"]["basefiles"]["commodity"]) == str: + if isinstance(self.__config__["replayer"]["basefiles"]["commodity"], str): self.productMgr.load(self.__config__["replayer"]["basefiles"]["commodity"]) - elif type(self.__config__["replayer"]["basefiles"]["commodity"]) == list: + elif isinstance(self.__config__["replayer"]["basefiles"]["commodity"], list): for fname in self.__config__["replayer"]["basefiles"]["commodity"]: self.productMgr.load(fname) self.contractMgr = ContractMgr(self.productMgr) - if type(self.__config__["replayer"]["basefiles"]["contract"]) == str: + if isinstance(self.__config__["replayer"]["basefiles"]["contract"], str): self.contractMgr.load(self.__config__["replayer"]["basefiles"]["contract"]) - elif type(self.__config__["replayer"]["basefiles"]["contract"]) == list: + elif isinstance(self.__config__["replayer"]["basefiles"]["contract"], list): for fname in self.__config__["replayer"]["basefiles"]["contract"]: self.contractMgr.load(fname) self.sessionMgr = SessionMgr() self.sessionMgr.load(self.__config__["replayer"]["basefiles"]["session"]) - def init(self, folder:str, - cfgfile:str = "configbt.yaml", - commfile:str = None, - contractfile:str = None, - sessionfile:str = None, - holidayfile:str= None, - hotfile:str = None, - secondfile:str = None): - ''' + def init(self, folder: str, + cfgfile: str = "configbt.yaml", + commfile: str = None, + contractfile: str = None, + sessionfile: str = None, + holidayfile: str = None, + hotfile: str = None, + secondfile: str = None): + """ 初始化 @folder 基础数据文件目录,\\结尾 @cfgfile 配置文件,json/yaml格式 @commfile 品种定义文件,json/yaml格式 @contractfile 合约定义文件,json/yaml格式 - ''' + """ f = open(cfgfile, "rb") content = f.read() f.close() @@ -168,66 +170,68 @@ def init(self, folder:str, content = content.decode(encoding) if cfgfile.lower().endswith(".json"): - self.init_with_config(folder, json.loads(content), commfile, contractfile, sessionfile, holidayfile, hotfile, secondfile) + self.init_with_config(folder, json.loads(content), commfile, contractfile, sessionfile, holidayfile, + hotfile, secondfile) self.__is_cfg_yaml__ = False else: - self.init_with_config(folder, yaml.full_load(content), commfile, contractfile, sessionfile, holidayfile, hotfile, secondfile) - self.__is_cfg_yaml__ = True + self.init_with_config(folder, yaml.full_load(content), commfile, contractfile, sessionfile, holidayfile, + hotfile, secondfile) + self.__is_cfg_yaml__ = True - def configMocker(self, name:str): - ''' + def configMocker(self, name: str): + """ 设置模拟器 - ''' + """ self.__config__["env"]["mocker"] = name - def configBacktest(self, stime:int, etime:int): - ''' + def configBacktest(self, stime: int, etime: int): + """ 配置回测设置项 @stime 开始时间 @etime 结束时间 - ''' + """ self.__config__["replayer"]["stime"] = int(stime) self.__config__["replayer"]["etime"] = int(etime) - def configBTStorage(self, mode:str, path:str = None, storage:dict = None): - ''' + def configBTStorage(self, mode: str, path: str = None, storage: dict = None): + """ 配置数据存储 @mode 存储模式,csv-表示从csv直接读取,一般回测使用,wtp-表示使用wt框架自带数据存储 - ''' + """ self.__config__["replayer"]["mode"] = mode if path is not None: self.__config__["replayer"]["store"] = { - "path":path + "path": path } if storage is not None: self.__config__["replayer"]["store"] = storage - def configIncrementalBt(self, incrementBtBase:str): - ''' + def configIncrementalBt(self, incrementBtBase: str): + """ 设置增量 - ''' + """ self.__config__["env"]["incremental_backtest_base"] = incrementBtBase - - def registerCustomRule(self, ruleTag:str, filename:str): - ''' + + def registerCustomRule(self, ruleTag: str, filename: str): + """ 注册自定义连续合约规则 @ruleTag 规则标签,如ruleTag为THIS,对应的连续合约代码为CFFEX.IF.THIS @filename 规则定义文件名,和hots.json格式一样 - ''' + """ if "rules" not in self.__config__["replayer"]["basefiles"]: self.__config__["replayer"]["basefiles"]["rules"] = dict() self.__config__["replayer"]["basefiles"]["rules"][ruleTag] = filename - def setExternalCtaStrategy(self, id:str, module:str, typeName:str, params:dict): - ''' + def setExternalCtaStrategy(self, id: str, module: str, typeName: str, params: dict): + """ 添加C++的CTA策略 @id 策略ID @module 策略模块文件名,包含后缀,如:WzCtaFact.dll @typeName 模块内的策略类名 @params 策略参数 - ''' + """ if "cta" not in self.__config__: self.__config__["cta"] = dict() @@ -239,16 +243,15 @@ def setExternalCtaStrategy(self, id:str, module:str, typeName:str, params:dict): self.__config__["cta"]["strategy"]["id"] = id self.__config__["cta"]["strategy"]["name"] = typeName self.__config__["cta"]["strategy"]["params"] = params - - def setExternalHftStrategy(self, id:str, module:str, typeName:str, params:dict): - ''' + def setExternalHftStrategy(self, id: str, module: str, typeName: str, params: dict): + """ 添加C++的HFT策略 @id 策略ID @module 策略模块文件名,包含后缀,如:WzHftFact.dll @typeName 模块内的策略类名 @params 策略参数 - ''' + """ if "hft" not in self.__config__: self.__config__["hft"] = dict() @@ -261,27 +264,27 @@ def setExternalHftStrategy(self, id:str, module:str, typeName:str, params:dict): self.__config__["hft"]["strategy"]["name"] = typeName self.__config__["hft"]["strategy"]["params"] = params - def set_extended_data_loader(self, loader:BaseExtDataLoader, bAutoTrans:bool = True): - ''' + def set_extended_data_loader(self, loader: BaseExtDataLoader, bAutoTrans: bool = True): + """ 设置扩展数据加载器 @loader 数据加载器模块 @bAutoTrans 是否自动转储,如果是的话底层就转成dsb文件 - ''' + """ self.__ext_data_loader__ = loader self.__wrapper__.register_extended_data_loader(bAutoTrans) def get_extended_data_loader(self) -> BaseExtDataLoader: - ''' + """ 获取扩展的数据加载器 - ''' + """ return self.__ext_data_loader__ def commitBTConfig(self): - ''' + """ 提交配置 只有第一次调用会生效,不可重复调用 如果执行run之前没有调用,run会自动调用该方法 - ''' + """ if self.__cfg_commited__: return @@ -299,11 +302,11 @@ def commitBTConfig(self): f.write(cfgfile) f.close() - def getSessionByCode(self, code:str) -> SessionInfo: - ''' + def getSessionByCode(self, code: str) -> SessionInfo: + """ 通过合约代码获取交易时间模板 @code 合约代码,格式如SHFE.rb.HOT - ''' + """ pid = CodeHelper.stdCodeToStdCommID(code) pInfo = self.productMgr.getProductInfo(pid) @@ -312,64 +315,65 @@ def getSessionByCode(self, code:str) -> SessionInfo: return self.sessionMgr.getSession(pInfo.session) - def getSessionByName(self, sname:str) -> SessionInfo: - ''' + def getSessionByName(self, sname: str) -> SessionInfo: + """ 通过模板名获取交易时间模板 @sname 模板名 - ''' + """ return self.sessionMgr.getSession(sname) - def getProductInfo(self, code:str) -> ProductInfo: - ''' + def getProductInfo(self, code: str) -> ProductInfo: + """ 获取品种信息 @code 合约代码,格式如SHFE.rb.HOT - ''' + """ return self.productMgr.getProductInfo(code) - def getContractInfo(self, code:str) -> ContractInfo: - ''' + def getContractInfo(self, code: str) -> ContractInfo: + """ 获取品种信息 @code 合约代码,格式如SHFE.rb.HOT - ''' + """ return self.contractMgr.getContractInfo(code, self.trading_day) def getAllCodes(self) -> list: - ''' + """ 获取全部合约代码 - ''' + """ return self.contractMgr.getTotalCodes(self.trading_day) - def getRawStdCode(self, stdCode:str): - ''' + def getRawStdCode(self, stdCode: str): + """ 根据连续合约代码获取原始合约代码 - ''' + """ return self.__wrapper__.get_raw_stdcode(stdCode) - - def getCodesByProduct(self, stdPID:str) -> list: - ''' + + def getCodesByProduct(self, stdPID: str) -> list: + """ 根据品种id获取对应合约代码 @stdPID 品种代码, 格式如SHFE.rb - ''' + """ return self.contractMgr.getCodesByProduct(stdPID, self.trading_day) - - def getCodesByUnderlying(self, underlying:str) -> list: - ''' + + def getCodesByUnderlying(self, underlying: str) -> list: + """ 根据underlying获取对应合约代码 @underlying 品种代码, 格式如SHFE.rb2305 - ''' + """ return self.contractMgr.getCodesByUnderlying(underlying, self.trading_day) - def set_time_range(self, beginTime:int, endTime:int): - ''' + def set_time_range(self, beginTime: int, endTime: int): + """ 设置回测时间 一般用于一个进程中多次回测的时候启动下一轮回测之前重设之间范围 @beginTime 开始时间,格式如yyyymmddHHMM @endTime 结束时间,格式如yyyymmddHHMM - ''' + """ self.__wrapper__.set_time_range(beginTime, endTime) - def set_cta_strategy(self, strategy:BaseCtaStrategy, slippage:int = 0, hook:bool = False, persistData:bool = True, incremental:bool = False, isRatioSlp:bool = False): - ''' + def set_cta_strategy(self, strategy: BaseCtaStrategy, slippage: int = 0, hook: bool = False, + persistData: bool = True, incremental: bool = False, isRatioSlp: bool = False): + """ 添加CTA策略 @strategy 策略对象 @slippage 滑点大小 @@ -377,85 +381,87 @@ def set_cta_strategy(self, strategy:BaseCtaStrategy, slippage:int = 0, hook:bool @persistData 回测生成的数据是否落地, 默认为True @incremental 是否增量回测, 默认为False, 如果为True, 则会自动根据策略ID到output_bt目录下加载对应的数据 @isRatioSlp 滑点是否是比例, 默认为False, 如果为True, 则slippage为万分比 - ''' + """ ctxid = self.__wrapper__.init_cta_mocker(strategy.name(), slippage, hook, persistData, incremental, isRatioSlp) self.__context__ = CtaContext(ctxid, strategy, self.__wrapper__, self) - def set_hft_strategy(self, strategy:BaseHftStrategy, hook:bool = False): - ''' + def set_hft_strategy(self, strategy: BaseHftStrategy, hook: bool = False): + """ 添加HFT策略 @strategy 策略对象 @hook 是否安装钩子,主要用于单步控制重算 - ''' + """ ctxid = self.__wrapper__.init_hft_mocker(strategy.name(), hook) self.__context__ = HftContext(ctxid, strategy, self.__wrapper__, self) - def set_sel_strategy(self, strategy:BaseSelStrategy, date:int=0, time:int=0, period:str="d", trdtpl:str="CHINA", session:str="TRADING", slippage:int = 0, isRatioSlp:bool = False): - ''' + def set_sel_strategy(self, strategy: BaseSelStrategy, date: int = 0, time: int = 0, period: str = "d", + trdtpl: str = "CHINA", session: str = "TRADING", slippage: int = 0, isRatioSlp: bool = False): + """ 添加SEL策略 @strategy 策略对象 @date 日期,根据周期变化,每日为0,每周为0~6,对应周日到周六,每月为1~31,每年为0101~1231 - @time 时间,精确到分钟 - @period 时间周期,可以是分钟min、天d、周w、月m、年y + @time 时间,精确到分钟 + @period 时间周期,可以是分钟min、天d、周w、月m、年y @trdtpl 交易日历模板,默认为CHINA @session 交易时间模板,默认为TRADING @slippage 滑点大小 @isRatioSlp 滑点是否是比例, 默认为False, 如果为True, 则slippage为万分比 - ''' - ctxid = self.__wrapper__.init_sel_mocker(strategy.name(), date, time, period, trdtpl, session, slippage, isRatioSlp) + """ + ctxid = self.__wrapper__.init_sel_mocker(strategy.name(), date, time, period, trdtpl, session, slippage, + isRatioSlp) self.__context__ = SelContext(ctxid, strategy, self.__wrapper__, self) - def get_context(self, id:int): + def get_context(self, id: int): return self.__context__ - def run_backtest(self, bAsync:bool = False, bNeedDump:bool = True): - ''' + def run_backtest(self, bAsync: bool = False, bNeedDump: bool = True): + """ 运行框架 @bAsync 是否异步运行,默认为false。如果不启动异步模式,则强化学习的训练环境也不能生效,即使策略下了钩子 - ''' - if not self.__cfg_commited__: #如果配置没有提交,则自动提交一下 + """ + if not self.__cfg_commited__: # 如果配置没有提交,则自动提交一下 self.commitBTConfig() - self.__wrapper__.run_backtest(bNeedDump = bNeedDump, bAsync = bAsync) + self.__wrapper__.run_backtest(bNeedDump=bNeedDump, bAsync=bAsync) - def cta_step(self, remark:str = "") -> bool: - ''' - CTA策略单步执行 + def cta_step(self, remark: str = "") -> bool: + """ + CTA策略单步 执行 - @remark 单步备注信息,没有实际作用,主要用于外部调用区分步骤 - ''' + @remark 单步 备注信息,没有实际作用,主要用于外部调用区分步骤 + """ return self.__wrapper__.cta_step(self.__context__.id) def hft_step(self): - ''' + """ HFT策略单步执行 - ''' + """ self.__wrapper__.hft_step(self.__context__.id) def stop_backtest(self): - ''' + """ 手动停止回测 - ''' + """ self.__wrapper__.stop_backtest() def release_backtest(self): - ''' + """ 释放框架 - ''' + """ self.__wrapper__.release_backtest() def on_init(self): return - def on_schedule(self, date:int, time:int, taskid:int = 0): + def on_schedule(self, date: int, time: int, taskid: int = 0): return - def on_session_begin(self, date:int): + def on_session_begin(self, date: int): self.trading_day = date return - def on_session_end(self, date:int): + def on_session_end(self, date: int): return def on_backtest_end(self): @@ -465,7 +471,7 @@ def on_backtest_end(self): self.__context__.on_backtest_end() def clear_cache(self): - ''' + """ 清除缓存的数据,即加已经加载到内存中的数据全部清除 - ''' + """ self.__wrapper__.clear_cache() diff --git a/wtpy/WtCoreDefs.py b/wtpy/WtCoreDefs.py index 3b7be56a..66b57dce 100644 --- a/wtpy/WtCoreDefs.py +++ b/wtpy/WtCoreDefs.py @@ -1,12 +1,14 @@ from ctypes import c_void_p, CFUNCTYPE, POINTER, c_char_p, c_bool, c_ulong, c_double -from ctypes import Structure, c_char, c_int32, c_uint32,c_uint64,c_int64 +from ctypes import Structure, c_char, c_int32, c_uint32, c_uint64, c_int64 from copy import copy import numpy as np +from enum import Enum + +MAX_INSTRUMENT_LENGTH = c_char * 32 +MAX_EXCHANGE_LENGTH = c_char * 16 +PriceQueueType = c_double * 10 +VolumeQueueType = c_double * 10 -MAX_INSTRUMENT_LENGTH = c_char*32 -MAX_EXCHANGE_LENGTH = c_char*16 -PriceQueueType = c_double*10 -VolumeQueueType = c_double*10 class WTSStruct(Structure): @property @@ -16,18 +18,19 @@ def fields(self) -> list: @property def values(self) -> tuple: return tuple(getattr(self, i[0]) for i in self._fields_) - + @property def dict(self) -> dict: - return {i[0]:getattr(self, i[0]) for i in self._fields_} - + return {i[0]: getattr(self, i[0]) for i in self._fields_} + def to_dict(self) -> dict: - return {i[0]:getattr(self, i[0]) for i in self._fields_} + return {i[0]: getattr(self, i[0]) for i in self._fields_} + class WTSTickStruct(WTSStruct): - ''' + """ C接口传递的tick数据结构 - ''' + """ _fields_ = [("exchg", MAX_EXCHANGE_LENGTH), ("code", MAX_INSTRUMENT_LENGTH), ("price", c_double), @@ -65,7 +68,7 @@ class WTSTickStruct(WTSStruct): ("bid_price_7", c_double), ("bid_price_8", c_double), ("bid_price_9", c_double), - + ("ask_price_0", c_double), ("ask_price_1", c_double), ("ask_price_2", c_double), @@ -76,7 +79,7 @@ class WTSTickStruct(WTSStruct): ("ask_price_7", c_double), ("ask_price_8", c_double), ("ask_price_9", c_double), - + ("bid_qty_0", c_double), ("bid_qty_1", c_double), ("bid_qty_2", c_double), @@ -87,7 +90,7 @@ class WTSTickStruct(WTSStruct): ("bid_qty_7", c_double), ("bid_qty_8", c_double), ("bid_qty_9", c_double), - + ("ask_qty_0", c_double), ("ask_qty_1", c_double), ("ask_qty_2", c_double), @@ -109,82 +112,7 @@ def fields(self) -> list: @property def bid_prices(self) -> tuple: - return (self.bid_price_0, - self.bid_price_1, - self.bid_price_2, - self.bid_price_3, - self.bid_price_4, - self.bid_price_5, - self.bid_price_6, - self.bid_price_7, - self.bid_price_8, - self.bid_price_9) - - @property - def bid_qty(self) -> tuple: - return (self.bid_qty_0, - self.bid_qty_1, - self.bid_qty_2, - self.bid_qty_3, - self.bid_qty_4, - self.bid_qty_5, - self.bid_qty_6, - self.bid_qty_7, - self.bid_qty_8, - self.bid_qty_9) - - @property - def ask_prices(self) -> tuple: - return (self.ask_price_0, - self.ask_price_1, - self.ask_price_2, - self.ask_price_3, - self.ask_price_4, - self.ask_price_5, - self.ask_price_6, - self.ask_price_7, - self.ask_price_8, - self.ask_price_9) - - @property - def ask_qty(self) -> tuple: - return (self.ask_qty_0, - self.ask_qty_1, - self.ask_qty_2, - self.ask_qty_3, - self.ask_qty_4, - self.ask_qty_5, - self.ask_qty_6, - self.ask_qty_7, - self.ask_qty_8, - self.ask_qty_9) - - def to_tuple(self) -> tuple: - return ( - np.uint64(self.action_date)*1000000000+self.action_time, - self.exchg, - self.code, - self.price, - self.open, - self.high, - self.low, - self.settle_price, - self.upper_limit, - self.lower_limit, - self.total_volume, - self.volume, - self.total_turnover, - self.turn_over, - self.open_interest, - self.diff_interest, - self.trading_date, - self.action_date, - self.action_time, - self.pre_close, - self.pre_settle, - self.pre_interest, - - self.bid_price_0, + return (self.bid_price_0, self.bid_price_1, self.bid_price_2, self.bid_price_3, @@ -193,20 +121,11 @@ def to_tuple(self) -> tuple: self.bid_price_6, self.bid_price_7, self.bid_price_8, - self.bid_price_9, - - self.ask_price_0, - self.ask_price_1, - self.ask_price_2, - self.ask_price_3, - self.ask_price_4, - self.ask_price_5, - self.ask_price_6, - self.ask_price_7, - self.ask_price_8, - self.ask_price_9, - - self.bid_qty_0, + self.bid_price_9) + + @property + def bid_qty(self) -> tuple: + return (self.bid_qty_0, self.bid_qty_1, self.bid_qty_2, self.bid_qty_3, @@ -215,9 +134,24 @@ def to_tuple(self) -> tuple: self.bid_qty_6, self.bid_qty_7, self.bid_qty_8, - self.bid_qty_9, - - self.ask_qty_0, + self.bid_qty_9) + + @property + def ask_prices(self) -> tuple: + return (self.ask_price_0, + self.ask_price_1, + self.ask_price_2, + self.ask_price_3, + self.ask_price_4, + self.ask_price_5, + self.ask_price_6, + self.ask_price_7, + self.ask_price_8, + self.ask_price_9) + + @property + def ask_qty(self) -> tuple: + return (self.ask_qty_0, self.ask_qty_1, self.ask_qty_2, self.ask_qty_3, @@ -226,13 +160,83 @@ def to_tuple(self) -> tuple: self.ask_qty_6, self.ask_qty_7, self.ask_qty_8, - self.ask_qty_9 - ) + self.ask_qty_9) + + def to_tuple(self) -> tuple: + return ( + np.uint64(self.action_date) * 1000000000 + self.action_time, + self.exchg, + self.code, + self.price, + self.open, + self.high, + self.low, + self.settle_price, + self.upper_limit, + self.lower_limit, + self.total_volume, + self.volume, + self.total_turnover, + self.turn_over, + self.open_interest, + self.diff_interest, + self.trading_date, + self.action_date, + self.action_time, + self.pre_close, + self.pre_settle, + self.pre_interest, + + self.bid_price_0, + self.bid_price_1, + self.bid_price_2, + self.bid_price_3, + self.bid_price_4, + self.bid_price_5, + self.bid_price_6, + self.bid_price_7, + self.bid_price_8, + self.bid_price_9, + + self.ask_price_0, + self.ask_price_1, + self.ask_price_2, + self.ask_price_3, + self.ask_price_4, + self.ask_price_5, + self.ask_price_6, + self.ask_price_7, + self.ask_price_8, + self.ask_price_9, + + self.bid_qty_0, + self.bid_qty_1, + self.bid_qty_2, + self.bid_qty_3, + self.bid_qty_4, + self.bid_qty_5, + self.bid_qty_6, + self.bid_qty_7, + self.bid_qty_8, + self.bid_qty_9, + + self.ask_qty_0, + self.ask_qty_1, + self.ask_qty_2, + self.ask_qty_3, + self.ask_qty_4, + self.ask_qty_5, + self.ask_qty_6, + self.ask_qty_7, + self.ask_qty_8, + self.ask_qty_9 + ) + class WTSBarStruct(WTSStruct): - ''' + """ C接口传递的bar数据结构 - ''' + """ # @2IQ9d _fields_ = [("date", c_uint32), ("reserve", c_uint32), @@ -248,35 +252,36 @@ class WTSBarStruct(WTSStruct): ("diff", c_double)] _pack_ = 8 - def to_tuple(self, flag:int=0) -> tuple: - ''' + def to_tuple(self, flag: int = 0) -> tuple: + """ WTSBarStruct转成tuple @flag 转换标记,0-分钟线,1-日线,2-秒线 - ''' + """ if flag == 0: time = self.time + 199000000000 elif flag == 1: time = self.date elif flag == 2: time = self.time - + return ( - self.date, - time, - self.open, - self.high, - self.low, - self.close, - self.settle, - self.turnover, - self.volume, - self.open_interest, - self.diff) + self.date, + time, + self.open, + self.high, + self.low, + self.close, + self.settle, + self.turnover, + self.volume, + self.open_interest, + self.diff) + class WTSTransStruct(WTSStruct): - ''' + """ C接口传递的逐笔成交数据结构 - ''' + """ _fields_ = [("exchg", MAX_EXCHANGE_LENGTH), ("code", MAX_INSTRUMENT_LENGTH), @@ -296,25 +301,26 @@ class WTSTransStruct(WTSStruct): def to_tuple(self) -> tuple: return ( - np.uint64(self.action_date)*1000000000+self.action_time, - self.exchg, - self.code, - self.trading_date, - self.action_date, - self.action_time, - self.index, - self.ttype, - self.side, - self.price, - self.volume, - self.askorder, - self.bidorder - ) + np.uint64(self.action_date) * 1000000000 + self.action_time, + self.exchg, + self.code, + self.trading_date, + self.action_date, + self.action_time, + self.index, + self.ttype, + self.side, + self.price, + self.volume, + self.askorder, + self.bidorder + ) + class WTSOrdQueStruct(WTSStruct): - ''' + """ C接口传递的委托队列数据结构 - ''' + """ _fields_ = [("exchg", MAX_EXCHANGE_LENGTH), ("code", MAX_INSTRUMENT_LENGTH), @@ -326,27 +332,28 @@ class WTSOrdQueStruct(WTSStruct): ("price", c_double), ("order_items", c_uint32), ("qsize", c_uint32), - ("volumes", c_uint32*50)] + ("volumes", c_uint32 * 50)] _pack_ = 8 def to_tuple(self) -> tuple: return ( - np.uint64(self.action_date)*1000000000+self.action_time, - self.exchg, - self.code, - self.trading_date, - self.action_date, - self.action_time, - self.side, - self.price, - self.order_items, - self.qsize - ) + tuple(self.volumes) + np.uint64(self.action_date) * 1000000000 + self.action_time, + self.exchg, + self.code, + self.trading_date, + self.action_date, + self.action_time, + self.side, + self.price, + self.order_items, + self.qsize + ) + tuple(self.volumes) + class WTSOrdDtlStruct(WTSStruct): - ''' + """ C接口传递的委托明细数据结构 - ''' + """ _fields_ = [("exchg", MAX_EXCHANGE_LENGTH), ("code", MAX_INSTRUMENT_LENGTH), @@ -363,119 +370,118 @@ class WTSOrdDtlStruct(WTSStruct): def to_tuple(self) -> tuple: return ( - np.uint64(self.action_date)*1000000000+self.action_time, - self.exchg, - self.code, - self.trading_date, - self.action_date, - self.action_time, - self.index, - self.side, - self.price, - self.volume, - self.otype - ) + np.uint64(self.action_date) * 1000000000 + self.action_time, + self.exchg, + self.code, + self.trading_date, + self.action_date, + self.action_time, + self.index, + self.side, + self.price, + self.volume, + self.otype + ) + # 回调函数定义 -#策略初始化回调 -CB_STRATEGY_INIT = CFUNCTYPE(c_void_p, c_ulong) -#策略tick数据推送回调 +# 策略初始化回调 +CB_STRATEGY_INIT = CFUNCTYPE(c_void_p, c_ulong) +# 策略tick数据推送回调 CB_STRATEGY_TICK = CFUNCTYPE(c_void_p, c_ulong, c_char_p, POINTER(WTSTickStruct)) -#策略获取tick数据的单条tick同步回调 +# 策略获取tick数据的单条tick同步回调 CB_STRATEGY_GET_TICK = CFUNCTYPE(c_void_p, c_ulong, c_char_p, POINTER(WTSTickStruct), c_uint32, c_bool) -#策略重算回调(CTA/SEL策略) +# 策略重算回调(CTA/SEL策略) CB_STRATEGY_CALC = CFUNCTYPE(c_void_p, c_ulong, c_ulong, c_ulong) -#策略订阅的K线闭合事件回调 +# 策略订阅的K线闭合事件回调 CB_STRATEGY_BAR = CFUNCTYPE(c_void_p, c_ulong, c_char_p, c_char_p, POINTER(WTSBarStruct)) -#策略获取K线数据的单条K线同步回调 +# 策略获取K线数据的单条K线同步回调 CB_STRATEGY_GET_BAR = CFUNCTYPE(c_void_p, c_ulong, c_char_p, c_char_p, POINTER(WTSBarStruct), c_uint32, c_bool) -#策略获取全部持仓的同步回调 +# 策略获取全部持仓的同步回调 CB_STRATEGY_GET_POSITION = CFUNCTYPE(c_void_p, c_ulong, c_char_p, c_double, c_bool) -#交易日开始结束事件回调 -CB_SESSION_EVENT = CFUNCTYPE(c_void_p, c_ulong, c_ulong, c_bool) -#条件单触发回调 +# 交易日开始结束事件回调 +CB_SESSION_EVENT = CFUNCTYPE(c_void_p, c_ulong, c_ulong, c_bool) +# 条件单触发回调 CB_STRATEGY_COND_TRIGGER = CFUNCTYPE(c_void_p, c_ulong, c_char_p, c_double, c_double, c_char_p) -#引擎事件回调(交易日开启结束等) +# 引擎事件回调(交易日开启结束等) CB_ENGINE_EVENT = CFUNCTYPE(c_void_p, c_ulong, c_ulong, c_ulong) -#HFT策略交易通道事件回调 +# HFT策略交易通道事件回调 CB_HFTSTRA_CHNL_EVT = CFUNCTYPE(c_void_p, c_ulong, c_char_p, c_ulong) -#HFT策略订单推送回报 +# HFT策略订单推送回报 CB_HFTSTRA_ORD = CFUNCTYPE(c_void_p, c_ulong, c_ulong, c_char_p, c_bool, c_double, c_double, c_double, c_bool, c_char_p) -#HFT策略成交推送回报 +# HFT策略成交推送回报 CB_HFTSTRA_TRD = CFUNCTYPE(c_void_p, c_ulong, c_ulong, c_char_p, c_bool, c_double, c_double, c_char_p) -#HFT策略下单结果回报 +# HFT策略下单结果回报 CB_HFTSTRA_ENTRUST = CFUNCTYPE(c_void_p, c_ulong, c_ulong, c_char_p, c_bool, c_char_p, c_char_p) -#HFT策略持仓推送回报(实盘有效) +# HFT策略持仓推送回报(实盘有效) CB_HFTSTRA_POSITION = CFUNCTYPE(c_void_p, c_ulong, c_char_p, c_bool, c_double, c_double, c_double, c_double) -#策略委托队列推送回调 +# 策略委托队列推送回调 CB_HFTSTRA_ORDQUE = CFUNCTYPE(c_void_p, c_ulong, c_char_p, POINTER(WTSOrdQueStruct)) -#策略获取委托队列数据的单条数据同步回调 +# 策略获取委托队列数据的单条数据同步回调 CB_HFTSTRA_GET_ORDQUE = CFUNCTYPE(c_void_p, c_ulong, c_char_p, POINTER(WTSOrdQueStruct), c_uint32, c_bool) -#策略委托明细推送回调 +# 策略委托明细推送回调 CB_HFTSTRA_ORDDTL = CFUNCTYPE(c_void_p, c_ulong, c_char_p, POINTER(WTSOrdDtlStruct)) -#策略获取委托明细数据的单条数据同步回调 +# 策略获取委托明细数据的单条数据同步回调 CB_HFTSTRA_GET_ORDDTL = CFUNCTYPE(c_void_p, c_ulong, c_char_p, POINTER(WTSOrdDtlStruct), c_uint32, c_bool) -#策略成交明细推送回调 +# 策略成交明细推送回调 CB_HFTSTRA_TRANS = CFUNCTYPE(c_void_p, c_ulong, c_char_p, POINTER(WTSTransStruct)) -#策略获取成交明细数据的单条数据同步回调 +# 策略获取成交明细数据的单条数据同步回调 CB_HFTSTRA_GET_TRANS = CFUNCTYPE(c_void_p, c_ulong, c_char_p, POINTER(WTSTransStruct), c_uint32, c_bool) +EVENT_ENGINE_INIT = 1 # 框架初始化 +EVENT_SESSION_BEGIN = 2 # 交易日开始 +EVENT_SESSION_END = 3 # 交易日结束 +EVENT_ENGINE_SCHDL = 4 # 框架调度 +EVENT_BACKTEST_END = 5 # 回测结束 -EVENT_ENGINE_INIT = 1 #框架初始化 -EVENT_SESSION_BEGIN = 2 #交易日开始 -EVENT_SESSION_END = 3 #交易日结束 -EVENT_ENGINE_SCHDL = 4 #框架调度 -EVENT_BACKTEST_END = 5 #回测结束 +CHNL_EVENT_READY = 1000 # 通道就绪事件 +CHNL_EVENT_LOST = 1001 # 通道断开事件 -CHNL_EVENT_READY = 1000 #通道就绪事件 -CHNL_EVENT_LOST = 1001 #通道断开事件 +# 日志级别 +LOG_LEVEL_DEBUG = 0 +LOG_LEVEL_INFO = 1 +LOG_LEVEL_WARN = 2 +LOG_LEVEL_ERROR = 3 -#日志级别 -LOG_LEVEL_DEBUG = 0 -LOG_LEVEL_INFO = 1 -LOG_LEVEL_WARN = 2 -LOG_LEVEL_ERROR = 3 -from enum import Enum class EngineType(Enum): - ''' + """ 引擎类型 枚举变量 - ''' + """ ET_CTA = 999 ET_HFT = 1000 ET_SEL = 1001 - -''' + +""" Parser外接实现 -''' -EVENT_PARSER_INIT = 1; #Parser初始化 -EVENT_PARSER_CONNECT = 2; #Parser连接 -EVENT_PARSER_DISCONNECT = 3; #Parser断开连接 -EVENT_PARSER_RELEASE = 4; #Parser释放 +""" +EVENT_PARSER_INIT = 1 # Parser初始化 +EVENT_PARSER_CONNECT = 2 # Parser连接 +EVENT_PARSER_DISCONNECT = 3 # Parser断开连接 +EVENT_PARSER_RELEASE = 4 # Parser释放 CB_PARSER_EVENT = CFUNCTYPE(c_void_p, c_ulong, c_char_p) CB_PARSER_SUBCMD = CFUNCTYPE(c_void_p, c_char_p, c_char_p, c_bool) -''' +""" Executer外接实现 -''' +""" CB_EXECUTER_INIT = CFUNCTYPE(c_void_p, c_char_p) CB_EXECUTER_CMD = CFUNCTYPE(c_void_p, c_char_p, c_char_p, c_double) - -''' +""" DataLoader外接实现 -''' -FUNC_LOAD_HISBARS = CFUNCTYPE(c_bool, c_char_p, c_char_p) #加载K线 -FUNC_LOAD_ADJFACTS = CFUNCTYPE(c_bool, c_char_p) #加载复权因子 -FUNC_LOAD_HISTICKS = CFUNCTYPE(c_bool, c_char_p, c_ulong) #加载Tick +""" +FUNC_LOAD_HISBARS = CFUNCTYPE(c_bool, c_char_p, c_char_p) # 加载K线 +FUNC_LOAD_ADJFACTS = CFUNCTYPE(c_bool, c_char_p) # 加载复权因子 +FUNC_LOAD_HISTICKS = CFUNCTYPE(c_bool, c_char_p, c_ulong) # 加载Tick -''' +""" DataDumper外接实现 -''' +""" FUNC_DUMP_HISBARS = CFUNCTYPE(c_bool, c_char_p, c_char_p, c_char_p, POINTER(WTSBarStruct), c_uint32) -FUNC_DUMP_HISTICKS = CFUNCTYPE(c_bool, c_char_p, c_char_p, c_ulong, POINTER(WTSTickStruct), c_uint32) \ No newline at end of file +FUNC_DUMP_HISTICKS = CFUNCTYPE(c_bool, c_char_p, c_char_p, c_ulong, POINTER(WTSTickStruct), c_uint32) diff --git a/wtpy/WtDataDefs.py b/wtpy/WtDataDefs.py index 91afdbcd..32118945 100644 --- a/wtpy/WtDataDefs.py +++ b/wtpy/WtDataDefs.py @@ -7,72 +7,90 @@ from wtpy.WtCoreDefs import WTSBarStruct, WTSTickStruct -NpTypeBar = np.dtype([('date','u4'),('reserve','u4'),('time','u8'),('open','d'),\ - ('high','d'),('low','d'),('close','d'),('settle','d'),\ - ('turnover','d'),('volume','d'),('open_interest','d'),('diff','d')]) - -NpTypeTick = np.dtype([('exchg','S16'),('code','S32'),('price','d'),('open','d'),('high','d'),('low','d'),('settle_price','d'),\ - ('upper_limit','d'),('lower_limit','d'),('total_volume','d'),('volume','d'),('total_turnover','d'),('turn_over','d'),\ - ('open_interest','d'),('diff_interest','d'),('trading_date','u4'),('action_date','u4'),('action_time','u4'),\ - ('reserve','u4'),('pre_close','d'),('pre_settle','d'),('pre_interest','d'),\ - ('bid_price_0','d'),('bid_price_1','d'),('bid_price_2','d'),('bid_price_3','d'),('bid_price_4','d'),\ - ('bid_price_5','d'),('bid_price_6','d'),('bid_price_7','d'),('bid_price_8','d'),('bid_price_9','d'),\ - ('ask_price_0','d'),('ask_price_1','d'),('ask_price_2','d'),('ask_price_3','d'),('ask_price_4','d'),\ - ('ask_price_5','d'),('ask_price_6','d'),('ask_price_7','d'),('ask_price_8','d'),('ask_price_9','d'),\ - ('bid_qty_0','d'),('bid_qty_1','d'),('bid_qty_2','d'),('bid_qty_3','d'),('bid_qty_4','d'),\ - ('bid_qty_5','d'),('bid_qty_6','d'),('bid_qty_7','d'),('bid_qty_8','d'),('bid_qty_9','d'),\ - ('ask_qty_0','d'),('ask_qty_1','d'),('ask_qty_2','d'),('ask_qty_3','d'),('ask_qty_4','d'),\ - ('ask_qty_5','d'),('ask_qty_6','d'),('ask_qty_7','d'),('ask_qty_8','d'),('ask_qty_9','d')]) - -NpTypeTrans = np.dtype([('exchg','S16'),('code','S32'),('trading_date','u4'),('action_date','u4'),('action_time','u4'),\ - ('reserve1','u4'),('index','u8'),('ttype','i4'),('side','i4'),('price','d'),('volume','u4'),('reserve2','u4'),\ - ('askorder', 'i8'),('bidorder', 'i8')]) - -NpTypeOrdQue = np.dtype([('exchg','S16'),('code','S32'),('trading_date','u4'),('action_date','u4'),('action_time','u4'),\ - ('side','u4'),('price','d'),('order_items','u4'),('qsize', 'i8'),('volumes','u4', 50)]) - -NpTypeOrdDtl = np.dtype([('exchg','S16'),('code','S32'),('trading_date','u4'),('action_date','u4'),('action_time','u4'),\ - ('reserve1','u4'),('index','u8'),('price','d'),('volume','u4'),('side','u4'),('otype','u4'),('reserve2','u4')]) +NpTypeBar = np.dtype([('date', 'u4'), ('reserve', 'u4'), ('time', 'u8'), ('open', 'd'), + ('high', 'd'), ('low', 'd'), ('close', 'd'), ('settle', 'd'), + ('turnover', 'd'), ('volume', 'd'), ('open_interest', 'd'), ('diff', 'd')]) + +NpTypeTick = np.dtype([('exchg', 'S16'), ('code', 'S32'), ('price', 'd'), ('open', 'd'), ('high', 'd'), ('low', 'd'), + ('settle_price', 'd'), + ('upper_limit', 'd'), ('lower_limit', 'd'), ('total_volume', 'd'), ('volume', 'd'), + ('total_turnover', 'd'), ('turn_over', 'd'), + ('open_interest', 'd'), ('diff_interest', 'd'), ('trading_date', 'u4'), ('action_date', 'u4'), + ('action_time', 'u4'), + ('reserve', 'u4'), ('pre_close', 'd'), ('pre_settle', 'd'), ('pre_interest', 'd'), + ('bid_price_0', 'd'), ('bid_price_1', 'd'), ('bid_price_2', 'd'), ('bid_price_3', 'd'), + ('bid_price_4', 'd'), + ('bid_price_5', 'd'), ('bid_price_6', 'd'), ('bid_price_7', 'd'), ('bid_price_8', 'd'), + ('bid_price_9', 'd'), + ('ask_price_0', 'd'), ('ask_price_1', 'd'), ('ask_price_2', 'd'), ('ask_price_3', 'd'), + ('ask_price_4', 'd'), + ('ask_price_5', 'd'), ('ask_price_6', 'd'), ('ask_price_7', 'd'), ('ask_price_8', 'd'), + ('ask_price_9', 'd'), + ('bid_qty_0', 'd'), ('bid_qty_1', 'd'), ('bid_qty_2', 'd'), ('bid_qty_3', 'd'), + ('bid_qty_4', 'd'), + ('bid_qty_5', 'd'), ('bid_qty_6', 'd'), ('bid_qty_7', 'd'), ('bid_qty_8', 'd'), + ('bid_qty_9', 'd'), + ('ask_qty_0', 'd'), ('ask_qty_1', 'd'), ('ask_qty_2', 'd'), ('ask_qty_3', 'd'), + ('ask_qty_4', 'd'), + ('ask_qty_5', 'd'), ('ask_qty_6', 'd'), ('ask_qty_7', 'd'), ('ask_qty_8', 'd'), + ('ask_qty_9', 'd')]) + +NpTypeTrans = np.dtype( + [('exchg', 'S16'), ('code', 'S32'), ('trading_date', 'u4'), ('action_date', 'u4'), ('action_time', 'u4'), + ('reserve1', 'u4'), ('index', 'u8'), ('ttype', 'i4'), ('side', 'i4'), ('price', 'd'), ('volume', 'u4'), + ('reserve2', 'u4'), + ('askorder', 'i8'), ('bidorder', 'i8')]) + +NpTypeOrdQue = np.dtype( + [('exchg', 'S16'), ('code', 'S32'), ('trading_date', 'u4'), ('action_date', 'u4'), ('action_time', 'u4'), + ('side', 'u4'), ('price', 'd'), ('order_items', 'u4'), ('qsize', 'i8'), ('volumes', 'u4', 50)]) + +NpTypeOrdDtl = np.dtype( + [('exchg', 'S16'), ('code', 'S32'), ('trading_date', 'u4'), ('action_date', 'u4'), ('action_time', 'u4'), + ('reserve1', 'u4'), ('index', 'u8'), ('price', 'd'), ('volume', 'u4'), ('side', 'u4'), ('otype', 'u4'), + ('reserve2', 'u4')]) + class WtNpKline: - ''' + """ 基于numpy.ndarray的K线数据容器 提供一些常用的属性和方法 - ''' - __type__:np.dtype = NpTypeBar - def __init__(self, isDay:bool = False, forceCopy:bool = False): - ''' + """ + __type__: np.dtype = NpTypeBar + + def __init__(self, isDay: bool = False, forceCopy: bool = False): + """ 基于numpy.ndarray的K线数据容器 @isDay 是否是日线数据, 主要用于控制bartimes的生成机制 @forceCopy 是否强制拷贝, 如果为True, 则会拷贝一份数据, 否则会直接引用内存中的数据 强制拷贝主要用于WtDtHelper的read_dsb_bars和read_dmb_bars接口, 因为这两个接口返回的数据是临时的, 调用结束就会释放 - ''' - self.__data__:np.ndarray = None - self.__isDay__:bool = isDay - self.__force_copy__:bool = forceCopy - self.__bartimes__:np.ndarray = None - self.__df__:pd.DataFrame = None + """ + self.__data__: np.ndarray = None + self.__isDay__: bool = isDay + self.__force_copy__: bool = forceCopy + self.__bartimes__: np.ndarray = None + self.__df__: pd.DataFrame = None def __len__(self): if self.__data__ is None: return 0 - + return len(self.__data__) - - def __getitem__(self, index:int): + + def __getitem__(self, index: int): if self.__data__ is None: raise IndexError("No data in WtNpKline") - + return self.__data__[index] - def set_day_flag(self, isDay:bool): + def set_day_flag(self, isDay: bool): if self.__isDay__ != isDay: self.__isDay__ = isDay self.__bartimes__ = None self.__df__ = None - def set_data(self, firstBar, count:int): - BarList = WTSBarStruct*count + def set_data(self, firstBar, count: int): + BarList = WTSBarStruct * count if self.__force_copy__: c_array = BarList.from_buffer_copy(BarList.from_address(addressof(firstBar.contents))) else: @@ -90,7 +108,7 @@ def set_data(self, firstBar, count:int): @property def ndarray(self) -> np.ndarray: return self.__data__ - + @property def opens(self) -> np.ndarray: return self.__data__["open"] @@ -113,61 +131,63 @@ def volumes(self) -> np.ndarray: @property def bartimes(self) -> np.ndarray: - ''' + """ 这里应该会构造一个副本, 可以暂存一个 - ''' + """ if self.__bartimes__ is None: if self.__isDay__: - self.__bartimes__ = self.__data__["date"] + self.__bartimes__ = self.__data__["date"] else: self.__bartimes__ = self.__data__["time"] + 199000000000 return self.__bartimes__ - - def get_bar(self, iLoc:int = -1) -> tuple: + + def get_bar(self, iLoc: int = -1) -> tuple: return self.__data__[iLoc] - + @property def is_day(self) -> bool: return self.__isDay__ - + def to_df(self) -> pd.DataFrame: if self.__df__ is None: self.__df__ = pd.DataFrame(self.__data__, index=self.bartimes) self.__df__.drop(columns=["time", "reserve"], inplace=True) self.__df__["bartime"] = self.__df__.index return self.__df__ - + + class WtNpTicks: - ''' + """ 基于numpy.ndarray的tick数据容器 提供一些常用的属性和方法 - ''' - __type__:np.dtype = NpTypeTick - def __init__(self, forceCopy:bool = False): - ''' + """ + __type__: np.dtype = NpTypeTick + + def __init__(self, forceCopy: bool = False): + """ 基于numpy.ndarray的tick数据容器 @forceCopy 是否强制拷贝, 如果为True, 则会拷贝一份数据, 否则会直接引用内存中的数据 强制拷贝主要用于WtDtHelper的read_dsb_ticks和read_dmb_ticks接口, 因为这两个接口返回的数据是临时的, 调用结束就会释放 - ''' - self.__data__:np.ndarray = None - self.__times__:np.ndarray = None - self.__force_copy__:bool = forceCopy - self.__df__:pd.DataFrame = None + """ + self.__data__: np.ndarray = None + self.__times__: np.ndarray = None + self.__force_copy__: bool = forceCopy + self.__df__: pd.DataFrame = None def __len__(self): if self.__data__ is None: return 0 - + return len(self.__data__) - - def __getitem__(self, index:int): + + def __getitem__(self, index: int): if self.__data__ is None: raise IndexError("No data in WtNpTicks") - + return self.__data__[index] - def set_data(self, firstTick, count:int): - BarList = WTSTickStruct*count + def set_data(self, firstTick, count: int): + BarList = WTSTickStruct * count if self.__force_copy__: c_array = BarList.from_buffer_copy(BarList.from_address(addressof(firstTick.contents))) else: @@ -185,14 +205,13 @@ def set_data(self, firstTick, count:int): @property def times(self) -> np.ndarray: - ''' + """ 这里应该会构造一个副本, 可以暂存一个 - ''' + """ if self.__times__ is None: - self.__times__ = np.uint64(self.__data__["action_date"])*1000000000 + self.__data__["action_time"] + self.__times__ = np.uint64(self.__data__["action_date"]) * 1000000000 + self.__data__["action_time"] return self.__times__ - def to_df(self) -> pd.DataFrame: if self.__df__ is None: self.__df__ = pd.DataFrame(self.__data__, index=self.times) @@ -203,41 +222,43 @@ def to_df(self) -> pd.DataFrame: @property def ndarray(self) -> np.ndarray: return self.__data__ - + + class WtNpTransactions: - ''' + """ 基于numpy.ndarray的逐笔成交数据容器 提供一些常用的属性和方法 - ''' - __type__:np.dtype = NpTypeTrans - def __init__(self, forceCopy:bool = False): - ''' + """ + __type__: np.dtype = NpTypeTrans + + def __init__(self, forceCopy: bool = False): + """ 基于numpy.ndarray的逐笔成交数据容器 @forceCopy 是否强制拷贝, 如果为True, 则会拷贝一份数据, 否则会直接引用内存中的数据 强制拷贝主要用于WtDtHelper的read_dsb_trans和read_dmb_trans接口, 因为这两个接口返回的数据是临时的, 调用结束就会释放 - ''' - self.__data__:np.ndarray = None - self.__force_copy__:bool = forceCopy + """ + self.__data__: np.ndarray = None + self.__force_copy__: bool = forceCopy def __len__(self): if self.__data__ is None: return 0 - + return len(self.__data__) - - def __getitem__(self, index:int): + + def __getitem__(self, index: int): if self.__data__ is None: raise IndexError("No data in WtNpTransactions") - + return self.__data__[index] - def set_data(self, firstItem, count:int): - DataList = WTSTransStruct*count + def set_data(self, firstItem, count: int): + DataList = WTSTransStruct * count if self.__force_copy__: c_array = DataList.from_buffer_copy(DataList.from_address(addressof(firstItem.contents))) else: c_array = DataList.from_buffer(DataList.from_address(addressof(firstItem.contents))) - + npAy = np.frombuffer(c_array, dtype=self.__type__, count=count) # 这里有点不高效,需要拼接的地方主要是WtDtServo的场景,这里慢点没关系 # 一旦触发拼接逻辑,都会拷贝一次 @@ -251,36 +272,38 @@ def set_data(self, firstItem, count:int): @property def ndarray(self) -> np.ndarray: return self.__data__ - + + class WtNpOrdDetails: - ''' + """ 基于numpy.ndarray的逐笔委托数据容器 提供一些常用的属性和方法 - ''' - __type__:np.dtype = NpTypeOrdDtl - def __init__(self, forceCopy:bool = False): - ''' + """ + __type__: np.dtype = NpTypeOrdDtl + + def __init__(self, forceCopy: bool = False): + """ 基于numpy.ndarray的逐笔委托数据容器 @forceCopy 是否强制拷贝, 如果为True, 则会拷贝一份数据, 否则会直接引用内存中的数据 强制拷贝主要用于WtDtHelper的read_dsb_trans和read_dmb_trans接口, 因为这两个接口返回的数据是临时的, 调用结束就会释放 - ''' - self.__data__:np.ndarray = None - self.__force_copy__:bool = forceCopy + """ + self.__data__: np.ndarray = None + self.__force_copy__: bool = forceCopy def __len__(self): if self.__data__ is None: return 0 - + return len(self.__data__) - - def __getitem__(self, index:int): + + def __getitem__(self, index: int): if self.__data__ is None: raise IndexError("No data in WtNpOrdDetails") - + return self.__data__[index] - def set_data(self, firstItem, count:int): - DataList = WTSOrdDtlStruct*count + def set_data(self, firstItem, count: int): + DataList = WTSOrdDtlStruct * count if self.__force_copy__: c_array = DataList.from_buffer_copy(DataList.from_address(addressof(firstItem.contents))) else: @@ -299,36 +322,38 @@ def set_data(self, firstItem, count:int): @property def ndarray(self) -> np.ndarray: return self.__data__ - + + class WtNpOrdQueues: - ''' + """ 基于numpy.ndarray的委托队列数据容器 提供一些常用的属性和方法 - ''' - __type__:np.dtype = NpTypeOrdQue - def __init__(self, forceCopy:bool = False): - ''' + """ + __type__: np.dtype = NpTypeOrdQue + + def __init__(self, forceCopy: bool = False): + """ 基于numpy.ndarray的委托队列数据容器 @forceCopy 是否强制拷贝, 如果为True, 则会拷贝一份数据, 否则会直接引用内存中的数据 强制拷贝主要用于WtDtHelper的read_dsb_trans和read_dmb_trans接口, 因为这两个接口返回的数据是临时的, 调用结束就会释放 - ''' - self.__data__:np.ndarray = None - self.__force_copy__:bool = forceCopy + """ + self.__data__: np.ndarray = None + self.__force_copy__: bool = forceCopy def __len__(self): if self.__data__ is None: return 0 - + return len(self.__data__) - - def __getitem__(self, index:int): + + def __getitem__(self, index: int): if self.__data__ is None: raise IndexError("No data in WtNpOrdQueues") - + return self.__data__[index] - def set_data(self, firstItem, count:int): - DataList = WTSOrdQueStruct*count + def set_data(self, firstItem, count: int): + DataList = WTSOrdQueStruct * count if self.__force_copy__: c_array = DataList.from_buffer_copy(DataList.from_address(addressof(firstItem.contents))) else: @@ -347,42 +372,44 @@ def set_data(self, firstItem, count:int): @property def ndarray(self) -> np.ndarray: return self.__data__ - + + class WtBarCache: - def __init__(self, isDay:bool = False, forceCopy:bool = False): - self.records:WtNpKline = None + def __init__(self, isDay: bool = False, forceCopy: bool = False): + self.records: WtNpKline = None self.__is_day__ = isDay self.__force_copy__ = forceCopy self.__total_count__ = 0 - def on_read_bar(self, firstItem:POINTER(WTSBarStruct), count:int, isLast:bool): + def on_read_bar(self, firstItem: POINTER(WTSBarStruct), count: int, isLast: bool): if self.records is None: self.records = WtNpKline(isDay=self.__is_day__, forceCopy=self.__force_copy__) # 多次set_data,会在内部自动concatenate self.records.set_data(firstItem, count) - def on_data_count(self, count:int): + def on_data_count(self, count: int): # 其实这里最好的处理方式是能够直接将底层的内存块拷贝,拼接成一块大的内存块 # 但是暂时没想好怎么处理,所以只能多次set_data了,会损失一些性能,但是比以前快 self.__total_count__ = count pass + class WtTickCache: - def __init__(self, forceCopy:bool = False): - self.records:WtNpTicks = None + def __init__(self, forceCopy: bool = False): + self.records: WtNpTicks = None self.__force_copy__ = forceCopy self.__total_count__ = 0 - def on_read_tick(self, firstItem:POINTER(WTSTickStruct), count:int, isLast:bool): + def on_read_tick(self, firstItem: POINTER(WTSTickStruct), count: int, isLast: bool): if self.records is None: self.records = WtNpTicks(forceCopy=self.__force_copy__) # 多次set_data,会在内部自动concatenate self.records.set_data(firstItem, count) - def on_data_count(self, count:int): + def on_data_count(self, count: int): # 其实这里最好的处理方式是能够直接将底层的内存块拷贝,拼接成一块大的内存块 # 但是暂时没想好怎么处理,所以只能多次set_data了,会损失一些性能,但是比以前快 self.__total_count__ = count - pass \ No newline at end of file + pass diff --git a/wtpy/WtDtEngine.py b/wtpy/WtDtEngine.py index 53fe153e..824e8421 100644 --- a/wtpy/WtDtEngine.py +++ b/wtpy/WtDtEngine.py @@ -3,80 +3,82 @@ from wtpy.WtUtilDefs import singleton import json + @singleton class WtDtEngine: def __init__(self): - self.__wrapper__ = WtDtWrapper(self) #api接口转换器 - self.__ext_parsers__ = dict() #外接的行情接入模块 - self.__ext_dumpers__ = dict() #扩展数据Dumper + self.__wrapper__ = WtDtWrapper(self) # api接口转换器 + self.__ext_parsers__ = dict() # 外接的行情接入模块 + self.__ext_dumpers__ = dict() # 扩展数据Dumper - def initialize(self, cfgfile:str = "dtcfg.yaml", logprofile:str = "logcfgdt.yaml", bCfgFile:bool = True, bLogCfgFile:bool = True): - ''' + def initialize(self, cfgfile: str = "dtcfg.yaml", logprofile: str = "logcfgdt.yaml", bCfgFile: bool = True, + bLogCfgFile: bool = True): + """ 数据引擎初始化 @cfgfile 配置文件 @logprofile 日志模块配置文件 - ''' + """ self.__wrapper__.initialize(cfgfile, logprofile, bCfgFile, bLogCfgFile) - def init_with_config(self, cfgfile:dict, logprofile:dict): - ''' + def init_with_config(self, cfgfile: dict, logprofile: dict): + """ 数据引擎初始化 @cfgfile 配置 @logprofile 日志模块配置 - ''' + """ self.__wrapper__.initialize(json.dumps(cfgfile), json.dumps(logprofile), False, False) - - def run(self, bAsync:bool = False): - ''' + + def run(self, bAsync: bool = False): + """ 运行数据引擎 @bAsync 是否异步,异步则立即返回,默认False - ''' + """ self.__wrapper__.run_datakit(bAsync) - def add_exetended_parser(self, parser:BaseExtParser): - ''' + def add_exetended_parser(self, parser: BaseExtParser): + """ 添加扩展parser - ''' + """ id = parser.id() if id not in self.__ext_parsers__: self.__ext_parsers__[id] = parser if not self.__wrapper__.create_extended_parser(id): self.__ext_parsers__.pop(id) - def get_extended_parser(self, id:str)->BaseExtParser: - ''' + def get_extended_parser(self, id: str) -> BaseExtParser: + """ 根据id获取扩展parser - ''' + """ if id not in self.__ext_parsers__: return None return self.__ext_parsers__[id] - def push_quote_from_extended_parser(self, id:str, newTick, uProcFlag:int): - ''' + def push_quote_from_extended_parser(self, id: str, newTick, uProcFlag: int): + """ 向底层推送tick数据 @id parserid @newTick POINTER(WTSTickStruct) @uProcFlag 预处理标记,0-不处理,1-切片,2-累加 - ''' + """ self.__wrapper__.push_quote_from_exetended_parser(id, newTick, uProcFlag) - def add_extended_data_dumper(self, dumper:BaseExtDataDumper): - ''' + def add_extended_data_dumper(self, dumper: BaseExtDataDumper): + """ 添加扩展dumper - ''' + """ id = dumper.id() if id not in self.__ext_dumpers__: self.__ext_dumpers__[id] = dumper if not self.__wrapper__.create_extended_dumper(id): self.__ext_dumpers__.pop(id) self.__wrapper__.register_extended_data_dumper() - - def get_extended_data_dumper(self, id:str) -> BaseExtDataDumper: - ''' + + def get_extended_data_dumper(self, id: str) -> BaseExtDataDumper: + """ 根据id获取扩展dumper - ''' + """ if id not in self.__ext_dumpers__: return None - return self.__ext_dumpers__[id] \ No newline at end of file + return self.__ext_dumpers__[id] diff --git a/wtpy/WtDtServo.py b/wtpy/WtDtServo.py index 13f7f5f3..cc6681f8 100644 --- a/wtpy/WtDtServo.py +++ b/wtpy/WtDtServo.py @@ -4,21 +4,22 @@ import json import os + @singleton class WtDtServo: # 构造函数, 传入动态库名 - def __init__(self, logcfg:str="logcfg.yaml"): + def __init__(self, logcfg: str = "logcfg.yaml"): self.__config__ = None self.__cfg_commited__ = False self.local_api = None self.logCfg = logcfg def __check_config__(self): - ''' + """ 检查设置项 主要会补充一些默认设置项 - ''' + """ if self.local_api is None: self.local_api = WtDtServoApi() @@ -30,14 +31,16 @@ def __check_config__(self): if "data" not in self.__config__: self.__config__["data"] = { - "store":{ - "path":"./storage/" + "store": { + "path": "./storage/" } } - def setBasefiles(self, folder:str="./common/", commfile:str="commodities.json", contractfile:str="contracts.json", - holidayfile:str="holidays.json", sessionfile:str="sessions.json", hotfile:str="hots.json"): - ''' + def setBasefiles(self, folder: str = "./common/", commfile: str = "commodities.json", + contractfile: str = "contracts.json", + holidayfile: str = "holidays.json", sessionfile: str = "sessions.json", + hotfile: str = "hots.json"): + """ 设置基础文件 @folder 基础文件目录 @commfile 品种文件, str/list @@ -45,20 +48,20 @@ def setBasefiles(self, folder:str="./common/", commfile:str="commodities.json", @holidayfile 节假日文件 @sessionfile 交易时间模板文件 @hotfile 主力合约配置文件 - ''' + """ self.__check_config__() - if type(commfile) == str: + if isinstance(commfile, str): self.__config__["basefiles"]["commodity"] = os.path.join(folder, commfile) - elif type(commfile) == list: + elif isinstance(commfile, list): absList = [] for filename in commfile: absList.append(os.path.join(folder, filename)) self.__config__["basefiles"]["commodity"] = ','.join(absList) - if type(contractfile) == str: + if isinstance(contractfile, str): self.__config__["basefiles"]["contract"] = os.path.join(folder, contractfile) - elif type(contractfile) == list: + elif isinstance(contractfile, list): absList = [] for filename in contractfile: absList.append(os.path.join(folder, filename)) @@ -68,11 +71,11 @@ def setBasefiles(self, folder:str="./common/", commfile:str="commodities.json", self.__config__["basefiles"]["session"] = os.path.join(folder, sessionfile) self.__config__["basefiles"]["hot"] = os.path.join(folder, hotfile) - def setStorage(self, path:str = "./storage/", adjfactor:str = "adjfactors.json"): + def setStorage(self, path: str = "./storage/", adjfactor: str = "adjfactors.json"): self.__config__["data"]["store"]["path"] = path self.__config__["data"]["store"]["adjfactor"] = adjfactor - - def commitConfig(self): + + def commitConfig(self): if self.__cfg_commited__: return @@ -84,33 +87,35 @@ def commitConfig(self): print(oe) def clear_cache(self): - ''' + """ 清除缓存数据 - ''' + """ self.local_api.clear_cache() - def get_bars(self, stdCode:str, period:str, fromTime:int = None, dataCount:int = None, endTime:int = 0) -> WtNpKline: - ''' + def get_bars(self, stdCode: str, period: str, fromTime: int = None, dataCount: int = None, + endTime: int = 0) -> WtNpKline: + """ 获取K线数据 @stdCode 标准合约代码 @period 基础K线周期, m1/m5/d @fromTime 开始时间, 日线数据格式yyyymmdd, 分钟线数据为格式为yyyymmddHHMM @endTime 结束时间, 日线数据格式yyyymmdd, 分钟线数据为格式为yyyymmddHHMM, 为0则读取到最后一条 - ''' + """ self.commitConfig() if (fromTime is None and dataCount is None) or (fromTime is not None and dataCount is not None): raise Exception('Only one of fromTime and dataCount must be valid at the same time') - return self.local_api.get_bars(stdCode=stdCode, period=period, fromTime=fromTime, dataCount=dataCount, endTime=endTime) + return self.local_api.get_bars(stdCode=stdCode, period=period, fromTime=fromTime, dataCount=dataCount, + endTime=endTime) - def get_ticks(self, stdCode:str, fromTime:int = None, dataCount:int = None, endTime:int = 0) -> WtNpTicks: - ''' + def get_ticks(self, stdCode: str, fromTime: int = None, dataCount: int = None, endTime: int = 0) -> WtNpTicks: + """ 获取tick数据 @stdCode 标准合约代码 @fromTime 开始时间, 格式为yyyymmddHHMM @endTime 结束时间, 格式为yyyymmddHHMM, 为0则读取到最后一条 - ''' + """ self.commitConfig() if (fromTime is None and dataCount is None) or (fromTime is not None and dataCount is not None): @@ -118,34 +123,34 @@ def get_ticks(self, stdCode:str, fromTime:int = None, dataCount:int = None, endT return self.local_api.get_ticks(stdCode=stdCode, fromTime=fromTime, dataCount=dataCount, endTime=endTime) - def get_ticks_by_date(self, stdCode:str, iDate:int) -> WtNpTicks: - ''' + def get_ticks_by_date(self, stdCode: str, iDate: int) -> WtNpTicks: + """ 按日期获取tick数据 @stdCode 标准合约代码 @iDate 日期, 格式为yyyymmdd - ''' + """ self.commitConfig() return self.local_api.get_ticks_by_date(stdCode=stdCode, iDate=iDate) - def get_sbars_by_date(self, stdCode:str, iSec:int, iDate:int) -> WtNpKline: - ''' + def get_sbars_by_date(self, stdCode: str, iSec: int, iDate: int) -> WtNpKline: + """ 按日期获取秒线数据 @stdCode 标准合约代码 @iSec 周期, 单位s @iDate 日期, 格式为yyyymmdd - ''' + """ self.commitConfig() return self.local_api.get_sbars_by_date(stdCode=stdCode, iSec=iSec, iDate=iDate) - def get_bars_by_date(self, stdCode:str, period:str, iDate:int) -> WtNpKline: - ''' + def get_bars_by_date(self, stdCode: str, period: str, iDate: int) -> WtNpKline: + """ 按日期获取K线数据 @stdCode 标准合约代码 @period 周期,只支持分钟线 @iDate 日期, 格式为yyyymmdd - ''' + """ self.commitConfig() - return self.local_api.get_bars_by_date(stdCode=stdCode, period=period, iDate=iDate) \ No newline at end of file + return self.local_api.get_bars_by_date(stdCode=stdCode, period=period, iDate=iDate) diff --git a/wtpy/WtEngine.py b/wtpy/WtEngine.py index 9c4c3425..11e5c883 100644 --- a/wtpy/WtEngine.py +++ b/wtpy/WtEngine.py @@ -18,43 +18,45 @@ import chardet import os + @singleton class WtEngine: - ''' + """ 实盘交易引擎 - ''' + """ - def __init__(self, eType:EngineType, logCfg:str = "logcfg.yaml", genDir:str = "generated", bDumpCfg:bool = False): - ''' + def __init__(self, eType: EngineType, logCfg: str = "logcfg.yaml", genDir: str = "generated", + bDumpCfg: bool = False): + """ WtEngine构造函数 @eType 引擎类型: EngineType.ET_CTA、EngineType.ET_HFT、EngineType.ET_SEL @logCfg 日志配置文件 @genDir 数据输出目录 @bDumpCfg 是否保存最终配置文件 - ''' + """ self.is_backtest = False - self.__wrapper__:WtWrapper = WtWrapper(self) #api接口转换器 - self.__cta_ctxs__ = dict() #CTA策略ctx映射表 - self.__sel_ctxs__ = dict() #SEL策略ctx映射表 - self.__hft_ctxs__ = dict() #HFT策略ctx映射表 - self.__config__ = dict() #框架配置项 - self.__cfg_commited__ = False #配置是否已提交 + self.__wrapper__: WtWrapper = WtWrapper(self) # api接口转换器 + self.__cta_ctxs__ = dict() # CTA策略ctx映射表 + self.__sel_ctxs__ = dict() # SEL策略ctx映射表 + self.__hft_ctxs__ = dict() # HFT策略ctx映射表 + self.__config__ = dict() # 框架配置项 + self.__cfg_commited__ = False # 配置是否已提交 - self.__writer__:BaseIndexWriter = None #指标输出模块 - self.__reporter__:BaseDataReporter = None #数据提交模块 + self.__writer__: BaseIndexWriter = None # 指标输出模块 + self.__reporter__: BaseDataReporter = None # 数据提交模块 - self.__ext_data_loader__:BaseExtDataLoader = None #扩展历史数据加载器 + self.__ext_data_loader__: BaseExtDataLoader = None # 扩展历史数据加载器 - self.__ext_parsers__ = dict() #外接的行情接入模块 - self.__ext_executers__ = dict() #外接的执行器 + self.__ext_parsers__ = dict() # 外接的行情接入模块 + self.__ext_executers__ = dict() # 外接的执行器 - self.__dump_config__ = bDumpCfg #是否保存最终配置 + self.__dump_config__ = bDumpCfg # 是否保存最终配置 self.__is_cfg_yaml__ = True - self.trading_day = 0 #当前交易日 + self.trading_day = 0 # 当前交易日 - self.__engine_type:EngineType = eType + self.__engine_type: EngineType = eType if eType == EngineType.ET_CTA: self.__wrapper__.initialize_cta(logCfg=logCfg, isFile=True, genDir=genDir) elif eType == EngineType.ET_HFT: @@ -63,10 +65,10 @@ def __init__(self, eType:EngineType, logCfg:str = "logcfg.yaml", genDir:str = "g self.__wrapper__.initialize_sel(logCfg=logCfg, isFile=True, genDir=genDir) def __check_config__(self): - ''' + """ 检查设置项 主要会补充一些默认设置项 - ''' + """ if "basefiles" not in self.__config__: self.__config__["basefiles"] = dict() @@ -75,83 +77,83 @@ def __check_config__(self): self.__config__["env"]["name"] = "cta" self.__config__["env"]["mode"] = "product" self.__config__["env"]["product"] = { - "session":"TRADING" + "session": "TRADING" } - + def get_engine_type(self) -> EngineType: return self.__engine_type - def set_extended_data_loader(self, loader:BaseExtDataLoader): + def set_extended_data_loader(self, loader: BaseExtDataLoader): self.__ext_data_loader__ = loader self.__wrapper__.register_extended_data_loader() def get_extended_data_loader(self) -> BaseExtDataLoader: return self.__ext_data_loader__ - def add_exetended_parser(self, parser:BaseExtParser): + def add_exetended_parser(self, parser: BaseExtParser): id = parser.id() if id not in self.__ext_parsers__: if self.__wrapper__.create_extended_parser(id): self.__ext_parsers__[id] = parser - def add_exetended_executer(self, executer:BaseExtExecuter): + def add_exetended_executer(self, executer: BaseExtExecuter): id = executer.id() if id not in self.__ext_executers__: if self.__wrapper__.create_extended_executer(id): self.__ext_executers__[id] = executer - def get_extended_parser(self, id:str)->BaseExtParser: + def get_extended_parser(self, id: str) -> BaseExtParser: if id not in self.__ext_parsers__: return None return self.__ext_parsers__[id] - def get_extended_executer(self, id:str)->BaseExtExecuter: + def get_extended_executer(self, id: str) -> BaseExtExecuter: if id not in self.__ext_executers__: return None return self.__ext_executers__[id] - def push_quote_from_extended_parser(self, id:str, newTick, uProcFlag:int): - ''' + def push_quote_from_extended_parser(self, id: str, newTick, uProcFlag: int): + """ 向底层推送tick数据 @id parserid @newTick POINTER(WTSTickStruct) @uProcFlag 预处理标记,0-不处理,1-切片,2-累加 - ''' + """ self.__wrapper__.push_quote_from_exetended_parser(id, newTick, uProcFlag) - def set_writer(self, writer:BaseIndexWriter): - ''' + def set_writer(self, writer: BaseIndexWriter): + """ 设置指标输出模块 - ''' + """ self.__writer__ = writer - def write_indicator(self, id:str, tag:str, time:int, data:dict): - ''' + def write_indicator(self, id: str, tag: str, time: int, data: dict): + """ 写入指标数据 - ''' + """ if self.__writer__ is not None: self.__writer__.write_indicator(id, tag, time, data) - def set_data_reporter(self, reporter:BaseDataReporter): - ''' + def set_data_reporter(self, reporter: BaseDataReporter): + """ 设置数据报告器 - ''' + """ self.__reporter__ = reporter - def init(self, folder:str, - cfgfile:str = "config.yaml", - contractfile:str = None, - sessionfile:str = None, - commfile:str = None, - holidayfile:str = None, - hotfile:str = None, - secondfile:str = None): - ''' + def init(self, folder: str, + cfgfile: str = "config.yaml", + contractfile: str = None, + sessionfile: str = None, + commfile: str = None, + holidayfile: str = None, + hotfile: str = None, + secondfile: str = None): + """ 初始化 @folder 基础数据文件目录,\\结尾 @cfgfile 配置文件,json格式 - ''' + """ f = open(cfgfile, "rb") content = f.read() f.close() @@ -167,9 +169,9 @@ def init(self, folder:str, self.__check_config__() - if contractfile is not None: + if contractfile is not None: self.__config__["basefiles"]["contract"] = os.path.join(folder, contractfile) - + if sessionfile is not None: self.__config__["basefiles"]["session"] = os.path.join(folder, sessionfile) @@ -203,17 +205,17 @@ def init(self, folder:str, self.sessionMgr = SessionMgr() self.sessionMgr.load(self.__config__["basefiles"]["session"]) - def configEngine(self, name:str, mode:str = "product"): - ''' + def configEngine(self, name: str, mode: str = "product"): + """ 设置引擎和运行模式 - ''' + """ self.__config__["env"]["name"] = name self.__config__["env"]["mode"] = mode - def addExternalCtaStrategy(self, id:str, params:dict): - ''' + def addExternalCtaStrategy(self, id: str, params: dict): + """ 添加外部的CTA策略 - ''' + """ if "strategies" not in self.__config__: self.__config__["strategies"] = dict() @@ -223,10 +225,10 @@ def addExternalCtaStrategy(self, id:str, params:dict): params["id"] = id self.__config__["strategies"]["cta"].append(params) - def addExternalHftStrategy(self, id:str, params:dict): - ''' + def addExternalHftStrategy(self, id: str, params: dict): + """ 添加外部的HFT策略 - ''' + """ if "strategies" not in self.__config__: self.__config__["strategies"] = dict() @@ -236,31 +238,31 @@ def addExternalHftStrategy(self, id:str, params:dict): params["id"] = id self.__config__["strategies"]["hft"].append(params) - def configStorage(self, path:str, module:str=""): - ''' + def configStorage(self, path: str, module: str = ""): + """ 配置数据存储 @mode 存储模式,csv-表示从csv直接读取,一般回测使用,wtp-表示使用wt框架自带数据存储 - ''' + """ self.__config__["data"]["store"]["module"] = module self.__config__["data"]["store"]["path"] = path - def registerCustomRule(self, ruleTag:str, filename:str): - ''' + def registerCustomRule(self, ruleTag: str, filename: str): + """ 注册自定义连续合约规则 @ruleTag 规则标签,如ruleTag为THIS,对应的连续合约代码为CFFEX.IF.THIS @filename 规则定义文件名,和hots.json格式一样 - ''' + """ if "rules" not in self.__config__["basefiles"]: self.__config__["basefiles"]["rules"] = dict() self.__config__["basefiles"]["rules"][ruleTag] = filename def commitConfig(self): - ''' + """ 提交配置 只有第一次调用会生效,不可重复调用 如果执行run之前没有调用,run会自动调用该方法 - ''' + """ if self.__cfg_commited__: return @@ -278,45 +280,45 @@ def commitConfig(self): f.write(cfgfile) f.close() - def regCtaStraFactories(self, factFolder:str): - ''' + def regCtaStraFactories(self, factFolder: str): + """ 向底层模块注册CTA工厂模块目录 !!!CTA策略只会被CTA引擎加载!!! @factFolder 工厂模块所在的目录 - ''' + """ return self.__wrapper__.reg_cta_factories(factFolder) - def regHftStraFactories(self, factFolder:str): - ''' + def regHftStraFactories(self, factFolder: str): + """ 向底层模块注册HFT工厂模块目录 !!!HFT策略只会被HFT引擎加载!!! @factFolder 工厂模块所在的目录 - ''' + """ return self.__wrapper__.reg_hft_factories(factFolder) - def regExecuterFactories(self, factFolder:str): - ''' + def regExecuterFactories(self, factFolder: str): + """ 向底层模块注册执行器模块目录 !!!执行器只在CTA引擎有效!!! @factFolder 工厂模块所在的目录 - ''' + """ return self.__wrapper__.reg_exe_factories(factFolder) - def addExecuter(self, id:str, trader:str, policies:dict, scale:int = 1): + def addExecuter(self, id: str, trader: str, policies: dict, scale: int = 1): if "executers" not in self.__config__: self.__config__["executers"] = list() exeItem = { - "active":True, + "active": True, "id": id, "scale": scale, "policy": policies, - "trader":trader + "trader": trader } self.__config__["executers"].append(exeItem) - def addTrader(self, id:str, params:dict): + def addTrader(self, id: str, params: dict): if "traders" not in self.__config__: self.__config__["traders"] = list() @@ -326,11 +328,11 @@ def addTrader(self, id:str, params:dict): self.__config__["traders"].append(tItem) - def getSessionByCode(self, stdCode:str) -> SessionInfo: - ''' + def getSessionByCode(self, stdCode: str) -> SessionInfo: + """ 通过合约代码获取交易时间模板 @stdCode 合约代码,格式如SHFE.rb.HOT - ''' + """ pid = CodeHelper.stdCodeToStdCommID(stdCode) pInfo = self.productMgr.getProductInfo(pid) if pInfo is None: @@ -338,86 +340,88 @@ def getSessionByCode(self, stdCode:str) -> SessionInfo: return self.sessionMgr.getSession(pInfo.session) - def getSessionByName(self, sname:str) -> SessionInfo: - ''' + def getSessionByName(self, sname: str) -> SessionInfo: + """ 通过模板名获取交易时间模板 @sname 模板名 - ''' + """ return self.sessionMgr.getSession(sname) - def getProductInfo(self, stdCode:str) -> ProductInfo: - ''' + def getProductInfo(self, stdCode: str) -> ProductInfo: + """ 获取品种信息 @stdCode 合约代码,格式如SHFE.rb.HOT - ''' + """ return self.productMgr.getProductInfo(stdCode) - def getContractInfo(self, stdCode:str) -> ContractInfo: - ''' + def getContractInfo(self, stdCode: str) -> ContractInfo: + """ 获取品种信息 @stdCode 合约代码,格式如SHFE.rb.HOT - ''' + """ return self.contractMgr.getContractInfo(stdCode, self.trading_day) def getAllCodes(self) -> list: - ''' + """ 获取全部合约代码 - ''' + """ return self.contractMgr.getTotalCodes(self.trading_day) - - def getCodesByProduct(self, stdPID:str) -> list: - ''' + + def getCodesByProduct(self, stdPID: str) -> list: + """ 根据品种id获取对应合约代码 @stdPID 品种代码, 格式如SHFE.rb - ''' + """ return self.contractMgr.getCodesByProduct(stdPID, self.trading_day) - - def getCodesByUnderlying(self, underlying:str) -> list: - ''' + + def getCodesByUnderlying(self, underlying: str) -> list: + """ 根据underlying获取对应合约代码 @underlying 品种代码, 格式如SHFE.rb2305 - ''' + """ return self.contractMgr.getCodesByUnderlying(underlying, self.trading_day) - def getRawStdCode(self, stdCode:str): - ''' + def getRawStdCode(self, stdCode: str): + """ 根据连续合约代码获取原始合约代码 - ''' + """ return self.__wrapper__.get_raw_stdcode(stdCode) - def add_cta_strategy(self, strategy:BaseCtaStrategy, slippage:int = 0): - ''' + def add_cta_strategy(self, strategy: BaseCtaStrategy, slippage: int = 0): + """ 添加CTA策略 @strategy 策略对象 - ''' + """ id = self.__wrapper__.create_cta_context(strategy.name(), slippage) self.__cta_ctxs__[id] = CtaContext(id, strategy, self.__wrapper__, self) - def add_hft_strategy(self, strategy:BaseHftStrategy, trader:str, agent:bool = True, slippage:int = 0): - ''' + def add_hft_strategy(self, strategy: BaseHftStrategy, trader: str, agent: bool = True, slippage: int = 0): + """ 添加HFT策略 @strategy 策略对象 - ''' + """ id = self.__wrapper__.create_hft_context(strategy.name(), trader, agent, slippage) self.__hft_ctxs__[id] = HftContext(id, strategy, self.__wrapper__, self) - def add_sel_strategy(self, strategy:BaseSelStrategy, date:int, time:int, period:str, trdtpl:str="CHINA", session:str="TRADING", slippage:int = 0): - ''' + def add_sel_strategy(self, strategy: BaseSelStrategy, date: int, time: int, period: str, trdtpl: str = "CHINA", + session: str = "TRADING", slippage: int = 0): + """ 添加SEL策略 @ strategy SEL策略对象 @date 日期,根据周期变化,每日为0,每周为0~6,对应周日到周六,每月为1~31,每年为0101~1231 - @time 时间,精确到分钟 - @period 时间周期,可以是分钟min、天d、周w、月m、年y + @time 时间,精确到分钟 + @period 时间周期,可以是分钟min、天d、周w、月m、年y @slippage 滑点大小 - ''' - id = self.__wrapper__.create_sel_context(name=strategy.name(), date=date, time=time, period=period, trdtpl=trdtpl, session=session, slippage=slippage) + """ + id = self.__wrapper__.create_sel_context(name=strategy.name(), date=date, time=time, period=period, + trdtpl=trdtpl, session=session, slippage=slippage) self.__sel_ctxs__[id] = SelContext(id, strategy, self.__wrapper__, self) - def get_context(self, id:int): - ''' + def get_context(self, id: int): + """ 根据ID获取策略上下文 @id 上下文id,一般添加策略的时候会自动生成一个唯一的上下文id - ''' + """ if self.__engine_type == EngineType.ET_CTA: if id not in self.__cta_ctxs__: return None @@ -434,19 +438,19 @@ def get_context(self, id:int): return self.__sel_ctxs__[id] - def run(self, bAsync:bool = True): - ''' + def run(self, bAsync: bool = True): + """ 运行框架 - ''' - if not self.__cfg_commited__: #如果配置没有提交,则自动提交一下 + """ + if not self.__cfg_commited__: # 如果配置没有提交,则自动提交一下 self.commitConfig() self.__wrapper__.run(bAsync) def release(self): - ''' + """ 释放框架 - ''' + """ self.__wrapper__.release() def on_init(self): @@ -454,17 +458,17 @@ def on_init(self): self.__reporter__.report_init_data() return - def on_schedule(self, date:int, time:int, taskid:int = 0): + def on_schedule(self, date: int, time: int, taskid: int = 0): # print("engine scheduled") if self.__reporter__ is not None: self.__reporter__.report_rt_data() - def on_session_begin(self, date:int): + def on_session_begin(self, date: int): # print("session begin") self.trading_day = date return - def on_session_end(self, date:int): + def on_session_end(self, date: int): if self.__reporter__ is not None: self.__reporter__.report_settle_data() return diff --git a/wtpy/WtMsgQue.py b/wtpy/WtMsgQue.py index 42065937..abd363e6 100644 --- a/wtpy/WtMsgQue.py +++ b/wtpy/WtMsgQue.py @@ -2,27 +2,29 @@ from wtpy.WtUtilDefs import singleton from ctypes import c_char, POINTER + class WtMQServer: def __init__(self): self.id = None - def init(self, wrapper:WtMQWrapper, id:int): + def init(self, wrapper: WtMQWrapper, id: int): self.id = id self.wrapper = wrapper - def publish_message(self, topic:str, message:str): + def publish_message(self, topic: str, message: str): if self.id is None: raise Exception("MQServer not initialzied") self.wrapper.publish_message(self.id, topic, message) + class WtMQClient: def __init__(self): return - def init(self, wrapper:WtMQWrapper, id:int): + def init(self, wrapper: WtMQWrapper, id: int): self.id = id self.wrapper = wrapper @@ -32,18 +34,19 @@ def start(self): self.wrapper.start_client(self.id) - def subscribe(self, topic:str): + def subscribe(self, topic: str): if self.id is None: raise Exception("MQClient not initialzied") self.wrapper.subcribe_topic(self.id, topic) - def on_mq_message(self, topic:bytes, message, dataLen:int): + def on_mq_message(self, topic: bytes, message, dataLen: int): pass + @singleton class WtMsgQue: - def __init__(self, logger = None) -> None: + def __init__(self, logger=None) -> None: self._servers = dict() self._clients = dict() self._logger = logger @@ -51,20 +54,20 @@ def __init__(self, logger = None) -> None: self._cb_msg = CB_ON_MSG(self.on_mq_message) - def get_client(self, client_id:int) -> WtMQClient: + def get_client(self, client_id: int) -> WtMQClient: if client_id not in self._clients: return None - + return self._clients[client_id] - def on_mq_message(self, client_id:int, topic:bytes, message:POINTER(c_char), dataLen:int): + def on_mq_message(self, client_id: int, topic: bytes, message: POINTER(c_char), dataLen: int): client = self.get_client(client_id) if client is None: print(f"WtMsgQue: client {client_id} not found") return client.on_mq_message(topic, message, dataLen) - def add_mq_server(self, url:str, server:WtMQServer = None) -> WtMQServer: + def add_mq_server(self, url: str, server: WtMQServer = None) -> WtMQServer: id = self._wrapper.create_server(url) if server is None: server = WtMQServer() @@ -73,15 +76,15 @@ def add_mq_server(self, url:str, server:WtMQServer = None) -> WtMQServer: self._servers[id] = server return server - def destroy_mq_server(self, server:WtMQServer): + def destroy_mq_server(self, server: WtMQServer): id = server.id if id not in self._servers: return - + self._wrapper.destroy_server(id) self._servers.pop(id) - def add_mq_client(self, url:str, client:WtMQClient = None) -> WtMQClient: + def add_mq_client(self, url: str, client: WtMQClient = None) -> WtMQClient: id = self._wrapper.create_client(url, self._cb_msg) if client is None: client = WtMQClient() @@ -89,12 +92,10 @@ def add_mq_client(self, url:str, client:WtMQClient = None) -> WtMQClient: self._clients[id] = client return client - def destroy_mq_client(self, client:WtMQClient): + def destroy_mq_client(self, client: WtMQClient): id = client.id if id not in self._clients: return - + self._wrapper.destroy_client(id) self._clients.pop(id) - - \ No newline at end of file diff --git a/wtpy/WtUtilDefs.py b/wtpy/WtUtilDefs.py index 863fedb2..1904aaad 100644 --- a/wtpy/WtUtilDefs.py +++ b/wtpy/WtUtilDefs.py @@ -1,9 +1,11 @@ def singleton(cls): instances = {} - def getinstance(*args,**kwargs): + + def getinstance(*args, **kwargs): if cls not in instances: - instances[cls] = cls(*args,**kwargs) + instances[cls] = cls(*args, **kwargs) return instances[cls] + return getinstance @@ -12,4 +14,5 @@ def wrapper(*args, **kwargs): msg = f"Warning: {func.__name__} is deprecated." print(msg) return func(*args, **kwargs) + return wrapper diff --git a/wtpy/__init__.py b/wtpy/__init__.py index c9d1b39b..c2becf62 100644 --- a/wtpy/__init__.py +++ b/wtpy/__init__.py @@ -5,23 +5,23 @@ from .WtEngine import WtEngine from .WtBtEngine import WtBtEngine from .WtDtEngine import WtDtEngine -from .WtCoreDefs import WTSTickStruct,WTSBarStruct,EngineType +from .WtCoreDefs import WTSTickStruct, WTSBarStruct, EngineType from .ExtToolDefs import BaseDataReporter, BaseIndexWriter from .ExtModuleDefs import BaseExtExecuter, BaseExtParser from .WtMsgQue import WtMsgQue, WtMQClient, WtMQServer from .WtDtServo import WtDtServo from wtpy.wrapper.WtExecApi import WtExecApi -from wtpy.wrapper.ContractLoader import ContractLoader,LoaderType +from wtpy.wrapper.ContractLoader import ContractLoader, LoaderType from wtpy.wrapper.TraderDumper import TraderDumper, DumperSink -__all__ = ["BaseCtaStrategy", "BaseSelStrategy", "BaseHftStrategy", - "CtaContext", "SelContext", "HftContext", - "WtEngine", "WtBtEngine", "WtDtEngine", "EngineType", - "WtExecApi", "WtDtServo", - "WTSTickStruct","WTSBarStruct", - "BaseIndexWriter", "BaseDataReporter", - "ContractLoader", "LoaderType", - "BaseExtParser", "BaseExtExecuter", - "WtMsgQue", "WtMQClient", "WtMQServer", - "TraderDumper", "DumperSink"] \ No newline at end of file +__all__ = ["BaseCtaStrategy", "BaseSelStrategy", "BaseHftStrategy", + "CtaContext", "SelContext", "HftContext", + "WtEngine", "WtBtEngine", "WtDtEngine", "EngineType", + "WtExecApi", "WtDtServo", + "WTSTickStruct", "WTSBarStruct", + "BaseIndexWriter", "BaseDataReporter", + "ContractLoader", "LoaderType", + "BaseExtParser", "BaseExtExecuter", + "WtMsgQue", "WtMQClient", "WtMQServer", + "TraderDumper", "DumperSink"] diff --git a/wtpy/apps/WtBtAnalyst.py b/wtpy/apps/WtBtAnalyst.py index a9290320..e7ff61f1 100644 --- a/wtpy/apps/WtBtAnalyst.py +++ b/wtpy/apps/WtBtAnalyst.py @@ -10,11 +10,12 @@ from xlsxwriter import Workbook -class Calculate(): - ''' +class Calculate: + """ 绩效比率计算 - ''' - def __init__(self, ret, mar, rf, period, trade,capital,ret_day=[],trade_day=0,profit=0): + """ + + def __init__(self, ret, mar, rf, period, trade, capital, ret_day=[], trade_day=0, profit=0): """ :param ret: 收益率序列(单笔) :param mar: 最低可接受回报 @@ -40,8 +41,8 @@ def __init__(self, ret, mar, rf, period, trade,capital,ret_day=[],trade_day=0,pr def calculate_upside_ratio(self): upside = self.ret_day - self.daily_rf acess_return = upside[upside > 0].sum() / self.trade_day - downside_std = math.sqrt((upside[upside < 0] ** 2).sum()/self.trade_day) - if len(upside[upside < 0]) ==0: + downside_std = math.sqrt((upside[upside < 0] ** 2).sum() / self.trade_day) + if len(upside[upside < 0]) == 0: return 9999 upside_ratio = acess_return / downside_std return upside_ratio @@ -56,7 +57,7 @@ def sharp_ratio(self): # 索提诺比率 def sortion_ratio(self): expect_return = self.ret_day.mean() - downside = self.ret_day-self.daily_rf + downside = self.ret_day - self.daily_rf downside_std = downside[downside < 0].std() # downside_std = self.ret[self.ret.apply(lambda x: x < 0)].std() sortion_ratio = (expect_return - self.daily_rf) / downside_std * np.sqrt(self.period) @@ -99,7 +100,7 @@ def calmar_ratio(self): # 斯特林比率 def sterling_a_ratio(self): annual_return = Calculate.get_annual_return(self) - sterling_a_ratio = annual_return / abs(Calculate.maxDrawdown_ratio(self)- 0.1) + sterling_a_ratio = annual_return / abs(Calculate.maxDrawdown_ratio(self) - 0.1) return sterling_a_ratio # 单笔最大回撤 @@ -124,13 +125,14 @@ def single_maxdrawdown_time(self): # 年化收益率 def get_annual_return(self): - annual_return = 0 if self.trade_day==0 else (1+self.ret_day).cumprod()[len(self.ret_day)-1] ** (self.period / self.trade_day) - 1 + annual_return = 0 if self.trade_day == 0 else (1 + self.ret_day).cumprod()[len(self.ret_day) - 1] ** ( + self.period / self.trade_day) - 1 return annual_return # 月化收益率 def monthly_return(self): ann = Calculate.get_annual_return(self) - monthly_return = (ann + 1) ** (1/12) - 1 + monthly_return = (ann + 1) ** (1 / 12) - 1 return monthly_return # 月平均收益 @@ -140,7 +142,7 @@ def monthly_average_return(self): # 衰落时间 def decay_time(self): - netvalue = (self.ret+1).cumprod() + netvalue = (self.ret + 1).cumprod() ser = [] temp = netvalue[0] @@ -157,16 +159,18 @@ def decay_time(self): ss = max(pd.Series(ser)) return ss -def fmtNAN(val, defVal = 0): + +def fmtNAN(val, defVal=0): if math.isnan(val): return defVal return val + def continue_trading_analysis(data, x_value) -> dict: - ''' + """ 连续交易分析 - ''' + """ mean = data['profit'].mean() std = data['profit'].std() z_score = (x_value - mean) / std @@ -176,7 +180,7 @@ def continue_trading_analysis(data, x_value) -> dict: loss_time = 0 con_win_profit = [] con_lose_loss = [] - for i in range(len(data)-1): + for i in range(len(data) - 1): sss = data['profit'][i] if sss > 0: times += 1 @@ -227,16 +231,18 @@ def continue_trading_analysis(data, x_value) -> dict: return result + def nomalize_val(val): if math.isnan(val): return 0 else: return val + def extreme_trading(data, time_of_std=1): - ''' + """ 极端交易分析 - ''' + """ std = data['profit'].std() df_wins = data[data["profit"] > 0] df_wins_std = df_wins['profit'].std() @@ -263,26 +269,26 @@ def extreme_trading(data, time_of_std=1): extreme_num_win = len(extreme_result[extreme_result['profit'] > 0]) extreme_num_lose = len(extreme_result[extreme_result['profit'] < 0]) # 极端交易盈亏 1 Std. Deviation of Avg. Trade - extreme_profit = 0 if extreme_num==0 else extreme_result['profit'].sum() - extreme_profit_win = 0 if extreme_num_win ==0 else extreme_result[extreme_result['profit'] > 0]['profit'].sum() + extreme_profit = 0 if extreme_num == 0 else extreme_result['profit'].sum() + extreme_profit_win = 0 if extreme_num_win == 0 else extreme_result[extreme_result['profit'] > 0]['profit'].sum() extreme_profit_lose = 0 if extreme_num_lose == 0 else extreme_result[extreme_result['profit'] < 0]['profit'].sum() # 极端盈利交易计算 - result = {'总计':{ - '1 Std. Deviation of Avg. Trade': nomalize_val(std), - '单笔净利 +1倍标准差': nomalize_val(sin_profit_plstd), - '单笔净利 -1倍标准差': nomalize_val(sin_profit_mistd), - '极端交易数量': extreme_num, - '极端交易盈亏': extreme_profit - }, - '极端盈利':{ + result = {'总计': { + '1 Std. Deviation of Avg. Trade': nomalize_val(std), + '单笔净利 +1倍标准差': nomalize_val(sin_profit_plstd), + '单笔净利 -1倍标准差': nomalize_val(sin_profit_mistd), + '极端交易数量': extreme_num, + '极端交易盈亏': extreme_profit + }, + '极端盈利': { '1 Std. Deviation of Avg. Trade': nomalize_val(df_wins_std), '单笔净利 +1倍标准差': nomalize_val(sin_profit_plstd_win), '单笔净利 -1倍标准差': nomalize_val(sin_profit_mistd_win), '极端交易数量': extreme_num_win, '极端交易盈亏': extreme_profit_win }, - '极端亏损':{ + '极端亏损': { '1 Std. Deviation of Avg. Trade': nomalize_val(df_loses_std), '单笔净利 +1倍标准差': nomalize_val(sin_profit_plstd_lose), '单笔净利 -1倍标准差': nomalize_val(sin_profit_mistd_lose), @@ -297,15 +303,15 @@ def extreme_trading(data, time_of_std=1): def average_profit(data): - ''' + """ 连续交易分析之平均收益 - ''' + """ data = data['profit'] win = 0 li = [] lose = 0 li_2 = [] - dic= [] + dic = [] dicc = [] for i in range(1, len(data) - 1): if (data[i] > 0) == (data[i - 1] > 0): @@ -319,13 +325,13 @@ def average_profit(data): pass else: if data[i] > 0: - dis = {str(win): data[i-win:i+1].sum()} + dis = {str(win): data[i - win:i + 1].sum()} dic.append(dis) li.append(win) win = 0 else: - dis = {str(lose): data[i-lose:i+1].sum()} + dis = {str(lose): data[i - lose:i + 1].sum()} dicc.append(dis) li_2.append(lose) lose = 0 @@ -353,10 +359,11 @@ def average_profit(data): '每个序列平均亏损': lose_ss} return result -def stat_closes_by_day(df_closes:df, capital) -> df: - ''' + +def stat_closes_by_day(df_closes: df, capital) -> df: + """ 按天统计平仓数据 - ''' + """ df_closes['day'] = df_closes['opentime'] df_closes['win'] = df_closes['profit'].apply(lambda x: 1 if x > 0 else 0) df_closes['times'] = 1 @@ -364,14 +371,15 @@ def stat_closes_by_day(df_closes:df, capital) -> df: df_closes['gross_loss'] = df_closes['profit'].apply(lambda x: x if x < 0 else 0) profit = df_closes.groupby(df_closes['day'])[['win', 'times', 'profit', 'gross_profit', 'gross_loss']].sum() profit['win_rate'] = profit['win'] / profit['times'] - profit['profit_ratio'] = profit['profit']*100.0/capital + profit['profit_ratio'] = profit['profit'] * 100.0 / capital res = profit[['profit', 'gross_profit', 'gross_loss', 'times', 'win_rate', 'profit_ratio']] return res.iloc[::-1] -def stat_closes_by_month(df_closes:df, capital) -> df: - ''' + +def stat_closes_by_month(df_closes: df, capital) -> df: + """ 按月统计平仓数据 - ''' + """ df_closes['month'] = df_closes['opentime'].apply(lambda x: x.strftime("%Y/%m")) df_closes['win'] = df_closes['profit'].apply(lambda x: 1 if x > 0 else 0) df_closes['times'] = 1 @@ -379,14 +387,15 @@ def stat_closes_by_month(df_closes:df, capital) -> df: df_closes['gross_loss'] = df_closes['profit'].apply(lambda x: x if x < 0 else 0) profit = df_closes.groupby(df_closes['month'])[['win', 'times', 'profit', 'gross_profit', 'gross_loss']].sum() profit['win_rate'] = profit['win'] / profit['times'] - profit['profit_ratio'] = profit['profit']*100.0/capital + profit['profit_ratio'] = profit['profit'] * 100.0 / capital res = profit[['profit', 'gross_profit', 'gross_loss', 'times', 'win_rate', 'profit_ratio']] return res.iloc[::-1] -def stat_closes_by_year(df_closes:df, capital) -> df: - ''' + +def stat_closes_by_year(df_closes: df, capital) -> df: + """ 按年统计平仓数据 - ''' + """ df_closes['year'] = df_closes['opentime'].apply(lambda x: x.strftime("%Y")) df_closes['win'] = df_closes['profit'].apply(lambda x: 1 if x > 0 else 0) df_closes['times'] = 1 @@ -394,14 +403,15 @@ def stat_closes_by_year(df_closes:df, capital) -> df: df_closes['gross_loss'] = df_closes['profit'].apply(lambda x: x if x < 0 else 0) profit = df_closes.groupby(df_closes['year'])[['win', 'times', 'profit', 'gross_profit', 'gross_loss']].sum() profit['win_rate'] = profit['win'] / profit['times'] - profit['profit_ratio'] = profit['profit']*100.0/capital + profit['profit_ratio'] = profit['profit'] * 100.0 / capital res = profit[['profit', 'gross_profit', 'gross_loss', 'times', 'win_rate', 'profit_ratio']] return res.iloc[::-1] -def time_analysis(df_closes:df,df_funds:df) -> dict: - ''' + +def time_analysis(df_closes: df, df_funds: df) -> dict: + """ 时间分析 - ''' + """ trading_time = df_closes['closebarno'][len(df_closes) - 1] # 策略运行时间 @@ -421,7 +431,7 @@ def time_analysis(df_closes:df,df_funds:df) -> dict: rf = 0.02 period = 240 trade = input_data['closebarno'][len(input_data) - 1] / 47 - factors = Calculate(ret, mar, rf, period, trade,capital) + factors = Calculate(ret, mar, rf, period, trade, capital) single_drawdown_date = factors.single_maxdrawdown_time() signe_drawdown_date = parse(str(input_data['opentime'][single_drawdown_date])) @@ -443,7 +453,7 @@ def time_analysis(df_closes:df,df_funds:df) -> dict: result = {'交易周期': str(trading_time) + '根K线', '策略运行时间': str(str_time) + '根K线', - '策略运行时间%': str(round(porition,2)) + '%', + '策略运行时间%': str(round(porition, 2)) + '%', '最长空仓时间': str(empty_time) + '根K线', '策略最大回撤开始时间': start_time.strftime("%Y/%m/%d %H:%M"), '策略最大回撤结束时间': end_time.strftime("%Y/%m/%d %H:%M"), @@ -452,7 +462,8 @@ def time_analysis(df_closes:df,df_funds:df) -> dict: return result -def ratio_calculate(data, data2,after_merge, capital = 500000, rf = 0, period = 240) -> dict: + +def ratio_calculate(data, data2, after_merge, capital=500000, rf=0, period=240) -> dict: data['principal'] = data['totalprofit'] + capital data['principal'] = data['principal'].shift(1) profit = data['profit'] @@ -461,11 +472,11 @@ def ratio_calculate(data, data2,after_merge, capital = 500000, rf = 0, period = input_data = data.fillna(value=capital) input_data2 = data2.fillna(value=capital) ret = input_data['profit'] / input_data['principal'] - ret_day =input_data2['principal']/input_data2['principal2'] -1 + ret_day = input_data2['principal'] / input_data2['principal2'] - 1 trade_day = data2.shape[0] mar = 0 trade = input_data['closebarno'][len(input_data) - 1] / 47 - factors = Calculate(ret, mar, rf, period, trade,capital,ret_day,trade_day,profit) + factors = Calculate(ret, mar, rf, period, trade, capital, ret_day, trade_day, profit) # 潜在上涨比率 potential_upside_ratio = factors.calculate_upside_ratio() # 夏普比率 @@ -476,13 +487,13 @@ def ratio_calculate(data, data2,after_merge, capital = 500000, rf = 0, period = calmar_ratio = factors.calmar_ratio() # 斯特林比率 sterling_ratio = factors.sterling_a_ratio() - result1 = performance_summary(data, after_merge,data2=data2) + result1 = performance_summary(data, after_merge, data2=data2) # 净利/单笔最大亏损 net_s_loss = result1.get('净利') / result1.get('单笔最大亏损') # 净利/单笔最大回撤 net_s_drawdown = result1.get('净利') / factors.single_largest_maxdrawdown_value() # 净利/ 策略最大回撤 - net_strategy_drawdown = result1.get('净利') / factors.maxDrawdown() + net_strategy_drawdown = result1.get('净利') / factors.maxDrawdown() # 调整净利/单笔最大亏损 adjust_s_loss = result1.get('调整净利') / result1.get('单笔最大亏损') # 调整净利/单笔最大回撤 @@ -490,24 +501,25 @@ def ratio_calculate(data, data2,after_merge, capital = 500000, rf = 0, period = # 调整净利/ 策略最大回撤 adjust_strategy_drawdown = result1.get('调整净利') / factors.maxDrawdown() - result = {'潜在上涨比率':potential_upside_ratio, - '夏普比率':sharpe_ratio, - '索提诺比率':sortino_ratio, - '卡尔马比率':calmar_ratio, - '斯特林比率':sterling_ratio, - '净利/单笔最大亏损':net_s_loss, - '净利/单笔最大回撤':net_s_drawdown, - '净利/ 策略最大回撤':net_strategy_drawdown, - '调整净利/单笔最大亏损':adjust_s_loss, - '调整净利/单笔最大回撤':adjust_s_drawdown, - '调整净利/ 策略最大回撤':adjust_strategy_drawdown} + result = {'潜在上涨比率': potential_upside_ratio, + '夏普比率': sharpe_ratio, + '索提诺比率': sortino_ratio, + '卡尔马比率': calmar_ratio, + '斯特林比率': sterling_ratio, + '净利/单笔最大亏损': net_s_loss, + '净利/单笔最大回撤': net_s_drawdown, + '净利/ 策略最大回撤': net_strategy_drawdown, + '调整净利/单笔最大亏损': adjust_s_loss, + '调整净利/单笔最大回撤': adjust_s_drawdown, + '调整净利/ 策略最大回撤': adjust_strategy_drawdown} return result -def performance_summary(input_data, input_data1, capital = 500000, rf = 0.00, period = 240,data2 = []): - ''' + +def performance_summary(input_data, input_data1, capital=500000, rf=0.00, period=240, data2=[]): + """ 绩效统计 - ''' + """ # 指标计算准备 input_data['principal'] = input_data['totalprofit'] + capital input_data['principal'] = input_data['principal'].shift(1) @@ -515,7 +527,7 @@ def performance_summary(input_data, input_data1, capital = 500000, rf = 0.00, pe ret = input_data['profit'] / input_data['principal'] mar = 0 trade = len(input_data) - #trade = input_data['closebarno'][len(input_data)-1] / 47 + # trade = input_data['closebarno'][len(input_data)-1] / 47 data2['principal'] = data2['dynbalance'] + capital data2['principal2'] = data2['principal'].shift(1) @@ -523,35 +535,36 @@ def performance_summary(input_data, input_data1, capital = 500000, rf = 0.00, pe ret_day = input_data2['principal'] / input_data2['principal2'] - 1 trade_day = data2.shape[0] # 指标class - factors = Calculate(ret, mar, rf, period, trade,capital,ret_day,trade_day) + factors = Calculate(ret, mar, rf, period, trade, capital, ret_day, trade_day) # 毛利 profit = input_data[input_data['profit'].apply(lambda x: x >= 0)] - total_profit = 0 if len(profit)==0 else profit['profit'].sum() + total_profit = 0 if len(profit) == 0 else profit['profit'].sum() # 毛损 loss = input_data[input_data['profit'].apply(lambda x: x < 0)] - total_loss = 0 if len(loss)==0 else loss['profit'].sum() + total_loss = 0 if len(loss) == 0 else loss['profit'].sum() # 净利 net_profit = total_profit + total_loss - input_data1['adjust_profit'] = (input_data1['profit'] - input_data1['transaction_fee']) if len(input_data1)>0 else 0 + input_data1['adjust_profit'] = (input_data1['profit'] - input_data1['transaction_fee']) if len( + input_data1) > 0 else 0 # 调整毛利 adjust_profit = input_data1[input_data1['adjust_profit'].apply(lambda x: x >= 0)] - total_adjust_profit = 0 if len(adjust_profit)==0 else adjust_profit['adjust_profit'].sum() + total_adjust_profit = 0 if len(adjust_profit) == 0 else adjust_profit['adjust_profit'].sum() # 调整毛损 adjust_loss = input_data1[input_data1['adjust_profit'].apply((lambda x: x < 0))] - total_adjust_loss = 0 if len(adjust_loss)==0 else adjust_loss['adjust_profit'].sum() + total_adjust_loss = 0 if len(adjust_loss) == 0 else adjust_loss['adjust_profit'].sum() # 调整净利 adjust_net_profit = total_adjust_profit + total_adjust_loss # 盈利因子 - profit_factor = 0 if total_loss==0 else np.abs(total_profit / total_loss) + profit_factor = 0 if total_loss == 0 else np.abs(total_profit / total_loss) # 调整盈利因子 adjust_profit_factor = 0 if total_adjust_loss == 0 else np.abs(total_adjust_profit / total_adjust_loss) # 最大持有合约数量 max_holding_number = 1 # 已付手续费 - paid_trading_fee = input_data1['transaction_fee'].sum() if len(input_data1)>0 else 0 + paid_trading_fee = input_data1['transaction_fee'].sum() if len(input_data1) > 0 else 0 # 单笔最大亏损 single_loss = input_data[input_data['profit'].apply(lambda x: x < 0)] - single_largest_loss = 0 if len(single_loss)==0 else abs(single_loss['profit'].min()) + single_largest_loss = 0 if len(single_loss) == 0 else abs(single_loss['profit'].min()) # 平仓交易最大亏损 trading_loss = single_largest_loss # 平仓交易最大亏损比 @@ -582,8 +595,8 @@ def performance_summary(input_data, input_data1, capital = 500000, rf = 0.00, pe '月平均收益': monthly_average_return} return result + def do_trading_analyze(df_closes, df_funds): - df_wins = df_closes[df_closes["profit"] > 0] df_loses = df_closes[df_closes["profit"] <= 0] @@ -613,17 +626,18 @@ def do_trading_analyze(df_closes, df_funds): # 单笔最大亏损交易 largest_loss = df_loses['profit'].min() # 交易的平均持仓K线根数 - avgtrd_hold_bar = 0 if totaltimes==0 else ((df_closes['closebarno'] - df_closes['openbarno']).sum()) / totaltimes + avgtrd_hold_bar = 0 if totaltimes == 0 else ((df_closes['closebarno'] - df_closes['openbarno']).sum()) / totaltimes # 平均空仓K线根数 avb = (df_closes['openbarno'] - df_closes['closebarno'].shift(1).fillna(value=0)) - avgemphold_bar = 0 if len(df_closes)==0 else avb.sum() / len(df_closes) + avgemphold_bar = 0 if len(df_closes) == 0 else avb.sum() / len(df_closes) # 两笔盈利交易之间的平均空仓K线根数 win_holdbar_situ = (df_wins['openbarno'].shift(-1) - df_wins['closebarno']).dropna() - winempty_avgholdbar = 0 if len(df_wins)== 0 or len(df_wins) == 1 else win_holdbar_situ.sum() / (len(df_wins)-1) + winempty_avgholdbar = 0 if len(df_wins) == 0 or len(df_wins) == 1 else win_holdbar_situ.sum() / (len(df_wins) - 1) # 两笔亏损交易之间的平均空仓K线根数 loss_holdbar_situ = (df_loses['openbarno'].shift(-1) - df_loses['closebarno']).dropna() - lossempty_avgholdbar = 0 if len(df_loses)== 0 or len(df_loses) == 1 else loss_holdbar_situ.sum() / (len(df_loses)-1) + lossempty_avgholdbar = 0 if len(df_loses) == 0 or len(df_loses) == 1 else loss_holdbar_situ.sum() / ( + len(df_loses) - 1) max_consecutive_wins = 0 # 最大连续盈利次数 max_consecutive_loses = 0 # 最大连续亏损次数 @@ -663,7 +677,7 @@ def do_trading_analyze(df_closes, df_funds): summary["最大连续亏损次数"] = max_consecutive_loses summary["盈利交易的平均持仓K线根数"] = avg_bars_in_winner summary["亏损交易的平均持仓K线根数"] = avg_bars_in_loser - summary["账户净盈亏"] = 0 if totaltimes==0 else accnetprofit + summary["账户净盈亏"] = 0 if totaltimes == 0 else accnetprofit summary['单笔最大盈利交易'] = largest_profit summary['单笔最大亏损交易'] = largest_loss summary['交易的平均持仓K线根数'] = avgtrd_hold_bar @@ -674,10 +688,11 @@ def do_trading_analyze(df_closes, df_funds): summary = summary.reset_index() return summary -def trading_analyze(workbook:Workbook, df_closes, df_funds, capital = 500000): - ''' + +def trading_analyze(workbook: Workbook, df_closes, df_funds, capital=500000): + """ 交易分析 - ''' + """ res = average_profit(df_closes) rr = res.get('连续盈利次数') df = pd.DataFrame([rr]).T @@ -715,52 +730,53 @@ def trading_analyze(workbook:Workbook, df_closes, df_funds, capital = 500000): sss = extreme_trading(df_closes) title_format = workbook.add_format({ - 'font_size': 16, - 'bold': True, - 'align': 'left', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'font_size': 16, + 'bold': True, + 'align': 'left', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) index_format = workbook.add_format({ - 'font_size': 12, - 'bold': True, - 'align': 'left', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'font_size': 12, + 'bold': True, + 'align': 'left', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) value_format = workbook.add_format({ - 'align': 'right', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'align': 'right', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) date_format = workbook.add_format({ - 'num_format': 'yyyy/mm/dd', - 'bold': True, - 'align': 'left', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'num_format': 'yyyy/mm/dd', + 'bold': True, + 'align': 'left', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) worksheet = workbook.add_worksheet('交易分析') - df_closes['fee'] = df_closes['profit'] - df_closes['totalprofit'] + df_closes['totalprofit'].shift(1).fillna(value=0) + df_closes['fee'] = df_closes['profit'] - df_closes['totalprofit'] + df_closes['totalprofit'].shift(1).fillna( + value=0) trade_s = do_trading_analyze(df_closes, df_funds) data_1 = df_closes[df_closes['direct'].apply(lambda x: 'LONG' in x)] trade_s_long = do_trading_analyze(data_1, df_funds) data_2 = df_closes[df_closes['direct'].apply(lambda x: 'SHORT' in x)] trade_s_short = do_trading_analyze(data_2, df_funds) trade_s = trade_s.merge(trade_s_long, how='inner', on='index') - trade_s = trade_s.merge(trade_s_short,how='inner', on='index') - trade_s.columns =['类别', '所有交易', '多头', '空头'] + trade_s = trade_s.merge(trade_s_short, how='inner', on='index') + trade_s.columns = ['类别', '所有交易', '多头', '空头'] trade_s.fillna(value=0, inplace=True) worksheet.write_row('A1', ['总体交易分析'], title_format) - worksheet.write_row('B3', ['所有交易','多头交易','空头交易'], index_format) + worksheet.write_row('B3', ['所有交易', '多头交易', '空头交易'], index_format) worksheet.write_column('A4', trade_s['类别'], index_format) worksheet.write_column('B4', trade_s['所有交易'], value_format) worksheet.write_column('C4', trade_s['多头'], value_format) worksheet.write_column('D4', trade_s['空头'], value_format) worksheet.write_row('A28', ['极端交易'], title_format) - worksheet.write_row('B30', ['总计','极端盈利','极端亏损'], index_format) + worksheet.write_row('B30', ['总计', '极端盈利', '极端亏损'], index_format) worksheet.write_column('A31', sss.index, index_format) worksheet.write_column('B31', sss['总计'], value_format) worksheet.write_column('C31', sss['极端盈利'], value_format) @@ -771,128 +787,129 @@ def trading_analyze(workbook:Workbook, df_closes, df_funds, capital = 500000): worksheet.write_column('B40', s.values(), value_format) worksheet.write_row('A49', ['连续交易系列统计'], title_format) - worksheet.write_row('A51', ['连续盈利次数','出现次数','每个序列的平均收益'], index_format) + worksheet.write_row('A51', ['连续盈利次数', '出现次数', '每个序列的平均收益'], index_format) worksheet.write_column('A52', f_result['连续次数'], value_format) worksheet.write_column('B52', f_result['出现次数'], value_format) worksheet.write_column('C52', f_result['每个序列平均收益'], value_format) win_cnt = len(f_result) - next_row = win_cnt+52 - worksheet.write_row('A%d'%next_row, ['连续亏损次数','出现次数','每个序列的平均亏损'], index_format) - worksheet.write_column('A%d'%(next_row+1), f_2_result['连续次数'], value_format) - worksheet.write_column('B%d'%(next_row+1), f_2_result['出现次数'], value_format) - worksheet.write_column('C%d'%(next_row+1), f_2_result['每个序列平均亏损'], value_format) - + next_row = win_cnt + 52 + worksheet.write_row('A%d' % next_row, ['连续亏损次数', '出现次数', '每个序列的平均亏损'], index_format) + worksheet.write_column('A%d' % (next_row + 1), f_2_result['连续次数'], value_format) + worksheet.write_column('B%d' % (next_row + 1), f_2_result['出现次数'], value_format) + worksheet.write_column('C%d' % (next_row + 1), f_2_result['每个序列平均亏损'], value_format) # 这里开始画图 next_row += len(f_2_result) + 3 - worksheet.write_row('A%d'%next_row, ['全部交易'], title_format) + worksheet.write_row('A%d' % next_row, ['全部交易'], title_format) chart_col = workbook.add_chart({'type': 'scatter'}) length = len(df_closes) sheetName = '交易列表' chart_col.add_series( { 'name': '收益分布', - 'categories': '=%s!$A$4:$A$%s' % (sheetName, length+3), - 'values': '=%s!$J$4:$J$%s' % (sheetName, length+3), + 'categories': '=%s!$A$4:$A$%s' % (sheetName, length + 3), + 'values': '=%s!$J$4:$J$%s' % (sheetName, length + 3), 'marker': { - 'type':"circle", - 'size':3 + 'type': "circle", + 'size': 3 } } ) chart_col.set_title({'name': '收益分布'}) chart_col.set_x_axis({'label_position': 'low'}) - worksheet.insert_chart('A%d' % (next_row+2), chart_col,{'x_scale': 1.8, 'y_scale': 1.8}) + worksheet.insert_chart('A%d' % (next_row + 2), chart_col, {'x_scale': 1.8, 'y_scale': 1.8}) next_row += 30 - worksheet.write_row('A%d'%next_row, ['潜在盈利'], title_format) + worksheet.write_row('A%d' % next_row, ['潜在盈利'], title_format) chart_col = workbook.add_chart({'type': 'scatter'}) length = len(df_closes) sheetName = '交易列表' chart_col.add_series( { 'name': '潜在盈利', - 'categories': '=%s!$A$4:$A$%s' % (sheetName, length+3), - 'values': '=%s!$N$4:$N$%s' % (sheetName, length+3), + 'categories': '=%s!$A$4:$A$%s' % (sheetName, length + 3), + 'values': '=%s!$N$4:$N$%s' % (sheetName, length + 3), 'marker': { - 'type':"diamond", - 'size':3, + 'type': "diamond", + 'size': 3, 'border': {'color': 'red'}, - 'fill': {'color': 'red'} + 'fill': {'color': 'red'} } } ) chart_col.set_title({'name': '潜在盈利'}) chart_col.set_x_axis({'label_position': 'low'}) - worksheet.insert_chart('A%d' % (next_row+2), chart_col,{'x_scale': 1.8, 'y_scale': 1.8}) + worksheet.insert_chart('A%d' % (next_row + 2), chart_col, {'x_scale': 1.8, 'y_scale': 1.8}) next_row += 30 - worksheet.write_row('A%d'%next_row, ['潜在亏损'], title_format) + worksheet.write_row('A%d' % next_row, ['潜在亏损'], title_format) chart_col = workbook.add_chart({'type': 'scatter'}) length = len(df_closes) sheetName = '交易列表' chart_col.add_series( { 'name': '潜在亏损', - 'categories': '=%s!$A$4:$A$%s' % (sheetName, length+3), - 'values': '=%s!$P$4:$P$%s' % (sheetName, length+3), + 'categories': '=%s!$A$4:$A$%s' % (sheetName, length + 3), + 'values': '=%s!$P$4:$P$%s' % (sheetName, length + 3), 'marker': { - 'type':"triangle", - 'size':3, + 'type': "triangle", + 'size': 3, 'border': {'color': 'green'}, - 'fill': {'color': 'green'} + 'fill': {'color': 'green'} } } ) chart_col.set_title({'name': '潜在亏损'}) chart_col.set_x_axis({'label_position': 'low'}) - worksheet.insert_chart('A%d' % (next_row+2), chart_col,{'x_scale': 1.8, 'y_scale': 1.8}) + worksheet.insert_chart('A%d' % (next_row + 2), chart_col, {'x_scale': 1.8, 'y_scale': 1.8}) - # 周期分析 worksheet = workbook.add_worksheet('周期分析') df_closes['opentime'] = df_closes['opentime'].apply(lambda x: parse(str(int(x / 10000)))) - res = stat_closes_by_day(df_closes.copy(), capital) + res = stat_closes_by_day(df_closes.copy(), capital) worksheet.write_row('A1', ['日度绩效分析'], title_format) - worksheet.write_row('A3', ['期间','盈利(¤)','盈利(%)','毛利','毛损','交易次数','胜率(%)'], index_format) + worksheet.write_row('A3', ['期间', '盈利(¤)', '盈利(%)', '毛利', '毛损', '交易次数', '胜率(%)'], index_format) worksheet.write_column('A4', res.index, date_format) worksheet.write_column('B4', res["profit"], value_format) worksheet.write_column('C4', res["profit_ratio"], value_format) worksheet.write_column('D4', res["gross_profit"], value_format) worksheet.write_column('E4', res["gross_loss"], value_format) worksheet.write_column('F4', res["times"], value_format) - worksheet.write_column('G4', res["win_rate"]*100, value_format) - + worksheet.write_column('G4', res["win_rate"] * 100, value_format) + next_row = 5 + len(res) res = stat_closes_by_month(df_closes.copy(), capital) - worksheet.write_row('A%d'%(next_row+1), ['月度绩效分析'], title_format) - worksheet.write_row('A%d'%(next_row+3), ['期间','盈利(¤)','盈利(%)','毛利','毛损','交易次数','胜率(%)'], index_format) - worksheet.write_column('A%d'%(next_row+4), res.index, index_format) - worksheet.write_column('B%d'%(next_row+4), res["profit"], value_format) - worksheet.write_column('C%d'%(next_row+4), res["profit_ratio"], value_format) - worksheet.write_column('D%d'%(next_row+4), res["gross_profit"], value_format) - worksheet.write_column('E%d'%(next_row+4), res["gross_loss"], value_format) - worksheet.write_column('F%d'%(next_row+4), res["times"], value_format) - worksheet.write_column('G%d'%(next_row+4), res["win_rate"]*100, value_format) + worksheet.write_row('A%d' % (next_row + 1), ['月度绩效分析'], title_format) + worksheet.write_row('A%d' % (next_row + 3), ['期间', '盈利(¤)', '盈利(%)', '毛利', '毛损', '交易次数', '胜率(%)'], + index_format) + worksheet.write_column('A%d' % (next_row + 4), res.index, index_format) + worksheet.write_column('B%d' % (next_row + 4), res["profit"], value_format) + worksheet.write_column('C%d' % (next_row + 4), res["profit_ratio"], value_format) + worksheet.write_column('D%d' % (next_row + 4), res["gross_profit"], value_format) + worksheet.write_column('E%d' % (next_row + 4), res["gross_loss"], value_format) + worksheet.write_column('F%d' % (next_row + 4), res["times"], value_format) + worksheet.write_column('G%d' % (next_row + 4), res["win_rate"] * 100, value_format) next_row = next_row + 4 + len(res) - res = stat_closes_by_year(df_closes.copy(), capital) - worksheet.write_row('A%d'%(next_row+1), ['年度绩效分析'], title_format) - worksheet.write_row('A%d'%(next_row+3), ['期间','盈利(¤)','盈利(%)','毛利','毛损','交易次数','胜率(%)'], index_format) - worksheet.write_column('A%d'%(next_row+4), res.index, index_format) - worksheet.write_column('B%d'%(next_row+4), res["profit"], value_format) - worksheet.write_column('C%d'%(next_row+4), res["profit_ratio"], value_format) - worksheet.write_column('D%d'%(next_row+4), res["gross_profit"], value_format) - worksheet.write_column('E%d'%(next_row+4), res["gross_loss"], value_format) - worksheet.write_column('F%d'%(next_row+4), res["times"], value_format) - worksheet.write_column('G%d'%(next_row+4), res["win_rate"]*100, value_format) - -def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, rf = 0.0, period = 240): - ''' + res = stat_closes_by_year(df_closes.copy(), capital) + worksheet.write_row('A%d' % (next_row + 1), ['年度绩效分析'], title_format) + worksheet.write_row('A%d' % (next_row + 3), ['期间', '盈利(¤)', '盈利(%)', '毛利', '毛损', '交易次数', '胜率(%)'], + index_format) + worksheet.write_column('A%d' % (next_row + 4), res.index, index_format) + worksheet.write_column('B%d' % (next_row + 4), res["profit"], value_format) + worksheet.write_column('C%d' % (next_row + 4), res["profit_ratio"], value_format) + worksheet.write_column('D%d' % (next_row + 4), res["gross_profit"], value_format) + worksheet.write_column('E%d' % (next_row + 4), res["gross_loss"], value_format) + worksheet.write_column('F%d' % (next_row + 4), res["times"], value_format) + worksheet.write_column('G%d' % (next_row + 4), res["win_rate"] * 100, value_format) + + +def strategy_analyze(workbook: Workbook, df_closes, df_trades, df_funds, capital, rf=0.0, period=240): + """ 策略分析 - ''' + """ # 截取开仓明细 data1_open = df_trades[df_trades['action'].apply(lambda x: 'OPEN' in x)].reset_index() @@ -918,21 +935,22 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, # 合并数据 after_merge = pd.merge(df_closes, clean_data, how='inner', on='opentime') - data_long = df_closes[df_closes['direct'].apply(lambda x:'LONG' in x )].reset_index() + data_long = df_closes[df_closes['direct'].apply(lambda x: 'LONG' in x)].reset_index() after_merge_long = after_merge[after_merge['direct'].apply(lambda x: 'LONG' in x)].reset_index() data_short = df_closes[df_closes['direct'].apply(lambda x: 'SHORT' in x)].reset_index() after_merge_short = after_merge[after_merge['direct'].apply(lambda x: 'SHORT' in x)].reset_index() # 全部平仓明细进行绩效分析 - result1 = performance_summary(df_closes, after_merge, capital=capital, rf=rf, period=period,data2=df_funds) + result1 = performance_summary(df_closes, after_merge, capital=capital, rf=rf, period=period, data2=df_funds) # 做多平仓明细进行绩效分析 - result1_2 = performance_summary(data_long, after_merge_long, capital=capital, rf=rf, period=period,data2= df_funds) + result1_2 = performance_summary(data_long, after_merge_long, capital=capital, rf=rf, period=period, data2=df_funds) # 做空平仓明细进行绩效分析 - result1_3 = performance_summary(data_short,after_merge_short, capital=capital, rf=rf, period=period,data2= df_funds) + result1_3 = performance_summary(data_short, after_merge_short, capital=capital, rf=rf, period=period, + data2=df_funds) # 绩效比率计算 - result2 = ratio_calculate(df_closes,df_funds, after_merge, capital=capital, rf=rf, period=period) + result2 = ratio_calculate(df_closes, df_funds, after_merge, capital=capital, rf=rf, period=period) # 时间分析 - result3 = time_analysis(df_closes,df_funds) + result3 = time_analysis(df_closes, df_funds) result1 = pd.DataFrame(pd.Series(result1), columns=['所有交易']) result1 = result1.reset_index().rename(columns={'index': '策略绩效概要'}) @@ -943,64 +961,63 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, result1_3 = pd.DataFrame(pd.Series(result1_3), columns=['空头交易']) result1_3 = result1_3.reset_index().rename(columns={'index': '策略绩效概要'}) - result1 = result1.merge(result1_2,how='inner',on='策略绩效概要') - result1 = result1.merge(result1_3,how='inner',on='策略绩效概要') + result1 = result1.merge(result1_2, how='inner', on='策略绩效概要') + result1 = result1.merge(result1_3, how='inner', on='策略绩效概要') sheetName = '策略分析' worksheet = workbook.add_worksheet(sheetName) title_format = workbook.add_format({ - 'font_size': 16, - 'bold': True, - 'align': 'left', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'font_size': 16, + 'bold': True, + 'align': 'left', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) index_format = workbook.add_format({ - 'font_size': 12, - 'bold': True, - 'align': 'left', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'font_size': 12, + 'bold': True, + 'align': 'left', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) value_format = workbook.add_format({ - 'align': 'right', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'align': 'right', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) result1.fillna(value=0, inplace=True) - worksheet.write_row('A1', ['策略绩效概要'], title_format) - worksheet.write_row('B3', ['所有交易','多头交易','空头交易'], index_format) + worksheet.write_row('A1', ['策略绩效概要'], title_format) + worksheet.write_row('B3', ['所有交易', '多头交易', '空头交易'], index_format) worksheet.write_column('A4', result1['策略绩效概要'], index_format) worksheet.write_column('B4', result1['所有交易'], value_format) worksheet.write_column('C4', result1['多头交易'], value_format) worksheet.write_column('D4', result1['空头交易'], value_format) - worksheet.write_row('A22', ['绩效比率'], title_format) + worksheet.write_row('A22', ['绩效比率'], title_format) worksheet.write_column('A24', result2.keys(), index_format) worksheet.write_column('B24', result2.values(), value_format) - worksheet.write_row('A37', ['时间分析'], title_format) + worksheet.write_row('A37', ['时间分析'], title_format) worksheet.write_column('A39', result3.keys(), index_format) worksheet.write_column('B39', result3.values(), value_format) - #修正:重算多头、空头交易的年化月化月平均收益率 - net_profit_long = result1_2.loc[5,'多头交易'] - net_profit_short = result1_3.loc[5,'空头交易'] + # 修正:重算多头、空头交易的年化月化月平均收益率 + net_profit_long = result1_2.loc[5, '多头交易'] + net_profit_short = result1_3.loc[5, '空头交易'] trade_day = df_funds.shape[0] - long_annualp = ((net_profit_long + capital) / capital) ** (period / trade_day) -1 + long_annualp = ((net_profit_long + capital) / capital) ** (period / trade_day) - 1 short_annualp = ((net_profit_short + capital) / capital) ** (period / trade_day) - 1 - long_monthlyp = (long_annualp + 1) ** (1/12) -1 - short_monthlyp = (short_annualp + 1 ) ** (1/12) -1 - long_month_average = ((net_profit_long + capital) / capital -1) / trade_day * period / 12 - short_month_average = ((net_profit_short + capital) / capital -1) / trade_day * period / 12 - worksheet.write('C17',long_annualp,value_format) - worksheet.write('D17',short_annualp,value_format) + long_monthlyp = (long_annualp + 1) ** (1 / 12) - 1 + short_monthlyp = (short_annualp + 1) ** (1 / 12) - 1 + long_month_average = ((net_profit_long + capital) / capital - 1) / trade_day * period / 12 + short_month_average = ((net_profit_short + capital) / capital - 1) / trade_day * period / 12 + worksheet.write('C17', long_annualp, value_format) + worksheet.write('D17', short_annualp, value_format) worksheet.write('C18', long_monthlyp, value_format) worksheet.write('D18', short_monthlyp, value_format) worksheet.write('C19', long_month_average, value_format) worksheet.write('D19', short_month_average, value_format) - # 这里开始画图 worksheet.write_row('A49', ['详细权益曲线'], title_format) @@ -1011,8 +1028,8 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, chart_col.add_series( { 'name': '详细权益曲线', - 'categories': '=%s!$S$4:$S$%s' % (sheetName, length+3), - 'values': '=%s!$R$4:$R$%s' % (sheetName, length+3), + 'categories': '=%s!$S$4:$S$%s' % (sheetName, length + 3), + 'values': '=%s!$R$4:$R$%s' % (sheetName, length + 3), 'line': {'color': 'red', 'width': 1} } ) @@ -1020,14 +1037,13 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, chart_col.set_x_axis({'name': '平仓K线编号'}) worksheet.insert_chart('A51', chart_col, {'x_scale': 1.8, 'y_scale': 1.8}) - worksheet.write_row('A79', ['每笔收益'], title_format) chart_col = workbook.add_chart({'type': 'column'}) chart_col.add_series( { 'name': '每笔收益', - 'categories': '=%s!$A$4:$A$%s' % (sheetName, length+3), - 'values': '=%s!$J$4:$J$%s' % (sheetName, length+3), + 'categories': '=%s!$A$4:$A$%s' % (sheetName, length + 3), + 'values': '=%s!$J$4:$J$%s' % (sheetName, length + 3), 'line': {'color': 'black', 'width': 1} } ) @@ -1040,17 +1056,17 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, chart_col.add_series( { 'name': '潜在盈利', - 'categories': '=%s!$S$4:$S$%s' % (sheetName, length+3), - 'values': '=%s!$N$4:$N$%s' % (sheetName, length+3), - 'line': {'color': 'red','width': 1} + 'categories': '=%s!$S$4:$S$%s' % (sheetName, length + 3), + 'values': '=%s!$N$4:$N$%s' % (sheetName, length + 3), + 'line': {'color': 'red', 'width': 1} } ) chart_col.add_series( { 'name': '潜在亏损', - 'categories': '=%s!$S$4:$S$%s' % (sheetName, length+3), - 'values': '=%s!$P$4:$P$%s' % (sheetName, length+3), - 'line': {'color': 'green','width': 1} + 'categories': '=%s!$S$4:$S$%s' % (sheetName, length + 3), + 'values': '=%s!$P$4:$P$%s' % (sheetName, length + 3), + 'line': {'color': 'green', 'width': 1} } ) chart_col.set_x_axis({'label_position': 'low', 'name': '平仓K线编号'}) @@ -1061,16 +1077,17 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, # df_closes['exittime'] = df_closes['closetime'].apply(lambda x: datetime.strptime(str(x), '%Y%m%d%H%M')) # df_closes['exittime'] = pd.to_datetime(df_closes['exittime']) - # #用matplotlib画图 # plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签 # worksheet.write_row('A139', ['详细多头权益曲线'], title_format) - df_closes['fee'] = df_closes['profit'] - df_closes['totalprofit'] + df_closes['totalprofit'].shift(1).fillna(value=0) + df_closes['fee'] = df_closes['profit'] - df_closes['totalprofit'] + df_closes['totalprofit'].shift(1).fillna( + value=0) df_temp = pd.DataFrame() - df_temp['profit'] = df_closes[df_closes['direct'] == 'LONG']['profit'] - df_closes[df_closes['direct'] == 'LONG']['fee'] + df_temp['profit'] = df_closes[df_closes['direct'] == 'LONG']['profit'] - df_closes[df_closes['direct'] == 'LONG'][ + 'fee'] df_temp['equity'] = df_temp['profit'].expanding().sum() + capital - np_temp = np.arange(1, len(df_temp)+1, 1) + np_temp = np.arange(1, len(df_temp) + 1, 1) df_temp['index'] = np_temp # plt.plot(df_temp['index'], df_temp['equity']) @@ -1085,7 +1102,8 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, # worksheet.write_row('A169', ['详细空头权益曲线'], title_format) # plt.clf() df_temp2 = pd.DataFrame() - df_temp2['profit'] = df_closes[df_closes['direct'] == 'SHORT']['profit'] - df_closes[df_closes['direct'] == 'SHORT']['fee'] + df_temp2['profit'] = df_closes[df_closes['direct'] == 'SHORT']['profit'] - \ + df_closes[df_closes['direct'] == 'SHORT']['fee'] df_temp2['equity'] = df_temp2['profit'].expanding().sum() + capital np_temp2 = np.arange(1, len(df_temp2) + 1, 1) df_temp2['index'] = np_temp2 @@ -1103,11 +1121,11 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, # ) worksheet = workbook.add_worksheet('交易列表') length0 = len(df_closes) - worksheet.write_row('A'+str(length0+98), ['作图数据'], index_format) - worksheet.write_column('A'+str(length0+100), df_temp['index'], value_format) - worksheet.write_column('B'+str(length0+100), df_temp['equity'], value_format) - worksheet.write_column('C'+str(length0+100), df_temp2['index'], value_format) - worksheet.write_column('D'+str(length0+100), df_temp2['equity'], value_format) + worksheet.write_row('A' + str(length0 + 98), ['作图数据'], index_format) + worksheet.write_column('A' + str(length0 + 100), df_temp['index'], value_format) + worksheet.write_column('B' + str(length0 + 100), df_temp['equity'], value_format) + worksheet.write_column('C' + str(length0 + 100), df_temp2['index'], value_format) + worksheet.write_column('D' + str(length0 + 100), df_temp2['equity'], value_format) worksheet = workbook.get_worksheet_by_name('策略分析') worksheet.write_row('A139', ['详细多头权益曲线'], title_format) @@ -1118,8 +1136,8 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, chart_col.add_series( { 'name': '详细权益曲线', - 'categories': '=%s!$A$%s:$A$%s' % (sheetName, length0+100, length0+100+length), - 'values': '=%s!$B$%s:$B$%s' % (sheetName, length0+100, length0+100+length), + 'categories': '=%s!$A$%s:$A$%s' % (sheetName, length0 + 100, length0 + 100 + length), + 'values': '=%s!$B$%s:$B$%s' % (sheetName, length0 + 100, length0 + 100 + length), 'line': {'color': 'red', 'width': 1} } ) @@ -1134,8 +1152,8 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, chart_col.add_series( { 'name': '详细权益曲线', - 'categories': '=%s!$C$%s:$C$%s' % (sheetName, length0+100, length0+100+length), - 'values': '=%s!$D$%s:$D$%s' % (sheetName, length0+100, length0+100+length), + 'categories': '=%s!$C$%s:$C$%s' % (sheetName, length0 + 100, length0 + 100 + length), + 'values': '=%s!$D$%s:$D$%s' % (sheetName, length0 + 100, length0 + 100 + length), 'line': {'color': 'red', 'width': 1} } ) @@ -1143,56 +1161,58 @@ def strategy_analyze(workbook:Workbook, df_closes, df_trades,df_funds, capital, chart_col.set_x_axis({'name': '交易列表'}) worksheet.insert_chart('A171', chart_col, {'x_scale': 1.8, 'y_scale': 1.8}) -def output_closes(workbook:Workbook, df_closes:df, capital = 500000): + +def output_closes(workbook: Workbook, df_closes: df, capital=500000): worksheet = workbook.get_worksheet_by_name('交易列表') title_format = workbook.add_format({ - 'font_size': 16, - 'bold': True, - 'align': 'left', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'font_size': 16, + 'bold': True, + 'align': 'left', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) index_format = workbook.add_format({ - 'font_size': 12, - 'bold': True, - 'align': 'left', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'font_size': 12, + 'bold': True, + 'align': 'left', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) value_format = workbook.add_format({ - 'align': 'right', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'align': 'right', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) time_format = workbook.add_format({ - 'num_format': 'yyyy/mm/dd HH:MM', - 'align': 'right', # 水平居中 - 'valign': 'vcenter' # 垂直居中 + 'num_format': 'yyyy/mm/dd HH:MM', + 'align': 'right', # 水平居中 + 'valign': 'vcenter' # 垂直居中 }) - df_closes['entrytime'] = df_closes['opentime'].apply(lambda x: datetime.strptime(str(x), '%Y%m%d%H%M')) df_closes['exittime'] = df_closes['closetime'].apply(lambda x: datetime.strptime(str(x), '%Y%m%d%H%M')) - worksheet.write_row('A1', ['交易列表'], title_format) - worksheet.write_row('A3', ['编号', '代码','方向','进场时间','进场价格','进场标记','出场时间','出场价格','出场标记', - '盈利¤','盈利%','累计盈利¤','累计盈利%','潜在盈利¤','潜在盈利%','潜在亏损¤','潜在亏损%','累计权益','平仓K线编号'], index_format) - df_closes["profit_ratio"] = df_closes["profit"]*100/capital - df_closes["total_profit_ratio"] = df_closes["totalprofit"]*100/capital - df_closes["max_profit_ratio"] = df_closes["maxprofit"]*100/capital - df_closes["max_loss_ratio"] = df_closes["maxloss"]*100/capital - df_closes['direct'] = df_closes['direct'].apply(lambda x : '多' if x=='LONG' else '空' if x=='SHORT' else x) - - worksheet.write_column('A4', df_closes.index+1, value_format) + worksheet.write_row('A1', ['交易列表'], title_format) + worksheet.write_row('A3', + ['编号', '代码', '方向', '进场时间', '进场价格', '进场标记', '出场时间', '出场价格', '出场标记', + '盈利¤', '盈利%', '累计盈利¤', '累计盈利%', '潜在盈利¤', '潜在盈利%', '潜在亏损¤', '潜在亏损%', + '累计权益', '平仓K线编号'], index_format) + df_closes["profit_ratio"] = df_closes["profit"] * 100 / capital + df_closes["total_profit_ratio"] = df_closes["totalprofit"] * 100 / capital + df_closes["max_profit_ratio"] = df_closes["maxprofit"] * 100 / capital + df_closes["max_loss_ratio"] = df_closes["maxloss"] * 100 / capital + df_closes['direct'] = df_closes['direct'].apply(lambda x: '多' if x == 'LONG' else '空' if x == 'SHORT' else x) + + worksheet.write_column('A4', df_closes.index + 1, value_format) worksheet.write_column('B4', df_closes['code'], value_format) worksheet.write_column('C4', df_closes['direct'], value_format) worksheet.write_column('D4', df_closes['entrytime'], time_format) worksheet.write_column('E4', df_closes['openprice'], value_format) - ay = df_closes['entertag'].apply(lambda x: x if type(x)==str else '' if math.isnan(x) else x) + ay = df_closes['entertag'].apply(lambda x: x if isinstance(x, str) else '' if math.isnan(x) else x) worksheet.write_column('F4', ay, value_format) worksheet.write_column('G4', df_closes['exittime'], time_format) worksheet.write_column('H4', df_closes['closeprice'], value_format) - ay = df_closes['exittag'].apply(lambda x: x if type(x)==str else '' if math.isnan(x) else x) + ay = df_closes['exittag'].apply(lambda x: x if isinstance(x, str) else '' if math.isnan(x) else x) worksheet.write_column('I4', ay, value_format) worksheet.write_column('J4', df_closes['profit'], value_format) @@ -1203,214 +1223,217 @@ def output_closes(workbook:Workbook, df_closes:df, capital = 500000): worksheet.write_column('O4', df_closes['max_profit_ratio'], value_format) worksheet.write_column('P4', df_closes['maxloss'], value_format) worksheet.write_column('Q4', df_closes['max_loss_ratio'], value_format) - worksheet.write_column('R4', df_closes['totalprofit']+capital, value_format) + worksheet.write_column('R4', df_closes['totalprofit'] + capital, value_format) worksheet.write_column('S4', df_closes['closebarno'], value_format) -def summary_analyze(df_funds:df, capital = 5000000, rf = 0, period = 240) -> dict: - ''' + +def summary_analyze(df_funds: df, capital=5000000, rf=0, period=240) -> dict: + """ 概要分析 - ''' + """ init_capital = capital annual_days = period days = len(df_funds) - #先做资金统计吧 + # 先做资金统计吧 # print("anayzing fund data……") df_funds["dynbalance"] += init_capital - ayBal = df_funds["dynbalance"] # 每日期末动态权益 + ayBal = df_funds["dynbalance"] # 每日期末动态权益 - #生成每日期初动态权益 - ayPreBal = np.array(ayBal.tolist()[:-1]) - ayPreBal = np.insert(ayPreBal, 0, init_capital) #每日期初权益 + # 生成每日期初动态权益 + ayPreBal = np.array(ayBal.tolist()[:-1]) + ayPreBal = np.insert(ayPreBal, 0, init_capital) # 每日期初权益 df_funds["prebalance"] = ayPreBal - #统计期末权益大于期初权益的天数,即盈利天数 - windays = len(df_funds[df_funds["dynbalance"]>df_funds["prebalance"]]) + # 统计期末权益大于期初权益的天数,即盈利天数 + windays = len(df_funds[df_funds["dynbalance"] > df_funds["prebalance"]]) + + # 每日净值 + ayNetVals = (ayBal / init_capital) - #每日净值 - ayNetVals = (ayBal/init_capital) - if ayNetVals.iloc[-1] >= 0: - ar = math.pow(ayNetVals.iloc[-1], annual_days/days) - 1 #年化收益率=总收益率^(年交易日天数/统计天数) + ar = math.pow(ayNetVals.iloc[-1], annual_days / days) - 1 # 年化收益率=总收益率^(年交易日天数/统计天数) else: ar = -9999 - ayDailyReturn = ayBal/ayPreBal-1 #每日收益率 - delta = fmtNAN(ayDailyReturn.std(axis=0)*math.pow(annual_days,0.5),0) #年化标准差=每日收益率标准差*根号下(年交易日天数) - down_delta = fmtNAN(ayDailyReturn[ayDailyReturn<0].std(axis=0)*math.pow(annual_days,0.5), 0) #下行标准差=每日亏损收益率标准差*根号下(年交易日天数) + ayDailyReturn = ayBal / ayPreBal - 1 # 每日收益率 + delta = fmtNAN(ayDailyReturn.std(axis=0) * math.pow(annual_days, 0.5), 0) # 年化标准差=每日收益率标准差*根号下(年交易日天数) + down_delta = fmtNAN(ayDailyReturn[ayDailyReturn < 0].std(axis=0) * math.pow(annual_days, 0.5), + 0) # 下行标准差=每日亏损收益率标准差*根号下(年交易日天数) - #sharpe率 + # sharpe率 if delta != 0.0: - sr = (ar-rf)/delta + sr = (ar - rf) / delta else: sr = 9999.0 - #计算最大回撤和最大上涨 + # 计算最大回撤和最大上涨 maxub = ayNetVals[0] minub = maxub mdd = 0.0 midd = 0.0 mup = 0.0 - for idx in range(1,len(ayNetVals)): + for idx in range(1, len(ayNetVals)): maxub = max(maxub, ayNetVals[idx]) minub = min(minub, ayNetVals[idx]) - profit = (ayNetVals[idx] - ayNetVals[idx-1])/ayNetVals[idx-1] - falldown = (ayNetVals[idx] - maxub)/maxub - riseup = (ayNetVals[idx] - minub)/minub + profit = (ayNetVals[idx] - ayNetVals[idx - 1]) / ayNetVals[idx - 1] + falldown = (ayNetVals[idx] - maxub) / maxub + riseup = (ayNetVals[idx] - minub) / minub if profit <= 0: midd = max(midd, abs(profit)) mdd = max(mdd, abs(falldown)) else: mup = max(mup, abs(riseup)) - #索提诺比率 + # 索提诺比率 if down_delta != 0.0: - sortino = (ar-rf)/down_delta + sortino = (ar - rf) / down_delta else: sortino = 0.0 if mdd != 0.0: - calmar = ar/mdd + calmar = ar / mdd else: calmar = 999999.0 - # key_indicator = ['交易天数', '累积收益(%)', '年化收益率(%)', '胜率(%)', '最大回撤(%)', '最大上涨(%)', '标准差(%)', # '下行波动率(%)', 'Sharpe比率', 'Sortino比率', 'Calmar比率'] return { 'capital': capital, "days": days, - "total_return":(ayNetVals.iloc[-1]-1)*100, - "annual_return":ar*100, - "win_rate":(windays/days)*100, - "max_falldown":mdd*100, - "max_profratio":mup*100, - "std":delta*100, - "down_std":down_delta*100, - "sharpe_ratio":sr, - "sortino_ratio":sortino, - "calmar_ratio":calmar + "total_return": (ayNetVals.iloc[-1] - 1) * 100, + "annual_return": ar * 100, + "win_rate": (windays / days) * 100, + "max_falldown": mdd * 100, + "max_profratio": mup * 100, + "std": delta * 100, + "down_std": down_delta * 100, + "sharpe_ratio": sr, + "sortino_ratio": sortino, + "calmar_ratio": calmar } -def funds_analyze(workbook:Workbook, df_funds:df, capital = 5000000, rf = 0, period = 240): - ''' + +def funds_analyze(workbook: Workbook, df_funds: df, capital=5000000, rf=0, period=240): + """ 逐日资金分析 - ''' + """ init_capital = capital annual_days = period days = len(df_funds) - #先做资金统计吧 + # 先做资金统计吧 print("anayzing fund data……") df_funds["dynbalance"] += init_capital - ayBal = df_funds["dynbalance"] # 每日期末动态权益 + ayBal = df_funds["dynbalance"] # 每日期末动态权益 - #生成每日期初动态权益 - ayPreBal = np.array(ayBal.tolist()[:-1]) - ayPreBal = np.insert(ayPreBal, 0, init_capital) #每日期初权益 + # 生成每日期初动态权益 + ayPreBal = np.array(ayBal.tolist()[:-1]) + ayPreBal = np.insert(ayPreBal, 0, init_capital) # 每日期初权益 df_funds["prebalance"] = ayPreBal - #统计期末权益大于期初权益的天数,即盈利天数 - windays = len(df_funds[df_funds["dynbalance"]>df_funds["prebalance"]]) + # 统计期末权益大于期初权益的天数,即盈利天数 + windays = len(df_funds[df_funds["dynbalance"] > df_funds["prebalance"]]) + + # 每日净值 + ayNetVals = (ayBal / init_capital) - #每日净值 - ayNetVals = (ayBal/init_capital) - - ar = math.pow(ayNetVals.iloc[-1], annual_days/days) - 1 #年化收益率=总收益率^(年交易日天数/统计天数) - ayDailyReturn = ayBal/ayPreBal-1 #每日收益率 - delta = fmtNAN(ayDailyReturn.std(axis=0)*math.pow(annual_days,0.5),0) #年化标准差=每日收益率标准差*根号下(年交易日天数) - down_delta = fmtNAN(ayDailyReturn[ayDailyReturn<0].std(axis=0)*math.pow(annual_days,0.5), 0) #下行标准差=每日亏损收益率标准差*根号下(年交易日天数) + ar = math.pow(ayNetVals.iloc[-1], annual_days / days) - 1 # 年化收益率=总收益率^(年交易日天数/统计天数) + ayDailyReturn = ayBal / ayPreBal - 1 # 每日收益率 + delta = fmtNAN(ayDailyReturn.std(axis=0) * math.pow(annual_days, 0.5), 0) # 年化标准差=每日收益率标准差*根号下(年交易日天数) + down_delta = fmtNAN(ayDailyReturn[ayDailyReturn < 0].std(axis=0) * math.pow(annual_days, 0.5), + 0) # 下行标准差=每日亏损收益率标准差*根号下(年交易日天数) - #sharpe率 + # sharpe率 if delta != 0.0: - sr = (ar-rf)/delta + sr = (ar - rf) / delta else: sr = 9999.0 - #计算最大回撤和最大上涨 + # 计算最大回撤和最大上涨 maxub = ayNetVals[0] minub = maxub mdd = 0.0 midd = 0.0 mup = 0.0 - for idx in range(1,len(ayNetVals)): + for idx in range(1, len(ayNetVals)): maxub = max(maxub, ayNetVals[idx]) minub = min(minub, ayNetVals[idx]) - profit = (ayNetVals[idx] - ayNetVals[idx-1])/ayNetVals[idx-1] - falldown = (ayNetVals[idx] - maxub)/maxub - riseup = (ayNetVals[idx] - minub)/minub + profit = (ayNetVals[idx] - ayNetVals[idx - 1]) / ayNetVals[idx - 1] + falldown = (ayNetVals[idx] - maxub) / maxub + riseup = (ayNetVals[idx] - minub) / minub if profit <= 0: midd = max(midd, abs(profit)) mdd = max(mdd, abs(falldown)) else: mup = max(mup, abs(riseup)) - #索提诺比率 + # 索提诺比率 if down_delta != 0.0: - sortino = (ar-rf)/down_delta + sortino = (ar - rf) / down_delta else: sortino = 0.0 if mdd != 0.0: - calmar = ar/mdd + calmar = ar / mdd else: calmar = 999999.0 - #输出到excel + # 输出到excel sheetName = '逐日绩效概览' worksheet = workbook.add_worksheet(sheetName) # 设置合并单元格及格式 # # ~~~~~~ 写入数据 ~~~~~~ # title_format = workbook.add_format({ - 'bold': True, - 'border': 1, - 'align': 'center', # 水平居中 - 'valign': 'vcenter', # 垂直居中 + 'bold': True, + 'border': 1, + 'align': 'center', # 水平居中 + 'valign': 'vcenter', # 垂直居中 'fg_color': '#bcbcbc' }) fund_data_format = workbook.add_format({ 'border': 1, - 'align': 'right', # 右对齐 - 'valign': 'vcenter', # 垂直居中 + 'align': 'right', # 右对齐 + 'valign': 'vcenter', # 垂直居中 }) - + fund_data_format_2 = workbook.add_format({ 'border': 1, - 'align': 'right', # 右对齐 - 'valign': 'vcenter', # 垂直居中 + 'align': 'right', # 右对齐 + 'valign': 'vcenter', # 垂直居中 'num_format': '0.00' }) fund_data_format_3 = workbook.add_format({ 'border': 1, - 'align': 'right', # 右对齐 - 'valign': 'vcenter', # 垂直居中 + 'align': 'right', # 右对齐 + 'valign': 'vcenter', # 垂直居中 'num_format': '0.000' }) fund_data_format_4 = workbook.add_format({ 'border': 1, - 'align': 'right', # 右对齐 - 'valign': 'vcenter', # 垂直居中 + 'align': 'right', # 右对齐 + 'valign': 'vcenter', # 垂直居中 'num_format': '0.0000' }) - merge_format = workbook.add_format({ 'font_size': 16, - 'bold': True, - 'align': 'center', # 水平居中 - 'valign': 'vcenter', # 垂直居中 + 'bold': True, + 'align': 'center', # 水平居中 + 'valign': 'vcenter', # 垂直居中 }) indicator_format = workbook.add_format({ 'font_size': 12, - 'bold': True, - 'align': 'center', # 水平居中 - 'valign': 'vcenter', # 垂直居中 + 'bold': True, + 'align': 'center', # 水平居中 + 'valign': 'vcenter', # 垂直居中 }) worksheet.merge_range('A1:D1', '收益率统计指标', merge_format) worksheet.merge_range('E1:H1', '风险统计指标', merge_format) worksheet.merge_range('I1:K1', '综合指标', merge_format) key_indicator = ['交易天数', '累积收益(%)', '年化收益率(%)', '胜率(%)', '最大回撤(%)', '最大上涨(%)', '标准差(%)', - '下行波动率(%)', 'Sharpe比率', 'Sortino比率', 'Calmar比率'] - key_data = [(ayNetVals.iloc[-1]-1)*100, ar*100, (windays/days)*100, mdd*100, mup*100, delta*100, down_delta*100, sr, sortino, calmar] + '下行波动率(%)', 'Sharpe比率', 'Sortino比率', 'Calmar比率'] + key_data = [(ayNetVals.iloc[-1] - 1) * 100, ar * 100, (windays / days) * 100, mdd * 100, mup * 100, delta * 100, + down_delta * 100, sr, sortino, calmar] worksheet.write_row('A2', key_indicator, indicator_format) worksheet.write_column('A3', [days], fund_data_format) worksheet.write_row('B3', key_data, fund_data_format_3) @@ -1418,12 +1441,12 @@ def funds_analyze(workbook:Workbook, df_funds:df, capital = 5000000, rf = 0, per # 画图 # chart_col = workbook.add_chart({'type': 'line'}) length = days - chart_col.add_series( # 给图表设置格式,填充内容 + chart_col.add_series( # 给图表设置格式,填充内容 { 'name': '累计净值', - 'categories': '=逐日绩效分析!$A$3:$A$%d' % (length+2), - 'values': '=逐日绩效分析!$G$3:$G$%d' % (length+2), - 'line': {'color': 'blue', 'width':1}, + 'categories': '=逐日绩效分析!$A$3:$A$%d' % (length + 2), + 'values': '=逐日绩效分析!$G$3:$G$%d' % (length + 2), + 'line': {'color': 'blue', 'width': 1}, } ) chart_col.set_title({'name': '累计净值'}) @@ -1434,8 +1457,8 @@ def funds_analyze(workbook:Workbook, df_funds:df, capital = 5000000, rf = 0, per worksheet = workbook.add_worksheet(sheetName) title_format2 = workbook.add_format({ 'border': 1, - 'align': 'center', # 水平居中 - 'valign': 'vcenter', # 垂直居中 + 'align': 'center', # 水平居中 + 'valign': 'vcenter', # 垂直居中 'fg_color': '#D3D3D3', 'text_wrap': 1 }) @@ -1460,55 +1483,56 @@ def funds_analyze(workbook:Workbook, df_funds:df, capital = 5000000, rf = 0, per # 写入内容 # profit_format = workbook.add_format({ 'border': 1, - 'align': 'right', # 靠右 - 'valign': 'vcenter', # 垂直居中 + 'align': 'right', # 靠右 + 'valign': 'vcenter', # 垂直居中 'fg_color': '#FAFAD2', 'num_format': '0.00' }) percent_format = workbook.add_format({ 'border': 1, - 'align': 'right', # 右对齐 - 'valign': 'vcenter', # 垂直居中 + 'align': 'right', # 右对齐 + 'valign': 'vcenter', # 垂直居中 'num_format': '0.00%' }) date_format = workbook.add_format({ - 'num_format': 'yyyy/mm/dd', + 'num_format': 'yyyy/mm/dd', 'border': 1, - 'align': 'right', # 右对齐 - 'valign': 'vcenter', # 垂直居中 + 'align': 'right', # 右对齐 + 'valign': 'vcenter', # 垂直居中 }) - ayDates = df_funds['date'].apply(lambda x: str(x)[:4]+'/'+str(x)[4:6]+'/'+str(x)[6:8]) + ayDates = df_funds['date'].apply(lambda x: str(x)[:4] + '/' + str(x)[4:6] + '/' + str(x)[6:8]) worksheet.write_column('A3', ayDates, date_format) worksheet.write_column('B3', range(len(df_funds)), fund_data_format) - initial = [init_capital]*len(df_funds) + initial = [init_capital] * len(df_funds) worksheet.write_column('C3', initial, fund_data_format) worksheet.write_column('D3', '/', fund_data_format) worksheet.write_column('E3', ayBal, fund_data_format) - worksheet.write_column('F3', ayBal-init_capital, fund_data_format_2) + worksheet.write_column('F3', ayBal - init_capital, fund_data_format_2) worksheet.write_column('G3', ayNetVals, fund_data_format_4) - worksheet.write_column('H3', ayBal-ayPreBal, profit_format) + worksheet.write_column('H3', ayBal - ayPreBal, profit_format) worksheet.write_column('I3', ayDailyReturn, percent_format) # 计算峰值 upper = np.maximum.accumulate(ayNetVals) worksheet.write_column('J3', upper, fund_data_format_4) # 回撤指标 - temp = 1-(ayNetVals)/(np.maximum.accumulate(ayNetVals)) + temp = 1 - (ayNetVals) / (np.maximum.accumulate(ayNetVals)) worksheet.write_column('K3', temp, percent_format) worksheet.write_column('L3', np.maximum.accumulate(temp), percent_format) worksheet.write_column('M3', np.minimum.accumulate(ayDailyReturn), percent_format) # 计算衰落时间 down_time = [0] for i in range(1, len(upper)): - if upper[i] > upper[i-1]: + if upper[i] > upper[i - 1]: down_time.append(0) else: - l = down_time[i-1] - down_time.append(l+1) + li = down_time[i - 1] + down_time.append(li + 1) worksheet.write_column('N3', down_time, fund_data_format) + def do_trading_analyze2(df_closes, df_funds): df_wins = df_closes[df_closes["profit"] > 0] df_loses = df_closes[df_closes["profit"] <= 0] @@ -1538,17 +1562,18 @@ def do_trading_analyze2(df_closes, df_funds): # 单笔最大亏损交易 largest_loss = float(df_loses['profit'].min()) # 交易的平均持仓K线根数 - avgtrd_hold_bar = 0 if totaltimes==0 else ((df_closes['closebarno'] - df_closes['openbarno']).sum()) / totaltimes + avgtrd_hold_bar = 0 if totaltimes == 0 else ((df_closes['closebarno'] - df_closes['openbarno']).sum()) / totaltimes # 平均空仓K线根数 avb = (df_closes['openbarno'] - df_closes['closebarno'].shift(1).fillna(value=0)) - avgemphold_bar = 0 if len(df_closes)==0 else avb.sum() / len(df_closes) + avgemphold_bar = 0 if len(df_closes) == 0 else avb.sum() / len(df_closes) # 两笔盈利交易之间的平均空仓K线根数 win_holdbar_situ = (df_wins['openbarno'].shift(-1) - df_wins['closebarno']).dropna() - winempty_avgholdbar = 0 if len(df_wins)== 0 or len(df_wins) == 1 else win_holdbar_situ.sum() / (len(df_wins)-1) + winempty_avgholdbar = 0 if len(df_wins) == 0 or len(df_wins) == 1 else win_holdbar_situ.sum() / (len(df_wins) - 1) # 两笔亏损交易之间的平均空仓K线根数 loss_holdbar_situ = (df_loses['openbarno'].shift(-1) - df_loses['closebarno']).dropna() - lossempty_avgholdbar = 0 if len(df_loses)== 0 or len(df_loses) == 1 else loss_holdbar_situ.sum() / (len(df_loses)-1) + lossempty_avgholdbar = 0 if len(df_loses) == 0 or len(df_loses) == 1 else loss_holdbar_situ.sum() / ( + len(df_loses) - 1) max_consecutive_wins = 0 # 最大连续盈利次数 max_consecutive_loses = 0 # 最大连续亏损次数 @@ -1594,62 +1619,64 @@ def do_trading_analyze2(df_closes, df_funds): summary["max_consecutive_wins"] = max_consecutive_wins summary["max_consecutive_loses"] = max_consecutive_loses - return summary + class WtBtAnalyst: def __init__(self): self.__strategies__ = dict() return - def add_strategy(self, sname:str, folder:str, init_capital:float, rf:float=0.02, annual_trading_days:int = 240): + def add_strategy(self, sname: str, folder: str, init_capital: float, rf: float = 0.02, + annual_trading_days: int = 240): self.__strategies__[sname] = { "folder": folder, - "cap":init_capital, - "rf":rf, - "atd":annual_trading_days + "cap": init_capital, + "rf": rf, + "atd": annual_trading_days } - def run_new(self, outFileName:str = ''): + def run_new(self, outFileName: str = ''): if len(self.__strategies__.keys()) == 0: raise Exception("strategies is empty") for sname in self.__strategies__: sInfo = self.__strategies__[sname] folder = os.path.join(sInfo["folder"], sname) - print("start PnL analyzing for strategy %s……" % (sname)) + print("start PnL analyzing for strategy %s……" % sname) - df_funds = pd.read_csv(os.path.join(folder,"funds.csv")) + df_funds = pd.read_csv(os.path.join(folder, "funds.csv")) df_closes = pd.read_csv(os.path.join(folder, "closes.csv")) df_trades = pd.read_csv(os.path.join(folder, "trades.csv")) if len(outFileName) == 0: - outFileName = 'Strategy[%s]_PnLAnalyzing_%s_%s.xlsx' % (sname, df_funds['date'][0], df_funds['date'].iloc[-1]) + outFileName = 'Strategy[%s]_PnLAnalyzing_%s_%s.xlsx' % ( + sname, df_funds['date'][0], df_funds['date'].iloc[-1]) workbook = Workbook(outFileName) init_capital = sInfo["cap"] annual_days = sInfo["atd"] rf = sInfo["rf"] - strategy_analyze(workbook, df_closes.copy(), df_trades.copy(),df_funds.copy(), capital=init_capital, rf=rf, period=annual_days) + strategy_analyze(workbook, df_closes.copy(), df_trades.copy(), df_funds.copy(), capital=init_capital, rf=rf, + period=annual_days) output_closes(workbook, df_closes.copy(), capital=init_capital) trading_analyze(workbook, df_closes.copy(), df_funds.copy(), capital=init_capital) funds_analyze(workbook, df_funds.copy(), capital=init_capital, rf=rf, period=annual_days) workbook.close() - filename = os.path.join(folder,"summary.json") + filename = os.path.join(folder, "summary.json") sumObj = summary_analyze(df_funds, capital=init_capital, rf=rf, period=annual_days) sumObj["name"] = sname - f = open(filename,"w") + f = open(filename, "w") f.write(json.dumps(sumObj, indent=4, ensure_ascii=True)) f.close() - print("PnL analyzing of strategy %s done" % (sname)) - + print("PnL analyzing of strategy %s done" % sname) - def run(self, outFileName:str = ''): + def run(self, outFileName: str = ''): if len(self.__strategies__.keys()) == 0: raise Exception("strategies is empty") @@ -1657,7 +1684,7 @@ def run(self, outFileName:str = ''): sInfo = self.__strategies__[sname] # folder = sInfo["folder"] folder = os.path.join(sInfo["folder"], sname) - print("start PnL analyzing for strategy %s……" % (sname)) + print("start PnL analyzing for strategy %s……" % sname) df_funds = pd.read_csv(os.path.join(folder, "funds.csv")) print("fund logs loaded……") @@ -1665,14 +1692,15 @@ def run(self, outFileName:str = ''): init_capital = sInfo["cap"] annual_days = sInfo["atd"] rf = sInfo["rf"] - + if len(outFileName) == 0: - outFileName = 'Strategy[%s]_PnLAnalyzing_%s_%s.xlsx' % (sname, df_funds['date'][0], df_funds['date'].iloc[-1]) + outFileName = 'Strategy[%s]_PnLAnalyzing_%s_%s.xlsx' % ( + sname, df_funds['date'][0], df_funds['date'].iloc[-1]) workbook = Workbook(outFileName) funds_analyze(workbook, df_funds, capital=init_capital, rf=rf, period=annual_days) workbook.close() - print("PnL analyzing of strategy %s done" % (sname)) + print("PnL analyzing of strategy %s done" % sname) def run_simple(self): if len(self.__strategies__.keys()) == 0: @@ -1680,18 +1708,18 @@ def run_simple(self): for sname in self.__strategies__: sInfo = self.__strategies__[sname] - folder = os.path.join(sInfo["folder"],sname) + folder = os.path.join(sInfo["folder"], sname) df_funds = pd.read_csv(os.path.join(folder, "funds.csv")) init_capital = sInfo["cap"] annual_days = sInfo["atd"] rf = sInfo["rf"] - + filename = folder + 'summary.json' sumObj = summary_analyze(df_funds, capital=init_capital, rf=rf, period=annual_days) sumObj["name"] = sname - f = open(filename,"w") + f = open(filename, "w") f.write(json.dumps(sumObj, indent=4, ensure_ascii=True)) f.close() @@ -1702,12 +1730,13 @@ def run_flat(self): annual_days = sInfo["atd"] rf = sInfo["rf"] - folder = os.path.join(sInfo["folder"],sname) + folder = os.path.join(sInfo["folder"], sname) df_funds = pd.read_csv(os.path.join(folder, "funds.csv")) df_closes = pd.read_csv(os.path.join(folder, "closes.csv")) - df_closes['fee'] = df_closes['profit'] - df_closes['totalprofit'] + df_closes['totalprofit'].shift(1).fillna(value=0) + df_closes['fee'] = df_closes['profit'] - df_closes['totalprofit'] + df_closes['totalprofit'].shift( + 1).fillna(value=0) df_long = df_closes[df_closes['direct'].apply(lambda x: 'LONG' in x)] df_short = df_closes[df_closes['direct'].apply(lambda x: 'SHORT' in x)] @@ -1716,7 +1745,7 @@ def run_flat(self): summary_long = do_trading_analyze2(df_long, df_funds) filename = os.path.join(folder, 'trdana.json') - f = open(filename,"w") + f = open(filename, "w") f.write(json.dumps({ "all": summary_all, "long": summary_long, @@ -1725,7 +1754,8 @@ def run_flat(self): f.close() df_closes = df_closes.copy() - df_closes['fee'] = df_closes['profit'] - df_closes['totalprofit'] + df_closes['totalprofit'].shift(1).fillna( + df_closes['fee'] = df_closes['profit'] - df_closes['totalprofit'] + df_closes['totalprofit'].shift( + 1).fillna( value=0) df_closes['profit'] = df_closes['profit'] - df_closes['fee'] df_closes['profit_sum'] = df_closes['profit'].expanding(1).sum() @@ -1740,29 +1770,28 @@ def run_flat(self): closes_all = list() for item in np_trade: litem = { - "opentime":int(item[2]), - "closetime":int(item[4]), - "profit":float(item[7]), - "direct":str(item[1]), - "openprice":float(item[3]), - "closeprice":float(item[5]), - "maxprofit":float(item[8]), - "maxloss":float(item[9]), - "qty":int(item[6]), + "opentime": int(item[2]), + "closetime": int(item[4]), + "profit": float(item[7]), + "direct": str(item[1]), + "openprice": float(item[3]), + "closeprice": float(item[5]), + "maxprofit": float(item[8]), + "maxloss": float(item[9]), + "qty": int(item[6]), "capital": capital, - 'profit_sum':float(item[16]), - 'Withdrawal':float(item[17]), - 'profit_ratio':float(item[18]), - 'Withdrawal_ratio':float(item[19]) + 'profit_sum': float(item[16]), + 'Withdrawal': float(item[17]), + 'profit_ratio': float(item[18]), + 'Withdrawal_ratio': float(item[19]) } closes_all.append(litem) df_closes['time'] = df_closes['closetime'].apply(lambda x: datetime.strptime(str(x), '%Y%m%d%H%M')) df_c_m = df_closes.resample(rule='M', on='time', label='right', - closed='right').agg({ - 'profit': 'sum', - 'maxprofit': 'sum', - 'maxloss': 'sum', - }) + closed='right').agg({'profit': 'sum', + 'maxprofit': 'sum', + 'maxloss': 'sum', + }) df_c_m = df_c_m.reset_index() df_c_m['equity'] = df_c_m['profit'].expanding(1).sum() + capital df_c_m['monthly_profit'] = 100 * (df_c_m['equity'] / df_c_m['equity'].shift(1).fillna(value=capital) - 1) @@ -1770,21 +1799,20 @@ def run_flat(self): np_m = np.array(df_c_m).tolist() for item in np_m: litem = { - "time":int(item[0].strftime('%Y%m')), - "profit":float(item[1]), - 'maxprofit':float(item[2]), - 'maxloss':float(item[3]), - 'equity':float(item[4]), - 'monthly_profit':float(item[5]) + "time": int(item[0].strftime('%Y%m')), + "profit": float(item[1]), + 'maxprofit': float(item[2]), + 'maxloss': float(item[3]), + 'equity': float(item[4]), + 'monthly_profit': float(item[5]) } closes_month.append(litem) df_c_y = df_closes.resample(rule='Y', on='time', label='right', - closed='right').agg({ - 'profit': 'sum', - 'maxprofit': 'sum', - 'maxloss': 'sum', - }) + closed='right').agg({'profit': 'sum', + 'maxprofit': 'sum', + 'maxloss': 'sum', + }) df_c_y = df_c_y.reset_index() df_c_y['equity'] = df_c_y['profit'].expanding(1).sum() + capital df_c_y['monthly_profit'] = 100 * (df_c_y['equity'] / df_c_y['equity'].shift(1).fillna(value=capital) - 1) @@ -1792,12 +1820,12 @@ def run_flat(self): np_y = np.array(df_c_y).tolist() for item in np_y: litem = { - "time":int(item[0].strftime('%Y%m')), - "profit":float(item[1]), - 'maxprofit':float(item[2]), - 'maxloss':float(item[3]), - 'equity':float(item[4]), - 'annual_profit':float(item[5]) + "time": int(item[0].strftime('%Y%m')), + "profit": float(item[1]), + 'maxprofit': float(item[2]), + 'maxloss': float(item[3]), + 'equity': float(item[4]), + 'annual_profit': float(item[5]) } closes_year.append(litem) @@ -1805,29 +1833,29 @@ def run_flat(self): df_short = df_closes[df_closes['direct'].apply(lambda x: 'SHORT' in x)] df_long = df_long.copy() df_short = df_short.copy() - df_long["long_profit"] = df_long["profit"].expanding(1).sum()-df_long["fee"].expanding(1).sum() + df_long["long_profit"] = df_long["profit"].expanding(1).sum() - df_long["fee"].expanding(1).sum() closes_long = list() closes_short = list() np_long = np.array(df_long).tolist() for item in np_long: litem = { - "date":int(item[4]), - "long_profit":float(item[-1]), - "capital":capital + "date": int(item[4]), + "long_profit": float(item[-1]), + "capital": capital } closes_long.append(litem) - df_short["short_profit"] = df_short["profit"].expanding(1).sum()-df_short["fee"].expanding(1).sum() + df_short["short_profit"] = df_short["profit"].expanding(1).sum() - df_short["fee"].expanding(1).sum() np_short = np.array(df_short).tolist() for item in np_short: litem = { - "date":int(item[4]), - "short_profit":float(item[-1]), - "capital":capital + "date": int(item[4]), + "short_profit": float(item[-1]), + "capital": capital } closes_short.append(litem) filename = os.path.join(folder, 'rndana.json') - f = open(filename,"w") + f = open(filename, "w") f.write(json.dumps({ "long": closes_long, "short": closes_short, @@ -1836,10 +1864,10 @@ def run_flat(self): "year": closes_year }, indent=4, ensure_ascii=True)) f.close() - - filename = os.path.join(folder,"summary.json") + + filename = os.path.join(folder, "summary.json") sumObj = summary_analyze(df_funds, capital=capital, rf=rf, period=annual_days) sumObj["name"] = sname - f = open(filename,"w") + f = open(filename, "w") f.write(json.dumps(sumObj, indent=4, ensure_ascii=True)) - f.close() \ No newline at end of file + f.close() diff --git a/wtpy/apps/WtCCLoader.py b/wtpy/apps/WtCCLoader.py index 4f0ce04e..be756952 100644 --- a/wtpy/apps/WtCCLoader.py +++ b/wtpy/apps/WtCCLoader.py @@ -5,16 +5,17 @@ import json -def httpGet(url, encoding:str='utf-8', proxy:str = None, headers:dict = {}) -> str: - +def httpGet(url, encoding: str = 'utf-8', proxy: str = None, headers=None) -> str: + if headers is None: + headers = {} headers['Accept-encoding'] = 'gzip' headers['User-Agent'] = 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko)' handler = None if proxy is not None: - proxy_value = "%(ip)s" % {"ip":proxy} + proxy_value = "%(ip)s" % {"ip": proxy} proxies = { - "http":proxy_value, - "https":proxy_value + "http": proxy_value, + "https": proxy_value } handler = urllib.request.ProxyHandler(proxies) @@ -32,8 +33,9 @@ def httpGet(url, encoding:str='utf-8', proxy:str = None, headers:dict = {}) -> s else: return "" -def wrap_category(iType:str): - #20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-杠杆Margin + +def wrap_category(iType: str): + # 20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-杠杆Margin if iType.upper() == "SPOT": return 20 elif iType.upper() == "SWAP": @@ -45,18 +47,19 @@ def wrap_category(iType:str): else: return 24 + class WtCCLoader: @staticmethod - def load_from_okex(filename:str, instTypes:list = ["SPOT"], proxy:str = None) -> bool: + def load_from_okex(filename: str, instTypes: list = ["SPOT"], proxy: str = None) -> bool: contracts = dict() for iType in instTypes: cat = wrap_category(iType) - tMode = 1 if iType=='SPOT' else 0 #0-多空, 1-做多, 2-做多T+1 - content = httpGet('https://www.okex.com/api/v5/public/instruments?instType='+iType, proxy=proxy, headers={ - "Accept":"application/json" + tMode = 1 if iType == 'SPOT' else 0 # 0-多空, 1-做多, 2-做多T+1 + content = httpGet('https://www.okex.com/api/v5/public/instruments?instType=' + iType, proxy=proxy, headers={ + "Accept": "application/json" }) if len(content) == 0: return False @@ -75,10 +78,10 @@ def load_from_okex(filename:str, instTypes:list = ["SPOT"], proxy:str = None) -> cInfo["code"] = item["instId"] cInfo["exchg"] = "OKEX" - #这些是wt不用的额外信息,做一个保存 + # 这些是wt不用的额外信息,做一个保存 extInfo = dict() extInfo["instType"] = iType - extInfo["baseCcy"] = item["baseCcy"] + extInfo["baseCcy"] = item["baseCcy"] extInfo["quoteCcy"] = item["quoteCcy"] extInfo["category"] = item["category"] extInfo["ctVal"] = item["ctVal"] @@ -91,42 +94,43 @@ def load_from_okex(filename:str, instTypes:list = ["SPOT"], proxy:str = None) -> ruleInfo["session"] = "ALLDAY" ruleInfo["holiday"] = "" - ruleInfo["covermode"] = 3 #0-开平, 1-区分平今, 3-不分开平 - ruleInfo["pricemode"] = 0 #0-支持限价市价, 1-只支持限价, 2-只支持市价 - ruleInfo["category"] = cat #20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-币币杠杆Margin - ruleInfo["trademode"] = tMode #0-多空, 1-做多, 2-做多T+1 + ruleInfo["covermode"] = 3 # 0-开平, 1-区分平今, 3-不分开平 + ruleInfo["pricemode"] = 0 # 0-支持限价市价, 1-只支持限价, 2-只支持市价 + ruleInfo["category"] = cat # 20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-币币杠杆Margin + ruleInfo["trademode"] = tMode # 0-多空, 1-做多, 2-做多T+1 ruleInfo["pricetick"] = float(item["tickSz"]) ruleInfo["lotstick"] = float(item["lotSz"]) ruleInfo["minlots"] = float(item["minSz"]) - ruleInfo["volscale"] = int(item["ctMult"]) if len(item["ctMult"])>0 else 1 + ruleInfo["volscale"] = int(item["ctMult"]) if len(item["ctMult"]) > 0 else 1 cInfo["rules"] = ruleInfo contracts[cInfo['code']] = cInfo - except: + except Exception as e: + print(e) continue # 这里将下载到的合约列表落地 f = open(filename, "w") - f.write(json.dumps({"OKEX":contracts}, indent=4, ensure_ascii=False)) + f.write(json.dumps({"OKEX": contracts}, indent=4, ensure_ascii=False)) f.close() - @staticmethod - def load_spots_from_binance(filename:str, proxy:str = None) -> bool: + def load_spots_from_binance(filename: str, proxy: str = None) -> bool: contracts = dict() - content = httpGet('https://api.binance.com/api/v3/exchangeInfo', proxy=proxy, headers={ - "Accept":"application/json" + content = httpGet('https://api.binance.com/api/v3/exchangeInfo', proxy=proxy, headers={ + "Accept": "application/json" }) if len(content) == 0: return False try: root = json.loads(content) - except: + except Exception as e: + print(e) print("加载合约列表出错") return False @@ -142,13 +146,13 @@ def load_spots_from_binance(filename:str, proxy:str = None) -> bool: iType = "SPOT" else: continue - - tMode = 1 if iType=='SPOT' else 0 #0-多空, 1-做多, 2-做多T+1 - #这些是wt不用的额外信息,做一个保存 + tMode = 1 if iType == 'SPOT' else 0 # 0-多空, 1-做多, 2-做多T+1 + + # 这些是wt不用的额外信息,做一个保存 extInfo = dict() extInfo["instType"] = iType - extInfo["baseAsset"] = item["baseAsset"] + extInfo["baseAsset"] = item["baseAsset"] extInfo["quoteAsset"] = item["quoteAsset"] extInfo["icebergAllowed"] = item["icebergAllowed"] extInfo["ocoAllowed"] = item["ocoAllowed"] @@ -164,10 +168,10 @@ def load_spots_from_binance(filename:str, proxy:str = None) -> bool: ruleInfo["session"] = "ALLDAY" ruleInfo["holiday"] = "" - ruleInfo["covermode"] = 3 #0-开平, 1-区分平今, 3-不分开平 - ruleInfo["pricemode"] = 0 #0-支持限价市价, 1-只支持限价, 2-只支持市价 - ruleInfo["category"] = wrap_category(iType) #20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-币币杠杆Margin - ruleInfo["trademode"] = tMode #0-多空, 1-做多, 2-做多T+1 + ruleInfo["covermode"] = 3 # 0-开平, 1-区分平今, 3-不分开平 + ruleInfo["pricemode"] = 0 # 0-支持限价市价, 1-只支持限价, 2-只支持市价 + ruleInfo["category"] = wrap_category(iType) # 20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-币币杠杆Margin + ruleInfo["trademode"] = tMode # 0-多空, 1-做多, 2-做多T+1 for fItem in item["filters"]: if fItem["filterType"] == "PRICE_FILTER": @@ -180,27 +184,27 @@ def load_spots_from_binance(filename:str, proxy:str = None) -> bool: cInfo["rules"] = ruleInfo contracts[cInfo['code']] = cInfo - # 这里将下载到的合约列表落地 f = open(filename, "w") - f.write(json.dumps({"BINANCE":contracts}, indent=4, ensure_ascii=False)) + f.write(json.dumps({"BINANCE": contracts}, indent=4, ensure_ascii=False)) f.close() @staticmethod - def load_fpairs_from_binance(filename:str, proxy:str = None) -> bool: + def load_fpairs_from_binance(filename: str, proxy: str = None) -> bool: contracts = dict() - content = httpGet('https://fapi.binance.com/fapi/v1/exchangeInfo', proxy=proxy, headers={ - "Accept":"application/json" + content = httpGet('https://fapi.binance.com/fapi/v1/exchangeInfo', proxy=proxy, headers={ + "Accept": "application/json" }) if len(content) == 0: return False try: root = json.loads(content) - except: + except Exception as e: + print(e) print("加载合约列表出错") return False @@ -217,12 +221,12 @@ def load_fpairs_from_binance(filename:str, proxy:str = None) -> bool: iType = "SWAP" else: iType = "FUTURES" - tMode = 0 #0-多空, 1-做多, 2-做多T+1 + tMode = 0 # 0-多空, 1-做多, 2-做多T+1 - #这些是wt不用的额外信息,做一个保存 + # 这些是wt不用的额外信息,做一个保存 extInfo = dict() extInfo["instType"] = iType - extInfo["baseAsset"] = item["baseAsset"] + extInfo["baseAsset"] = item["baseAsset"] extInfo["quoteAsset"] = item["quoteAsset"] extInfo["marginAsset"] = item["marginAsset"] extInfo["pricePrecision"] = item["pricePrecision"] @@ -242,10 +246,10 @@ def load_fpairs_from_binance(filename:str, proxy:str = None) -> bool: ruleInfo["session"] = "ALLDAY" ruleInfo["holiday"] = "" - ruleInfo["covermode"] = 3 #0-开平, 1-区分平今, 3-不分开平 - ruleInfo["pricemode"] = 0 #0-支持限价市价, 1-只支持限价, 2-只支持市价 - ruleInfo["category"] = wrap_category(iType) #20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-币币杠杆Margin - ruleInfo["trademode"] = tMode #0-多空, 1-做多, 2-做多T+1 + ruleInfo["covermode"] = 3 # 0-开平, 1-区分平今, 3-不分开平 + ruleInfo["pricemode"] = 0 # 0-支持限价市价, 1-只支持限价, 2-只支持市价 + ruleInfo["category"] = wrap_category(iType) # 20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-币币杠杆Margin + ruleInfo["trademode"] = tMode # 0-多空, 1-做多, 2-做多T+1 for fItem in item["filters"]: if fItem["filterType"] == "PRICE_FILTER": @@ -258,27 +262,27 @@ def load_fpairs_from_binance(filename:str, proxy:str = None) -> bool: cInfo["rules"] = ruleInfo contracts[cInfo['code']] = cInfo - # 这里将下载到的合约列表落地 f = open(filename, "w") - f.write(json.dumps({"BINANCE":contracts}, indent=4, ensure_ascii=False)) + f.write(json.dumps({"BINANCE": contracts}, indent=4, ensure_ascii=False)) f.close() @staticmethod - def load_dpairs_from_binance(filename:str, proxy:str = None) -> bool: + def load_dpairs_from_binance(filename: str, proxy: str = None) -> bool: contracts = dict() - content = httpGet('https://dapi.binance.com/dapi/v1/exchangeInfo', proxy=proxy, headers={ - "Accept":"application/json" + content = httpGet('https://dapi.binance.com/dapi/v1/exchangeInfo', proxy=proxy, headers={ + "Accept": "application/json" }) if len(content) == 0: return False try: root = json.loads(content) - except: + except Exception as e: + print(e) print("加载合约列表出错") return False @@ -296,12 +300,12 @@ def load_dpairs_from_binance(filename:str, proxy:str = None) -> bool: iType = "SWAP" else: iType = "FUTURES" - tMode = 0 #0-多空, 1-做多, 2-做多T+1 + tMode = 0 # 0-多空, 1-做多, 2-做多T+1 - #这些是wt不用的额外信息,做一个保存 + # 这些是wt不用的额外信息,做一个保存 extInfo = dict() extInfo["instType"] = iType - extInfo["baseAsset"] = item["baseAsset"] + extInfo["baseAsset"] = item["baseAsset"] extInfo["quoteAsset"] = item["quoteAsset"] extInfo["marginAsset"] = item["marginAsset"] extInfo["pricePrecision"] = item["pricePrecision"] @@ -321,10 +325,10 @@ def load_dpairs_from_binance(filename:str, proxy:str = None) -> bool: ruleInfo["session"] = "ALLDAY" ruleInfo["holiday"] = "" - ruleInfo["covermode"] = 3 #0-开平, 1-区分平今, 3-不分开平 - ruleInfo["pricemode"] = 0 #0-支持限价市价, 1-只支持限价, 2-只支持市价 - ruleInfo["category"] = wrap_category(iType) #20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-币币杠杆Margin - ruleInfo["trademode"] = tMode #0-多空, 1-做多, 2-做多T+1 + ruleInfo["covermode"] = 3 # 0-开平, 1-区分平今, 3-不分开平 + ruleInfo["pricemode"] = 0 # 0-支持限价市价, 1-只支持限价, 2-只支持市价 + ruleInfo["category"] = wrap_category(iType) # 20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-币币杠杆Margin + ruleInfo["trademode"] = tMode # 0-多空, 1-做多, 2-做多T+1 for fItem in item["filters"]: if fItem["filterType"] == "PRICE_FILTER": @@ -337,28 +341,27 @@ def load_dpairs_from_binance(filename:str, proxy:str = None) -> bool: cInfo["rules"] = ruleInfo contracts[cInfo['code']] = cInfo - # 这里将下载到的合约列表落地 f = open(filename, "w") - f.write(json.dumps({"BINANCE":contracts}, indent=4, ensure_ascii=False)) + f.write(json.dumps({"BINANCE": contracts}, indent=4, ensure_ascii=False)) f.close() - @staticmethod - def load_from_ftx(filename:str, instTypes:list = ["SPOT"], proxy:str = None) -> bool: + def load_from_ftx(filename: str, instTypes: list = ["SPOT"], proxy: str = None) -> bool: contracts = dict() - content = httpGet('https://ftx.com/api/markets', proxy=proxy, headers={ - "Accept":"application/json" + content = httpGet('https://ftx.com/api/markets', proxy=proxy, headers={ + "Accept": "application/json" }) if len(content) == 0: return False try: root = json.loads(content) - except: + except Exception as e: + print(e) print("加载合约列表出错") return False @@ -384,12 +387,12 @@ def load_from_ftx(filename:str, instTypes:list = ["SPOT"], proxy:str = None) -> if iType not in instTypes: continue - tMode = 1 if iType=='SPOT' else 0 #0-多空, 1-做多, 2-做多T+1 + tMode = 1 if iType == 'SPOT' else 0 # 0-多空, 1-做多, 2-做多T+1 - #这些是wt不用的额外信息,做一个保存 + # 这些是wt不用的额外信息,做一个保存 extInfo = dict() extInfo["instType"] = iType - extInfo["baseCurrency"] = item["baseCurrency"] + extInfo["baseCurrency"] = item["baseCurrency"] extInfo["quoteCurrency"] = item["quoteCurrency"] extInfo["underlying"] = item["underlying"] extInfo["postOnly"] = item["postOnly"] @@ -400,10 +403,10 @@ def load_from_ftx(filename:str, instTypes:list = ["SPOT"], proxy:str = None) -> ruleInfo["session"] = "ALLDAY" ruleInfo["holiday"] = "" - ruleInfo["covermode"] = 3 #0-开平, 1-区分平今, 3-不分开平 - ruleInfo["pricemode"] = 0 #0-支持限价市价, 1-只支持限价, 2-只支持市价 - ruleInfo["category"] = wrap_category(iType) #20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-币币杠杆Margin - ruleInfo["trademode"] = tMode #0-多空, 1-做多, 2-做多T+1 + ruleInfo["covermode"] = 3 # 0-开平, 1-区分平今, 3-不分开平 + ruleInfo["pricemode"] = 0 # 0-支持限价市价, 1-只支持限价, 2-只支持市价 + ruleInfo["category"] = wrap_category(iType) # 20-币币SPOT, 21-永续SWAP, 22-期货Future, 23-币币杠杆Margin + ruleInfo["trademode"] = tMode # 0-多空, 1-做多, 2-做多T+1 ruleInfo["pricetick"] = item["priceIncrement"] ruleInfo["lotstick"] = item["sizeIncrement"] @@ -413,10 +416,9 @@ def load_from_ftx(filename:str, instTypes:list = ["SPOT"], proxy:str = None) -> cInfo["rules"] = ruleInfo contracts[cInfo['code']] = cInfo - + print(all_types) # 这里将下载到的合约列表落地 f = open(filename, "w") - f.write(json.dumps({"FTX":contracts}, indent=4, ensure_ascii=False)) + f.write(json.dumps({"FTX": contracts}, indent=4, ensure_ascii=False)) f.close() - diff --git a/wtpy/apps/WtCtaGAOptimizer.py b/wtpy/apps/WtCtaGAOptimizer.py index e0d84f79..be8add3a 100644 --- a/wtpy/apps/WtCtaGAOptimizer.py +++ b/wtpy/apps/WtCtaGAOptimizer.py @@ -28,9 +28,9 @@ def fmtNAN(val, defVal=0): class ParamInfo: - ''' + """ 参数信息类 - ''' + """ def __init__(self, name: str, start_val=None, end_val=None, step_val=None, ndigits=1, val_list: list = None): self.name = name # 参数名 @@ -63,14 +63,14 @@ def gen_array(self): class WtCtaGAOptimizer: - ''' + """ 参数优化器\n 主要用于做策略参数优化的 - ''' + """ def __init__(self, worker_num: int = 2, MU: int = 80, population_size: int = 100, ngen_size: int = 20, cx_prb: float = 0.9, mut_prb: float = 0.1): - ''' + """ 构造函数\n @worker_num 工作进程个数,默认为2,可以根据CPU核心数设置,由于计算回测值是从文件里读取,因此进程过多可能会出现冲突\n @@ -79,7 +79,7 @@ def __init__(self, worker_num: int = 2, MU: int = 80, population_size: int = 100 @ngen_size 进化代数\n @cx_prb 交叉概率\n @mut_prb 变异概率 - ''' + """ self.worker_num = worker_num self.running_worker = 0 self.mutable_params = dict() @@ -102,7 +102,7 @@ def __init__(self, worker_num: int = 2, MU: int = 80, population_size: int = 100 self.cache_dict = multiprocessing.Manager().dict() # 缓存中间结果 def add_mutable_param(self, name: str, start_val, end_val, step_val, ndigits=1): - ''' + """ 添加可变参数\n @name 参数名\n @@ -110,30 +110,30 @@ def add_mutable_param(self, name: str, start_val, end_val, step_val, ndigits=1): @end_val 结束值\n @step_val 步长\n @ndigits 小数位 - ''' + """ self.mutable_params[name] = ParamInfo(name=name, start_val=start_val, end_val=end_val, step_val=step_val, ndigits=ndigits) def add_listed_param(self, name: str, val_list: list): - ''' + """ 添加限定范围的可变参数\n @name 参数名\n @val_list 参数值列表 - ''' + """ self.mutable_params[name] = ParamInfo(name=name, val_list=val_list) def add_fixed_param(self, name: str, val): - ''' + """ 添加固定参数\n @name 参数名\n @val 值\n - ''' + """ self.fixed_params[name] = val def generate_settings(self): - ''' 生成优化参数组合 ''' + """ 生成优化参数组合 """ # 参数名列表 name_list = self.mutable_params.keys() @@ -152,11 +152,11 @@ def generate_settings(self): return settings def set_optimizing_target(self, target: str): - ''' 设置优化目标名称,可从summary中已有数据中选取优化目标 ''' + """ 设置优化目标名称,可从summary中已有数据中选取优化目标 """ self.optimizing_target = target def set_optimizing_func(self, calculator, target_name: str = None): - ''' 根据summary数据自定义优化目标值 ''' + """ 根据summary数据自定义优化目标值 """ self.optimizing_target_func = calculator if target_name is None: @@ -180,7 +180,7 @@ def mututate_individual(self, individual, indpb): individual[i] = settings[i] return individual, - def evaluate_func(self, start_time, end_time, cache_dict: dict, params, capital = 5000000, rf = 0, period = 240): + def evaluate_func(self, start_time, end_time, cache_dict: dict, params, capital=5000000, rf=0, period=240): """ 适应度函数 :return: @@ -265,24 +265,24 @@ def evaluate_func(self, start_time, end_time, cache_dict: dict, params, capital return result def set_strategy(self, typeName: type, name_prefix: str): - ''' + """ 设置策略\n @typeName 策略类名\n @name_prefix 命名前缀,用于自动命名用,一般为格式为"前缀_参数1名_参数1值_参数2名_参数2值" - ''' + """ self.strategy_type = typeName self.name_prefix = name_prefix return def set_cpp_strategy(self, module: str, type_name: type, name_prefix: str): - ''' + """ 设置CPP策略\n @module 模块文件\n @typeName 策略类名\n @name_prefix 命名前缀,用于自动命名用,一般为格式为"前缀_参数1名_参数1值_参数2名_参数2值" - ''' + """ self.cpp_stra_module = module self.cpp_stra_type = type_name self.name_prefix = name_prefix @@ -290,14 +290,14 @@ def set_cpp_strategy(self, module: str, type_name: type, name_prefix: str): def config_backtest_env(self, deps_dir: str, cfgfile: str = "configbt.yaml", storage_type: str = "csv", storage_path: str = None, storage: dict = None): - ''' + """ 配置回测环境\n @deps_dir 依赖文件目录\n @cfgfile 配置文件名\n @storage_type 存储类型,csv/bin等\n @storage_path 存储路径 - ''' + """ self.env_params["deps_dir"] = deps_dir self.env_params["cfgfile"] = cfgfile self.env_params["storage_type"] = storage_type @@ -305,21 +305,21 @@ def config_backtest_env(self, deps_dir: str, cfgfile: str = "configbt.yaml", sto self.env_params["storage_path"] = storage_path def config_backtest_time(self, start_time: int, end_time: int): - ''' + """ 配置回测时间,可多次调用配置多个回测时间区间\n @start_time 开始时间,精确到分钟,格式如201909100930\n @end_time 结束时间,精确到分钟,格式如201909100930 - ''' + """ if "time_ranges" not in self.env_params: self.env_params["time_ranges"] = [] self.env_params["time_ranges"].append([start_time, end_time]) def gen_params(self, markerfile: str = "strategies.json"): - ''' + """ 生成回测任务 - ''' + """ # name_list = self.mutable_params.keys() param_list = self.generate_settings() @@ -349,13 +349,14 @@ def gen_params(self, markerfile: str = "strategies.json"): f.close() return param_group - def __ayalyze_result__(self, strName: str, time_range: tuple, params: dict, capital = 5000000, rf = 0, period = 240): - folder = "./outputs_bt/%s/" % (strName) + def __ayalyze_result__(self, strName: str, time_range: tuple, params: dict, capital=5000000, rf=0, period=240): + folder = "./outputs_bt/%s/" % strName try: df_closes = pd.read_csv(folder + "closes.csv", engine="python") df_funds = pd.read_csv(folder + "funds.csv", engine="python") except Exception as e: # 如果读取csv文件出现异常,则按文本格式读取 + print(e) df_closes = read_closes(folder + "closes.csv") df_funds = read_funds(folder + "funds.csv") @@ -422,9 +423,9 @@ def __ayalyze_result__(self, strName: str, time_range: tuple, params: dict, capi summary["毛盈利"] = float(winamout) summary["毛亏损"] = float(loseamount) summary["交易净盈亏"] = float(trdnetprofit) - summary["逐笔胜率%"] = winrate*100 + summary["逐笔胜率%"] = winrate * 100 summary["逐笔平均盈亏"] = avgprof - summary["逐笔平均净盈亏"] = accnetprofit/totaltimes + summary["逐笔平均净盈亏"] = accnetprofit / totaltimes summary["逐笔平均盈利"] = avgprof_win summary["逐笔逐笔亏损"] = avgprof_lose summary["逐笔盈亏比"] = winloseratio @@ -439,7 +440,7 @@ def __ayalyze_result__(self, strName: str, time_range: tuple, params: dict, capi return summary - def run_ga_optimizer(self, params: dict = None, capital = 5000000, rf = 0, period = 240): + def run_ga_optimizer(self, params: dict = None, capital=5000000, rf=0, period=240): """ 执行GA优化 """ # 遗传算法参数空间 buffer = self.generate_settings() @@ -453,7 +454,8 @@ def generate_parameter(): toolbox.register("individual", tools.initIterate, creator.Individual, generate_parameter) toolbox.register("population", tools.initRepeat, list, toolbox.individual) - toolbox.register("evaluate", self.evaluate_func, params["start_time"], params["end_time"], self.cache_dict, capital=capital, rf=rf, period=period) + toolbox.register("evaluate", self.evaluate_func, params["start_time"], params["end_time"], self.cache_dict, + capital=capital, rf=rf, period=period) toolbox.register("mate", tools.cxTwoPoint) toolbox.register("mutate", self.mututate_individual, indpb=0.05) toolbox.register("select", tools.selNSGA2) @@ -494,11 +496,11 @@ def generate_parameter(): # return def go(self, out_marker_file: str = "strategies.json", - out_summary_file: str = "total_summary.csv", capital = 5000000, rf = 0, period = 240): - ''' + out_summary_file: str = "total_summary.csv", capital=5000000, rf=0, period=240): + """ 启动优化器\n @markerfile 标记文件名,回测完成以后分析会用到 - ''' + """ params = self.gen_params(out_marker_file) self.run_ga_optimizer(params, capital, rf, period) @@ -517,7 +519,7 @@ def go(self, out_marker_file: str = "strategies.json", obj_stras = json.loads(content) total_summary = list() for straName in obj_stras: - filename = "./outputs_bt/%s/summary.json" % (straName) + filename = "./outputs_bt/%s/summary.json" % straName if not os.path.exists(filename): # print("%s不存在,请检查数据" % (filename)) continue @@ -542,7 +544,8 @@ def go(self, out_marker_file: str = "strategies.json", df_summary.to_csv(out_summary_file, encoding='utf-8-sig', index=False) print(f'优化目标: {self.optimizing_target}, 优化最大值:{df_summary[self.optimizing_target][0]}') - def analyze(self, out_marker_file: str = "strategies.json", out_summary_file: str = "total_summary.csv", capital = 5000000, rf = 0, period = 240): + def analyze(self, out_marker_file: str = "strategies.json", out_summary_file: str = "total_summary.csv", + capital=5000000, rf=0, period=240): # 获取所有的值 results = list(self.cache_dict.values()) header = list(results[0].keys()) @@ -559,7 +562,7 @@ def analyze(self, out_marker_file: str = "strategies.json", out_summary_file: st obj_stras = json.loads(content) for straName in obj_stras: params = obj_stras[straName] - filename = "./outputs_bt/%s/summary.json" % (straName) + filename = "./outputs_bt/%s/summary.json" % straName if not os.path.exists(filename): # print("%s不存在,请检查数据" % (filename)) continue @@ -593,7 +596,8 @@ def analyzer(self, out_marker_file: str = "strategies.json", init_capital=500000 analyst.add_strategy(straname, folder="./outputs_bt/%s/" % straname, init_capital=init_capital, rf=rf, annual_trading_days=annual_trading_days) analyst.run() - except: + except Exception as e: + print(e) pass diff --git a/wtpy/apps/WtCtaOptimizer.py b/wtpy/apps/WtCtaOptimizer.py index 2eb3a9ac..d0163ae5 100644 --- a/wtpy/apps/WtCtaOptimizer.py +++ b/wtpy/apps/WtCtaOptimizer.py @@ -10,87 +10,92 @@ import pandas as pd from pandas import DataFrame as df -from wtpy import WtBtEngine,EngineType +from wtpy import WtBtEngine, EngineType from wtpy.apps import WtBtAnalyst from wtpy.apps.WtBtAnalyst import summary_analyze -from wtpy.WtMsgQue import WtMsgQue,WtMQServer +from wtpy.WtMsgQue import WtMsgQue, WtMQServer -def fmtNAN(val, defVal = 0): + +def fmtNAN(val, defVal=0): if math.isnan(val): return defVal return val + mq = WtMsgQue() + + class OptimizeNotifier: - def __init__(self, url:str): + def __init__(self, url: str): self._url = url - self._server:WtMQServer = None + self._server: WtMQServer = None def run(self): self._server = mq.add_mq_server(self._url) - def publish(self, topic:str, message:str): + def publish(self, topic: str, message: str): if self._server is None: return self._server.publish_message(topic, message) - def on_start(self, pgroups:int): + def on_start(self, pgroups: int): data = { - "pgroups":pgroups + "pgroups": pgroups } self.publish("OPT_START", json.dumps(data)) - def on_stop(self, pgroups:int, elapse:int): + def on_stop(self, pgroups: int, elapse: int): data = { "pgroups": pgroups, - "elapse":int(elapse) + "elapse": int(elapse) } self.publish("OPT_STOP", json.dumps(data)) - def on_state(self, pgroups:int, done:int, progress:float, elapse:float): + def on_state(self, pgroups: int, done: int, progress: float, elapse: float): data = { "pgroups": pgroups, "done": done, - "progress":progress, - "elapse":int(elapse) + "progress": progress, + "elapse": int(elapse) } self.publish("OPT_STATE", json.dumps(data)) -def ayalyze_result(strName:str, time_range:tuple, params:dict, capital = 5000000, rf = 0, period = 240): - folder = "./outputs_bt/%s/" % (strName) + +def ayalyze_result(strName: str, time_range: tuple, params: dict, capital=5000000, rf=0, period=240): + folder = "./outputs_bt/%s/" % strName df_closes = pd.read_csv(folder + "closes.csv") df_funds = pd.read_csv(folder + "funds.csv") - df_wins = df_closes[df_closes["profit"]>0] - df_loses = df_closes[df_closes["profit"]<=0] + df_wins = df_closes[df_closes["profit"] > 0] + df_loses = df_closes[df_closes["profit"] <= 0] - ay_WinnerBarCnts = df_wins["closebarno"]-df_wins["openbarno"] - ay_LoserBarCnts = df_loses["closebarno"]-df_loses["openbarno"] + ay_WinnerBarCnts = df_wins["closebarno"] - df_wins["openbarno"] + ay_LoserBarCnts = df_loses["closebarno"] - df_loses["openbarno"] total_winbarcnts = ay_WinnerBarCnts.sum() total_losebarcnts = ay_LoserBarCnts.sum() total_fee = df_funds.iloc[-1]["fee"] - totaltimes = len(df_closes) # 总交易次数 - wintimes = len(df_wins) # 盈利次数 - losetimes = len(df_loses) # 亏损次数 - winamout = df_wins["profit"].sum() #毛盈利 - loseamount = df_loses["profit"].sum() #毛亏损 - trdnetprofit = winamout + loseamount #交易净盈亏 - accnetprofit = trdnetprofit - total_fee #账户净盈亏 - winrate = wintimes / totaltimes if totaltimes>0 else 0 # 胜率 - avgprof = trdnetprofit/totaltimes if totaltimes>0 else 0 # 单次平均盈亏 - avgprof_win = winamout/wintimes if wintimes>0 else 0 # 单次盈利均值 - avgprof_lose = loseamount/losetimes if losetimes>0 else 0 # 单次亏损均值 - winloseratio = abs(avgprof_win/avgprof_lose) if avgprof_lose!=0 else "N/A" # 单次盈亏均值比 - - max_consecutive_wins = 0 # 最大连续盈利次数 - max_consecutive_loses = 0 # 最大连续亏损次数 - - avg_bars_in_winner = total_winbarcnts/wintimes if wintimes>0 else "N/A" - avg_bars_in_loser = total_losebarcnts/losetimes if losetimes>0 else "N/A" + totaltimes = len(df_closes) # 总交易次数 + wintimes = len(df_wins) # 盈利次数 + losetimes = len(df_loses) # 亏损次数 + winamout = df_wins["profit"].sum() # 毛盈利 + loseamount = df_loses["profit"].sum() # 毛亏损 + trdnetprofit = winamout + loseamount # 交易净盈亏 + accnetprofit = trdnetprofit - total_fee # 账户净盈亏 + winrate = wintimes / totaltimes if totaltimes > 0 else 0 # 胜率 + avgprof = trdnetprofit / totaltimes if totaltimes > 0 else 0 # 单次平均盈亏 + avgprof_win = winamout / wintimes if wintimes > 0 else 0 # 单次盈利均值 + avgprof_lose = loseamount / losetimes if losetimes > 0 else 0 # 单次亏损均值 + winloseratio = abs(avgprof_win / avgprof_lose) if avgprof_lose != 0 else "N/A" # 单次盈亏均值比 + + max_consecutive_wins = 0 # 最大连续盈利次数 + max_consecutive_loses = 0 # 最大连续亏损次数 + + avg_bars_in_winner = total_winbarcnts / wintimes if wintimes > 0 else "N/A" + avg_bars_in_loser = total_losebarcnts / losetimes if losetimes > 0 else "N/A" consecutive_wins = 0 consecutive_loses = 0 @@ -102,7 +107,7 @@ def ayalyze_result(strName:str, time_range:tuple, params:dict, capital = 5000000 else: consecutive_wins = 0 consecutive_loses += 1 - + max_consecutive_wins = max(max_consecutive_wins, consecutive_wins) max_consecutive_loses = max(max_consecutive_loses, consecutive_loses) @@ -125,9 +130,9 @@ def ayalyze_result(strName:str, time_range:tuple, params:dict, capital = 5000000 summary["毛盈利"] = float(winamout) summary["毛亏损"] = float(loseamount) summary["账户净盈亏"] = float(accnetprofit) - summary["逐笔胜率%"] = winrate*100 + summary["逐笔胜率%"] = winrate * 100 summary["逐笔平均盈亏"] = avgprof - summary["逐笔平均净盈亏"] = accnetprofit/totaltimes + summary["逐笔平均净盈亏"] = accnetprofit / totaltimes summary["逐笔平均盈利"] = avgprof_win summary["逐笔逐笔亏损"] = avgprof_lose summary["逐笔盈亏比"] = winloseratio @@ -136,18 +141,19 @@ def ayalyze_result(strName:str, time_range:tuple, params:dict, capital = 5000000 summary["平均盈利周期"] = avg_bars_in_winner summary["平均亏损周期"] = avg_bars_in_loser - f = open(folder+"summary.json", mode="w") + f = open(folder + "summary.json", mode="w") f.write(json.dumps(obj=summary, indent=4)) f.close() return -def start_task_group(env_params, gpName:str, params:list, counter, capital = 5000000, rf = 0, period = 240, - strategy_type = None, cpp_stra_module = None, cpp_stra_type = None): - ''' + +def start_task_group(env_params, gpName: str, params: list, counter, capital=5000000, rf=0, period=240, + strategy_type=None, cpp_stra_module=None, cpp_stra_type=None): + """ 启动多个回测任务,来回测一组参数,这里共用一个engine,因此可以避免多次io @params 参数组 - ''' + """ is_yaml = True fname = "./logcfg_tpl.yaml" if not os.path.exists(fname): @@ -160,7 +166,7 @@ def start_task_group(env_params, gpName:str, params:list, counter, capital = 500 content = "{}" else: f = open(fname, "r") - content =f.read() + content = f.read() f.close() content = content.replace("$NAME$", gpName) if is_yaml: @@ -169,16 +175,17 @@ def start_task_group(env_params, gpName:str, params:list, counter, capital = 500 engine = WtBtEngine(eType=EngineType.ET_CTA, logCfg=content, isFile=False) # 配置类型的参数相对固定 if env_params["iscfgfile"]: - engine.init(env_params["deps_dir"], env_params["cfgfile"], - env_params["deps_files"]["commfile"], env_params["deps_files"]["contractfile"], - env_params["deps_files"]["sessionfile"], env_params["deps_files"]["holidayfile"], - env_params["deps_files"]["hotfile"], env_params["deps_files"]["secondfile"]) + engine.init(env_params["deps_dir"], env_params["cfgfile"], + env_params["deps_files"]["commfile"], env_params["deps_files"]["contractfile"], + env_params["deps_files"]["sessionfile"], env_params["deps_files"]["holidayfile"], + env_params["deps_files"]["hotfile"], env_params["deps_files"]["secondfile"]) else: engine.init_with_config(env_params["deps_dir"], env_params["cfgfile"], - env_params["deps_files"]["commfile"], env_params["deps_files"]["contractfile"], - env_params["deps_files"]["sessionfile"], env_params["deps_files"]["holidayfile"], - env_params["deps_files"]["hotfile"], env_params["deps_files"]["secondfile"]) - engine.configBTStorage(mode=env_params["storage_type"], path=env_params["storage_path"], storage=env_params["storage"]) + env_params["deps_files"]["commfile"], env_params["deps_files"]["contractfile"], + env_params["deps_files"]["sessionfile"], env_params["deps_files"]["holidayfile"], + env_params["deps_files"]["hotfile"], env_params["deps_files"]["secondfile"]) + engine.configBTStorage(mode=env_params["storage_type"], path=env_params["storage_path"], + storage=env_params["storage"]) # 遍历参数组 total = len(params) cnt = 0 @@ -186,7 +193,7 @@ def start_task_group(env_params, gpName:str, params:list, counter, capital = 500 cnt += 1 print(f"{gpName} 正在回测{cnt}/{total}") name = param["name"] - + engine.configBacktest(param["start_time"], param["end_time"]) time_range = (param["start_time"], param["end_time"]) # 去掉多余的参数 @@ -204,17 +211,19 @@ def start_task_group(env_params, gpName:str, params:list, counter, capital = 500 counter.value += 1 engine.release_backtest() + class ParamInfo: - ''' + """ 参数信息类 - ''' - def __init__(self, name:str, start_val = None, end_val = None, step_val = None, ndigits = 1, val_list:list = None): - self.name = name #参数名 - self.start_val = start_val #起始值 - self.end_val = end_val #结束值 - self.step_val = step_val #变化步长 - self.ndigits = ndigits #小数位 - self.val_list = val_list #指定参数 + """ + + def __init__(self, name: str, start_val=None, end_val=None, step_val=None, ndigits=1, val_list: list = None): + self.name = name # 参数名 + self.start_val = start_val # 起始值 + self.end_val = end_val # 结束值 + self.step_val = step_val # 变化步长 + self.ndigits = ndigits # 小数位 + self.val_list = val_list # 指定参数 def gen_array(self): if self.val_list is not None: @@ -233,17 +242,19 @@ def gen_array(self): values.append(round(curVal, self.ndigits)) return values + class WtCtaOptimizer: - ''' + """ 参数优化器\n 主要用于做策略参数优化的 - ''' - def __init__(self, worker_num:int = 8, notifier:OptimizeNotifier = None): - ''' + """ + + def __init__(self, worker_num: int = 8, notifier: OptimizeNotifier = None): + """ 构造函数\n @worker_num 工作进程个数,默认为8,可以根据CPU核心数设置 - ''' + """ self.worker_num = worker_num self.running_worker = 0 self.mutable_params = dict() @@ -253,11 +264,11 @@ def __init__(self, worker_num:int = 8, notifier:OptimizeNotifier = None): self.cpp_stra_module = None self.cpp_stra_type = None - self.notifier = notifier + self.notifier = notifier return - def add_mutable_param(self, name:str, start_val, end_val, step_val, ndigits = 1): - ''' + def add_mutable_param(self, name: str, start_val, end_val, step_val, ndigits=1): + """ 添加可变参数\n @name 参数名\n @@ -265,104 +276,105 @@ def add_mutable_param(self, name:str, start_val, end_val, step_val, ndigits = 1) @end_val 结束值\n @step_val 步长\n @ndigits 小数位 - ''' - self.mutable_params[name] = ParamInfo(name=name, start_val=start_val, end_val=end_val, step_val=step_val, ndigits=ndigits) + """ + self.mutable_params[name] = ParamInfo(name=name, start_val=start_val, end_val=end_val, step_val=step_val, + ndigits=ndigits) - def add_listed_param(self, name:str, val_list:list): - ''' + def add_listed_param(self, name: str, val_list: list): + """ 添加限定范围的可变参数\n @name 参数名\n @val_list 参数值列表 - ''' + """ self.mutable_params[name] = ParamInfo(name=name, val_list=val_list) - def add_fixed_param(self, name:str, val): - ''' + def add_fixed_param(self, name: str, val): + """ 添加固定参数\n @name 参数名\n @val 值\n - ''' + """ self.fixed_params[name] = val return - - def set_strategy(self, typeName:type, name_prefix:str): - ''' + + def set_strategy(self, typeName: type, name_prefix: str): + """ 设置策略\n @typeName 策略类名\n @name_prefix 命名前缀,用于自动命名用,一般为格式为"前缀_参数1名_参数1值_参数2名_参数2值" - ''' + """ self.strategy_type = typeName self.name_prefix = name_prefix return - def set_cpp_strategy(self, module:str, type_name:type, name_prefix:str): - ''' + def set_cpp_strategy(self, module: str, type_name: type, name_prefix: str): + """ 设置CPP策略\n @module 模块文件\n @typeName 策略类名\n @name_prefix 命名前缀,用于自动命名用,一般为格式为"前缀_参数1名_参数1值_参数2名_参数2值" - ''' + """ self.cpp_stra_module = module self.cpp_stra_type = type_name self.name_prefix = name_prefix return - def config_backtest_env(self, deps_dir:str, - cfgfile:str="configbt.yaml", - storage_type:str="csv", - storage_path:str = None, - storage:dict = None, - commfile:str = None, - contractfile:str = None, - sessionfile:str = None, - holidayfile:str= None, - hotfile:str = None, - secondfile:str = None, - iscfgfile:bool = True): - ''' + def config_backtest_env(self, deps_dir: str, + cfgfile: str = "configbt.yaml", + storage_type: str = "csv", + storage_path: str = None, + storage: dict = None, + commfile: str = None, + contractfile: str = None, + sessionfile: str = None, + holidayfile: str = None, + hotfile: str = None, + secondfile: str = None, + iscfgfile: bool = True): + """ 配置回测环境\n @deps_dir 依赖文件目录\n @cfgfile 配置文件名\n @storage_type 存储类型,csv/bin等\n @storage_path 存储路径 - ''' + """ self.env_params["deps_dir"] = deps_dir self.env_params["deps_files"] = { - "commfile":commfile, - "contractfile":contractfile, - "sessionfile":sessionfile, - "holidayfile":holidayfile, - "hotfile":hotfile, - "secondfile":secondfile + "commfile": commfile, + "contractfile": contractfile, + "sessionfile": sessionfile, + "holidayfile": holidayfile, + "hotfile": hotfile, + "secondfile": secondfile } - + self.env_params["cfgfile"] = cfgfile self.env_params["iscfgfile"] = iscfgfile self.env_params["storage_type"] = storage_type self.env_params["storage"] = storage self.env_params["storage_path"] = storage_path - def config_backtest_time(self, start_time:int, end_time:int): - ''' + def config_backtest_time(self, start_time: int, end_time: int): + """ 配置回测时间,可多次调用配置多个回测时间区间\n @start_time 开始时间,精确到分钟,格式如201909100930\n @end_time 结束时间,精确到分钟,格式如201909100930 - ''' + """ if "time_ranges" not in self.env_params: self.env_params["time_ranges"] = [] - self.env_params["time_ranges"].append([start_time,end_time]) + self.env_params["time_ranges"].append([start_time, end_time]) - def __gen_tasks__(self, markerfile:str = "strategies.json", order_by_field:str=""): - ''' + def __gen_tasks__(self, markerfile: str = "strategies.json", order_by_field: str = ""): + """ 生成回测任务 - ''' + """ param_names = self.mutable_params.keys() if order_by_field != "" and order_by_field in param_names: param_names = [order_by_field] + param_names.remove(order_by_field) @@ -377,7 +389,7 @@ def __gen_tasks__(self, markerfile:str = "strategies.json", order_by_field:str=" param_values[name] = values total_groups *= len(values) - #再生成最终每一组的参数dict + # 再生成最终每一组的参数dict param_groups = list() stra_names = dict() time_ranges = self.env_params["time_ranges"] @@ -386,27 +398,27 @@ def __gen_tasks__(self, markerfile:str = "strategies.json", order_by_field:str=" end_time = time_range[1] for i in range(total_groups): k = i - thisGrp = self.fixed_params.copy() #复制固定参数 + thisGrp = self.fixed_params.copy() # 复制固定参数 endix = '' for name in param_names: cnt = len(param_values[name]) - curVal = param_values[name][k%cnt] + curVal = param_values[name][k % cnt] tname = type(curVal) if tname.__name__ == "list": - val_str = '' + val_str = '' for item in curVal: val_str += str(item) val_str += "_" val_str = val_str[:-1] thisGrp[name] = curVal - endix += name + endix += name endix += "_" endix += val_str endix += "_" else: thisGrp[name] = curVal - endix += name + endix += name endix += "_" endix += str(curVal) endix += "_" @@ -420,23 +432,24 @@ def __gen_tasks__(self, markerfile:str = "strategies.json", order_by_field:str=" thisGrp["end_time"] = end_time stra_names[straName] = thisGrp param_groups.append(thisGrp) - + # 将每一组参数和对应的策略ID落地到文件中,方便后续的分析 f = open(markerfile, "w") f.write(json.dumps(obj=stra_names, sort_keys=True, indent=4)) f.close() return param_groups - def go(self, order_by_field:str = "", out_marker_file:str = "strategies.json", out_summary_file:str = "total_summary.csv", capital = 5000000, rf = 0, period = 240): - ''' + def go(self, order_by_field: str = "", out_marker_file: str = "strategies.json", + out_summary_file: str = "total_summary.csv", capital=5000000, rf=0, period=240): + """ 启动优化器\n @order_by_field 参数排序字段 @markerfile 标记文件名,回测完成以后分析会用到 - ''' + """ self.tasks = self.__gen_tasks__(out_marker_file, order_by_field) if self.notifier is not None: self.notifier.on_start(len(self.tasks)) - + stime = datetime.datetime.now() self.counter = multiprocessing.Manager().Value(ctypes.c_int, 0) pool = [] @@ -446,7 +459,7 @@ def go(self, order_by_field:str = "", out_marker_file:str = "strategies.json", o if gpSize == 0: gpSize = 1 work_id = 0 - max_cnt = min(self.worker_num,total_size) + max_cnt = min(self.worker_num, total_size) fromIdx = 0 for i in range(max_cnt): work_id = i + 1 @@ -458,15 +471,16 @@ def go(self, order_by_field:str = "", out_marker_file:str = "strategies.json", o params = self.tasks[fromIdx: fromIdx + thisCnt] p = multiprocessing.Process( - target=start_task_group, - args=(self.env_params, work_name, params, self.counter, capital, rf, period, self.strategy_type, self.cpp_stra_module, self.cpp_stra_type), + target=start_task_group, + args=(self.env_params, work_name, params, self.counter, capital, rf, period, self.strategy_type, + self.cpp_stra_module, self.cpp_stra_type), name=work_name) p.start() print(f"{work_name} 开始工作") pool.append(p) fromIdx += thisCnt - + alive_cnt = len(pool) while alive_cnt > 0: for task in pool: @@ -477,10 +491,11 @@ def go(self, order_by_field:str = "", out_marker_file:str = "strategies.json", o if self.notifier is not None: elapse = datetime.datetime.now() - stime - self.notifier.on_state(total_size, self.counter.value, self.counter.value*100/total_size, int(elapse.total_seconds()*1000)) + self.notifier.on_state(total_size, self.counter.value, self.counter.value * 100 / total_size, + int(elapse.total_seconds() * 1000)) time.sleep(0.5) - #开始汇总回测结果 + # 开始汇总回测结果 f = open(out_marker_file, "r") content = f.read() f.close() @@ -492,7 +507,7 @@ def go(self, order_by_field:str = "", out_marker_file:str = "strategies.json", o if not os.path.exists(filename): print(f"{filename}不存在,请检查数据") continue - + f = open(filename, "r") content = f.read() f.close() @@ -505,10 +520,10 @@ def go(self, order_by_field:str = "", out_marker_file:str = "strategies.json", o if self.notifier is not None: elapse = datetime.datetime.now() - stime - self.notifier.on_stop(total_size, elapse.total_seconds()*1000) + self.notifier.on_stop(total_size, elapse.total_seconds() * 1000) - def analyze(self, out_marker_file:str = "strategies.json", out_summary_file:str = "total_summary.csv"): - #开始汇总回测结果 + def analyze(self, out_marker_file: str = "strategies.json", out_summary_file: str = "total_summary.csv"): + # 开始汇总回测结果 f = open(out_marker_file, "r") content = f.read() f.close() @@ -519,12 +534,12 @@ def analyze(self, out_marker_file:str = "strategies.json", out_summary_file:str params = obj_stras[straName] filename = f"./outputs_bt/{straName}/summary.json" if not os.path.exists(filename): - print("%s不存在,请检查数据" % (filename)) + print("%s不存在,请检查数据" % filename) continue - - time_range = (params["start_time"],params["end_time"]) + + time_range = (params["start_time"], params["end_time"]) self.__ayalyze_result__(straName, time_range, params) - + f = open(filename, "r") content = f.read() f.close() @@ -535,11 +550,13 @@ def analyze(self, out_marker_file:str = "strategies.json", out_summary_file:str df_summary = df_summary.drop(labels=["name"], axis='columns') df_summary.to_csv(out_summary_file) - def analyzer(self, out_marker_file:str = "strategies.json", init_capital=500000, rf=0.02, annual_trading_days=240): + def analyzer(self, out_marker_file: str = "strategies.json", init_capital=500000, rf=0.02, annual_trading_days=240): for straname in json.load(open(out_marker_file, mode='r')).keys(): try: analyst = WtBtAnalyst() - analyst.add_strategy(straname, folder=f"./outputs_bt/{straname}/", init_capital=init_capital, rf=rf, annual_trading_days=annual_trading_days) + analyst.add_strategy(straname, folder=f"./outputs_bt/{straname}/", init_capital=init_capital, rf=rf, + annual_trading_days=annual_trading_days) analyst.run() - except: - pass \ No newline at end of file + except Exception as e: + print(e) + pass diff --git a/wtpy/apps/WtHftOptimizer.py b/wtpy/apps/WtHftOptimizer.py index 5db96e16..4c2f2885 100644 --- a/wtpy/apps/WtHftOptimizer.py +++ b/wtpy/apps/WtHftOptimizer.py @@ -1,11 +1,11 @@ -''' +""" Descripttion: HFT参数寻优模块 version: Author: HeJ Date: 2022-06-22 14:03:33 LastEditors: Wesley LastEditTime: 2022-06-22 14:03:33 -''' +""" import multiprocessing import json import yaml @@ -16,26 +16,29 @@ import pandas as pd from pandas import DataFrame as df import datetime -from wtpy import WtBtEngine,EngineType +from wtpy import WtBtEngine, EngineType from wtpy.apps import WtBtAnalyst -def fmtNAN(val, defVal = 0): + +def fmtNAN(val, defVal=0): if math.isnan(val): return defVal return val + class ParamInfo: - ''' + """ 参数信息类 - ''' - def __init__(self, name:str, start_val = None, end_val = None, step_val = None, ndigits = 1, val_list:list = None): - self.name = name #参数名 - self.start_val = start_val #起始值 - self.end_val = end_val #结束值 - self.step_val = step_val #变化步长 - self.ndigits = ndigits #小数位 - self.val_list = val_list #指定参数 + """ + + def __init__(self, name: str, start_val=None, end_val=None, step_val=None, ndigits=1, val_list: list = None): + self.name = name # 参数名 + self.start_val = start_val # 起始值 + self.end_val = end_val # 结束值 + self.step_val = step_val # 变化步长 + self.ndigits = ndigits # 小数位 + self.val_list = val_list # 指定参数 def gen_array(self): if self.val_list is not None: @@ -54,17 +57,19 @@ def gen_array(self): values.append(round(curVal, self.ndigits)) return values + class WtHftOptimizer: - ''' + """ 参数优化器\n 主要用于做策略参数优化的 - ''' - def __init__(self, worker_num:int = 8): - ''' + """ + + def __init__(self, worker_num: int = 8): + """ 构造函数\n @worker_num 工作进程个数,默认为8,可以根据CPU核心数设置 - ''' + """ self.worker_num = worker_num self.running_worker = 0 self.mutable_params = dict() @@ -74,8 +79,8 @@ def __init__(self, worker_num:int = 8): self.cpp_stra_module = None return - def add_mutable_param(self, name:str, start_val, end_val, step_val, ndigits = 1): - ''' + def add_mutable_param(self, name: str, start_val, end_val, step_val, ndigits=1): + """ 添加可变参数\n @name 参数名\n @@ -83,83 +88,85 @@ def add_mutable_param(self, name:str, start_val, end_val, step_val, ndigits = 1) @end_val 结束值\n @step_val 步长\n @ndigits 小数位 - ''' - self.mutable_params[name] = ParamInfo(name=name, start_val=start_val, end_val=end_val, step_val=step_val, ndigits=ndigits) + """ + self.mutable_params[name] = ParamInfo(name=name, start_val=start_val, end_val=end_val, step_val=step_val, + ndigits=ndigits) - def add_listed_param(self, name:str, val_list:list): - ''' + def add_listed_param(self, name: str, val_list: list): + """ 添加限定范围的可变参数\n @name 参数名\n @val_list 参数值列表 - ''' + """ self.mutable_params[name] = ParamInfo(name=name, val_list=val_list) - def add_fixed_param(self, name:str, val): - ''' + def add_fixed_param(self, name: str, val): + """ 添加固定参数\n @name 参数名\n @val 值\n - ''' + """ self.fixed_params[name] = val return - - def set_strategy(self, typeName:type, name_prefix:str): - ''' + + def set_strategy(self, typeName: type, name_prefix: str): + """ 设置策略\n @typeName 策略类名\n @name_prefix 命名前缀,用于自动命名用,一般为格式为"前缀_参数1名_参数1值_参数2名_参数2值" - ''' + """ self.strategy_type = typeName self.name_prefix = name_prefix return - def set_cpp_strategy(self, module:str, type_name:type, name_prefix:str): - ''' + def set_cpp_strategy(self, module: str, type_name: type, name_prefix: str): + """ 设置CPP策略\n @module 模块文件\n @typeName 策略类名\n @name_prefix 命名前缀,用于自动命名用,一般为格式为"前缀_参数1名_参数1值_参数2名_参数2值" - ''' + """ self.cpp_stra_module = module self.cpp_stra_type = type_name self.name_prefix = name_prefix return - def config_backtest_env(self, deps_dir:str, cfgfile:str="configbt.yaml", storage_type:str="csv", storage_path:str = None, storage:dict = None): - ''' + def config_backtest_env(self, deps_dir: str, cfgfile: str = "configbt.yaml", storage_type: str = "csv", + storage_path: str = None, storage: dict = None): + """ 配置回测环境\n @deps_dir 依赖文件目录\n @cfgfile 配置文件名\n @storage_type 存储类型,csv/bin等\n @storage_path 存储路径 - ''' + """ self.env_params["deps_dir"] = deps_dir self.env_params["cfgfile"] = cfgfile self.env_params["storage_type"] = storage_type self.env_params["storage"] = storage self.env_params["storage_path"] = storage_path - def config_backtest_time(self, start_time:int, end_time:int): - ''' + def config_backtest_time(self, start_time: int, end_time: int): + """ 配置回测时间,可多次调用配置多个回测时间区间\n @start_time 开始时间,精确到分钟,格式如201909100930\n @end_time 结束时间,精确到分钟,格式如201909100930 - ''' + """ if "time_ranges" not in self.env_params: self.env_params["time_ranges"] = [] - self.env_params["time_ranges"].append([start_time,end_time]) + self.env_params["time_ranges"].append([start_time, end_time]) - def __gen_tasks__(self, markerfile:str = "strategies.json"): - ''' + def __gen_tasks__(self, markerfile: str = "strategies.json"): + """ 生成回测任务 - ''' + """ param_names = self.mutable_params.keys() param_values = dict() # 先生成各个参数的变量数组 @@ -171,7 +178,7 @@ def __gen_tasks__(self, markerfile:str = "strategies.json"): param_values[name] = values total_groups *= len(values) - #再生成最终每一组的参数dict + # 再生成最终每一组的参数dict param_groups = list() stra_names = dict() time_ranges = self.env_params["time_ranges"] @@ -180,27 +187,27 @@ def __gen_tasks__(self, markerfile:str = "strategies.json"): end_time = time_range[1] for i in range(total_groups): k = i - thisGrp = self.fixed_params.copy() #复制固定参数 + thisGrp = self.fixed_params.copy() # 复制固定参数 endix = '' for name in param_names: cnt = len(param_values[name]) - curVal = param_values[name][k%cnt] + curVal = param_values[name][k % cnt] tname = type(curVal) if tname.__name__ == "list": - val_str = '' + val_str = '' for item in curVal: val_str += str(item) val_str += "_" val_str = val_str[:-1] thisGrp[name] = curVal - endix += name + endix += name endix += "_" endix += val_str endix += "_" else: thisGrp[name] = curVal - endix += name + endix += name endix += "_" endix += str(curVal) endix += "_" @@ -214,48 +221,50 @@ def __gen_tasks__(self, markerfile:str = "strategies.json"): thisGrp["end_time"] = end_time stra_names[straName] = thisGrp param_groups.append(thisGrp) - + # 将每一组参数和对应的策略ID落地到文件中,方便后续的分析 f = open(markerfile, "w") f.write(json.dumps(obj=stra_names, sort_keys=True, indent=4)) f.close() return param_groups - def __ayalyze_result__(self, strName:str, time_range:tuple, params:dict): - folder = "./outputs_bt/%s/" % (strName) - df_closes = pd.read_csv(folder + "closes.csv",encoding="gbk") + def __ayalyze_result__(self, strName: str, time_range: tuple, params: dict): + folder = "./outputs_bt/%s/" % strName + df_closes = pd.read_csv(folder + "closes.csv", encoding="gbk") df_funds = pd.read_csv(folder + "funds.csv") - - df_closes["opentime"] = df_closes["opentime"].astype("str").apply(lambda dt: datetime.datetime.strptime(dt,"%Y%m%d%H%M%S%f")) - df_closes["closetime"] = df_closes["closetime"].astype("str").apply(lambda dt: datetime.datetime.strptime(dt,"%Y%m%d%H%M%S%f")) + + df_closes["opentime"] = df_closes["opentime"].astype("str").apply( + lambda dt: datetime.datetime.strptime(dt, "%Y%m%d%H%M%S%f")) + df_closes["closetime"] = df_closes["closetime"].astype("str").apply( + lambda dt: datetime.datetime.strptime(dt, "%Y%m%d%H%M%S%f")) df_closes["holdperiod"] = (df_closes["closetime"] - df_closes["opentime"]).astype('timedelta64[s]') - - df_wins = df_closes[df_closes["profit"]>0] - df_loses = df_closes[df_closes["profit"]<=0] + + df_wins = df_closes[df_closes["profit"] > 0] + df_loses = df_closes[df_closes["profit"] <= 0] total_win_holdperiod = df_wins["holdperiod"].sum() total_lose_holdperiod = df_loses["holdperiod"].sum() total_fee = df_funds.iloc[-1]["fee"] - totaltimes = len(df_closes) # 总交易次数 - wintimes = len(df_wins) # 盈利次数 - losetimes = len(df_loses) # 亏损次数 - winamout = df_wins["profit"].sum() #毛盈利 - loseamount = df_loses["profit"].sum() #毛亏损 - trdnetprofit = winamout + loseamount #交易净盈亏 - accnetprofit = trdnetprofit - total_fee #账户净盈亏 - winrate = wintimes / totaltimes if totaltimes>0 else 0 # 胜率 - avgprof = trdnetprofit/totaltimes if totaltimes>0 else 0 # 单次平均盈亏 - avgprof_win = winamout/wintimes if wintimes>0 else 0 # 单次盈利均值 - avgprof_lose = loseamount/losetimes if losetimes>0 else 0 # 单次亏损均值 - winloseratio = abs(avgprof_win/avgprof_lose) if avgprof_lose!=0 else "N/A" # 单次盈亏均值比 - - max_consecutive_wins = 0 # 最大连续盈利次数 - max_consecutive_loses = 0 # 最大连续亏损次数 - - avg_time_in_winner = total_win_holdperiod/wintimes if wintimes>0 else "N/A" - avg_time_in_loser = total_lose_holdperiod/losetimes if losetimes>0 else "N/A" + totaltimes = len(df_closes) # 总交易次数 + wintimes = len(df_wins) # 盈利次数 + losetimes = len(df_loses) # 亏损次数 + winamout = df_wins["profit"].sum() # 毛盈利 + loseamount = df_loses["profit"].sum() # 毛亏损 + trdnetprofit = winamout + loseamount # 交易净盈亏 + accnetprofit = trdnetprofit - total_fee # 账户净盈亏 + winrate = wintimes / totaltimes if totaltimes > 0 else 0 # 胜率 + avgprof = trdnetprofit / totaltimes if totaltimes > 0 else 0 # 单次平均盈亏 + avgprof_win = winamout / wintimes if wintimes > 0 else 0 # 单次盈利均值 + avgprof_lose = loseamount / losetimes if losetimes > 0 else 0 # 单次亏损均值 + winloseratio = abs(avgprof_win / avgprof_lose) if avgprof_lose != 0 else "N/A" # 单次盈亏均值比 + + max_consecutive_wins = 0 # 最大连续盈利次数 + max_consecutive_loses = 0 # 最大连续亏损次数 + + avg_time_in_winner = total_win_holdperiod / wintimes if wintimes > 0 else "N/A" + avg_time_in_loser = total_lose_holdperiod / losetimes if losetimes > 0 else "N/A" consecutive_wins = 0 consecutive_loses = 0 @@ -267,22 +276,23 @@ def __ayalyze_result__(self, strName:str, time_range:tuple, params:dict): else: consecutive_wins = 0 consecutive_loses += 1 - + max_consecutive_wins = max(max_consecutive_wins, consecutive_wins) max_consecutive_loses = max(max_consecutive_loses, consecutive_loses) total_fee = df_funds["fee"].sum() - yearRet = round(df_funds["closeprofit"].mean() * np.float(244),3) + yearRet = round(df_funds["closeprofit"].mean() * np.float(244), 3) vol = (df_funds["closeprofit"].std() * np.sqrt(244)).round(3) sr = yearRet / vol sr = sr.round(3) - df_funds["dd"] = (df_funds['dynbalance'].rolling(len(df_funds)+1,min_periods=1).max()- df_funds['dynbalance']) + df_funds["dd"] = ( + df_funds['dynbalance'].rolling(len(df_funds) + 1, min_periods=1).max() - df_funds['dynbalance']) maxdd = df_funds["dd"].max() - maxdd_t = df_funds[df_funds['dd']==maxdd]["date"].values[0] + maxdd_t = df_funds[df_funds['dd'] == maxdd]["date"].values[0] maxdd = maxdd + 0.1 - cr = yearRet/ maxdd - cr = np.round(cr,3) - + cr = yearRet / maxdd + cr = np.round(cr, 3) + summary = params.copy() summary["开始时间"] = time_range[0] summary["结束时间"] = time_range[1] @@ -293,7 +303,7 @@ def __ayalyze_result__(self, strName:str, time_range:tuple, params:dict): summary["毛盈利"] = float(winamout) summary["毛亏损"] = float(loseamount) summary["交易净盈亏"] = float(trdnetprofit) - summary["胜率"] = winrate*100 + summary["胜率"] = winrate * 100 summary["单次平均盈亏"] = avgprof summary["单次盈利均值"] = avgprof_win summary["单次亏损均值"] = avgprof_lose @@ -302,7 +312,7 @@ def __ayalyze_result__(self, strName:str, time_range:tuple, params:dict): summary["最大连续亏损次数"] = max_consecutive_loses summary["平均盈利周期(s)"] = avg_time_in_winner summary["平均亏损周期(s)"] = avg_time_in_loser - summary["平均账户收益率"] = accnetprofit/totaltimes + summary["平均账户收益率"] = accnetprofit / totaltimes summary["年均收入"] = yearRet summary["波动率"] = vol summary["夏普"] = sr @@ -310,20 +320,20 @@ def __ayalyze_result__(self, strName:str, time_range:tuple, params:dict): summary["最大回撤发生时间"] = str(maxdd_t) summary["卡玛"] = cr - f = open(folder+"summary.json", mode="w") + f = open(folder + "summary.json", mode="w") f.write(json.dumps(obj=summary, indent=4)) f.close() return - def __execute_task__(self, params:dict): - ''' + def __execute_task__(self, params: dict): + """ 执行单个回测任务\n @params kv形式的参数 - ''' + """ name = params["name"] - + is_yaml = True fname = "logcfg_tpl.yaml" if not os.path.exists(fname): @@ -331,7 +341,7 @@ def __execute_task__(self, params:dict): fname = "logcfg_tpl.json" f = open(fname, "r") - content =f.read() + content = f.read() f.close() content = content.replace("$NAME$", name) if is_yaml: @@ -339,14 +349,15 @@ def __execute_task__(self, params:dict): engine = WtBtEngine(eType=EngineType.ET_HFT, logCfg=content, isFile=False) engine.init(self.env_params["deps_dir"], self.env_params["cfgfile"]) engine.configBacktest(params["start_time"], params["end_time"]) - engine.configBTStorage(mode=self.env_params["storage_type"], path=self.env_params["storage_path"], storage=self.env_params["storage"]) + engine.configBTStorage(mode=self.env_params["storage_type"], path=self.env_params["storage_path"], + storage=self.env_params["storage"]) time_range = (params["start_time"], params["end_time"]) # 去掉多余的参数 params.pop("start_time") params.pop("end_time") - + if self.cpp_stra_module is not None: params.pop("name") engine.setExternalHftStrategy(name, self.cpp_stra_module, self.cpp_stra_type, params) @@ -360,26 +371,26 @@ def __execute_task__(self, params:dict): self.__ayalyze_result__(name, time_range, params) - def __start_task__(self, params:dict): - ''' + def __start_task__(self, params: dict): + """ 启动单个回测任务\n 这里用线程启动子进程的目的是为了可以控制总的工作进程个数\n 可以在线程中join等待子进程结束,再更新running_worker变量\n 如果在__execute_task__中修改running_worker,因为在不同进程中,数据并不同步\n @params kv形式的参数 - ''' + """ p = multiprocessing.Process(target=self.__execute_task__, args=(params,)) p.start() p.join() self.running_worker -= 1 - print("工作进程%d个" % (self.running_worker)) - - def __start_task_group__(self, params:list): - ''' + print("工作进程%d个" % self.running_worker) + + def __start_task_group__(self, params: list): + """ 启动多个回测任务,来回测一组参数,这里共用一个engine,因此可以避免多次io @params 参数组 - ''' + """ is_yaml = True fname = "logcfg_tpl.yaml" if not os.path.exists(fname): @@ -387,19 +398,20 @@ def __start_task_group__(self, params:list): fname = "logcfg_tpl.json" f = open(fname, "r") - content =f.read() + content = f.read() f.close() engine = WtBtEngine(eType=EngineType.ET_CTA, logCfg=content, isFile=False) # 配置类型的参数相对固定 engine.init(self.env_params["deps_dir"], self.env_params["cfgfile"]) - engine.configBTStorage(mode=self.env_params["storage_type"], path=self.env_params["storage_path"], storage=self.env_params["storage"]) + engine.configBTStorage(mode=self.env_params["storage_type"], path=self.env_params["storage_path"], + storage=self.env_params["storage"]) # 遍历参数组 for param in params: name = param["name"] param_content = content.replace("$NAME$", name) if is_yaml: param_content = json.dumps(yaml.full_load(param_content)) - + engine.configBacktest(param["start_time"], param["end_time"]) time_range = (param["start_time"], param["end_time"]) # 去掉多余的参数 @@ -416,32 +428,33 @@ def __start_task_group__(self, params:list): self.__ayalyze_result__(name, time_range, param) engine.release_backtest() - def go(self, interval:float = 0.2, out_marker_file:str = "strategies.json", out_summary_file:str = "total_summary.csv"): - ''' + def go(self, interval: float = 0.2, out_marker_file: str = "strategies.json", + out_summary_file: str = "total_summary.csv"): + """ 启动优化器\n @interval 时间间隔,单位秒 @markerfile 标记文件名,回测完成以后分析会用到 - ''' + """ self.tasks = self.__gen_tasks__(out_marker_file) pool = [] total_size = len(self.tasks) one_size = round(total_size / self.worker_num) work_id = 0 - for i in range(0,total_size,one_size): + for i in range(0, total_size, one_size): work_id += 1 - params = self.tasks[i:i+one_size] + params = self.tasks[i:i + one_size] p = multiprocessing.Process(target=self.__start_task_group__, args=(params,)) p.start() - print("工人%d开始工作" % (work_id)) + print("工人%d开始工作" % work_id) pool.append(p) - + work_id = 0 for task in pool: work_id += 1 task.join() - print("工人%d结束工作" % (work_id)) + print("工人%d结束工作" % work_id) - #开始汇总回测结果 + # 开始汇总回测结果 f = open(out_marker_file, "r") content = f.read() f.close() @@ -449,11 +462,11 @@ def go(self, interval:float = 0.2, out_marker_file:str = "strategies.json", out_ obj_stras = json.loads(content) total_summary = list() for straName in obj_stras: - filename = "./outputs_bt/%s/summary.json" % (straName) + filename = "./outputs_bt/%s/summary.json" % straName if not os.path.exists(filename): - print("%s不存在,请检查数据" % (filename)) + print("%s不存在,请检查数据" % filename) continue - + f = open(filename, "r") content = f.read() f.close() @@ -464,8 +477,8 @@ def go(self, interval:float = 0.2, out_marker_file:str = "strategies.json", out_ # df_summary = df_summary.drop(labels=["name"], axis='columns') df_summary.to_csv(out_summary_file, encoding='utf-8-sig') - def analyze(self, out_marker_file:str = "strategies.json", out_summary_file:str = "total_summary.csv"): - #开始汇总回测结果 + def analyze(self, out_marker_file: str = "strategies.json", out_summary_file: str = "total_summary.csv"): + # 开始汇总回测结果 f = open(out_marker_file, "r") content = f.read() f.close() @@ -474,14 +487,14 @@ def analyze(self, out_marker_file:str = "strategies.json", out_summary_file:str obj_stras = json.loads(content) for straName in obj_stras: params = obj_stras[straName] - filename = "./outputs_bt/%s/summary.json" % (straName) + filename = "./outputs_bt/%s/summary.json" % straName if not os.path.exists(filename): - print("%s不存在,请检查数据" % (filename)) + print("%s不存在,请检查数据" % filename) continue - - time_range = (params["start_time"],params["end_time"]) + + time_range = (params["start_time"], params["end_time"]) self.__ayalyze_result__(straName, time_range, params) - + f = open(filename, "r") content = f.read() f.close() @@ -492,13 +505,13 @@ def analyze(self, out_marker_file:str = "strategies.json", out_summary_file:str df_summary = df_summary.drop(labels=["name"], axis='columns') df_summary.to_csv(out_summary_file) - def analyzer(self, out_marker_file:str = "strategies.json", init_capital=500000, rf=0.02, annual_trading_days=240): + def analyzer(self, out_marker_file: str = "strategies.json", init_capital=500000, rf=0.02, annual_trading_days=240): for straname in json.load(open(out_marker_file, mode='r')).keys(): try: analyst = WtBtAnalyst() - analyst.add_strategy(straname, folder="./outputs_bt/%s/"%straname, init_capital=init_capital, rf=rf, annual_trading_days=annual_trading_days) + analyst.add_strategy(straname, folder="./outputs_bt/%s/" % straname, init_capital=init_capital, rf=rf, + annual_trading_days=annual_trading_days) analyst.run() - except: + except Exception as e: + print(e) pass - - diff --git a/wtpy/apps/WtHotPicker.py b/wtpy/apps/WtHotPicker.py index cda6baf7..031a2eca 100644 --- a/wtpy/apps/WtHotPicker.py +++ b/wtpy/apps/WtHotPicker.py @@ -1,4 +1,3 @@ - import datetime import time import json @@ -12,10 +11,11 @@ from pyquery import PyQuery as pq import re + class DayData: - ''' + """ 每日行情数据 - ''' + """ def __init__(self): self.pid = '' @@ -25,15 +25,16 @@ def __init__(self): self.volume = 0 # 成交量(手) self.hold = 0 # 空盘量(总持?持仓量) + def extractPID(code): - for idx in range(0, len(code)): c = code[idx] - if '0' <= c and c <= '9': + if '0' <= c <= '9': break - + return code[:idx] + def readFileContent(filename): if not os.path.exists(filename): return "" @@ -42,29 +43,31 @@ def readFileContent(filename): f.close() return content -def cmp_alg_01(left:DayData, right:DayData): + +def cmp_alg_01(left: DayData, right: DayData): if left.month > right.month: - if left.hold > right.hold and left.volume > right.volume/3: + if left.hold > right.hold and left.volume > right.volume / 3: return 1 else: return -1 else: - if left.hold <= right.hold or left.volume <= right.volume/3: + if left.hold <= right.hold or left.volume <= right.volume / 3: return -1 else: return 1 -def countFridays(curDate:datetime.datetime): - ''' + +def countFridays(curDate: datetime.datetime): + """ 计算截止到当周的周五的天数 - ''' + """ wd = curDate.weekday() checkDate = datetime.datetime(year=curDate.year, month=curDate.month, day=1) count = 0 while checkDate < curDate: if checkDate.weekday() == 4: count += 1 - + checkDate += datetime.timedelta(days=1) if wd < 4: @@ -72,6 +75,7 @@ def countFridays(curDate:datetime.datetime): return count + def httpGet(url, encoding='utf-8'): request = urllib.request.Request(url) request.add_header('Accept-encoding', 'gzip') @@ -86,9 +90,11 @@ def httpGet(url, encoding='utf-8'): f = gzip.GzipFile(fileobj=cs) return f.read().decode(encoding) - except: + except Exception as e: + print(e) return "" + def httpPost(url, datas, encoding='utf-8'): headers = { 'User-Agent': 'Mozilla/4.0 (compatible; MSIE 5.5; Windows NT)', @@ -105,55 +111,60 @@ def httpPost(url, datas, encoding='utf-8'): f = gzip.GzipFile(fileobj=cs) return f.read().decode(encoding) - except: + except Exception as e: + print(e) return "" + class WtCacheMon: - ''' + """ 缓存管理器基类 - ''' + """ + def __init__(self): self.day_cache = dict() - def get_cache(self, exchg, curDT:datetime.datetime): + def get_cache(self, exchg, curDT: datetime.datetime): pass + class WtCacheMonExchg(WtCacheMon): - ''' + """ 交易所行情缓存器 通过到交易所官网上拉取当日的行情快照,缓存当日行情数据 - ''' + """ @staticmethod - def getCffexData(curDT:datetime.datetime) -> dict: - ''' + def getCffexData(curDT: datetime.datetime) -> dict: + """ 读取CFFEX指定日期的行情快照 @curDT 指定的日期 - ''' + """ dtStr = curDT.strftime('%Y%m%d') dtNum = int(dtStr) - path = "http://www.cffex.com.cn/fzjy/mrhq/%d/%02d/index.xml" % (dtNum/100, dtNum % 100) + path = "http://www.cffex.com.cn/fzjy/mrhq/%d/%02d/index.xml" % (dtNum / 100, dtNum % 100) content = httpGet(path) if len(content) == 0: return None try: dom = xml.dom.minidom.parseString(content) - except: - logging.info("[CFFEX]%s无数据,跳过" % (dtStr)) + except Exception as e: + print(e) + logging.info("[CFFEX]%s无数据,跳过" % dtStr) return None root = dom.documentElement - + items = {} days = root.getElementsByTagName("dailydata") for day in days: pid = day.getElementsByTagName( "productid")[0].firstChild.data.strip() - if pid not in ["IF","IH","IC","T",'TF','TS']: + if pid not in ["IF", "IH", "IC", "T", 'TF', 'TS']: continue item = DayData() @@ -173,18 +184,18 @@ def getCffexData(curDT:datetime.datetime) -> dict: return items @staticmethod - def getShfeData(curDT:datetime.datetime) -> dict: - ''' + def getShfeData(curDT: datetime.datetime) -> dict: + """ 读取SHFE指定日期的行情快照 @curDT 指定的日期 - ''' + """ dtStr = curDT.strftime('%Y%m%d') - content = httpGet("http://www.shfe.com.cn/data/dailydata/kx/kx%s.dat" % (dtStr)) + content = httpGet("http://www.shfe.com.cn/data/dailydata/kx/kx%s.dat" % dtStr) if len(content) == 0: return None - + items = {} root = json.loads(content) for day in root['o_curinstrument']: @@ -211,12 +222,12 @@ def getShfeData(curDT:datetime.datetime) -> dict: return items @staticmethod - def getCzceData(curDT:datetime.datetime) -> dict: - ''' + def getCzceData(curDT: datetime.datetime) -> dict: + """ 读取CZCE指定日期的行情快照 @curDT 指定的日期 - ''' + """ dtStr = curDT.strftime('%Y%m%d') url = 'http://www.czce.com.cn/cn/DFSStaticFiles/Future/%s/%s/FutureDataDaily.htm' % (dtStr[0:4], dtStr) @@ -241,7 +252,7 @@ def getCzceData(curDT:datetime.datetime) -> dict: # tr行数 trcount = len(lis) # 遍历行 - for tr in range(0, trcount-1): + for tr in range(0, trcount - 1): item = DayData() tdlis = doc(lis[tr])('td') @@ -250,19 +261,19 @@ def getCzceData(curDT:datetime.datetime) -> dict: if len(ay) == 0: continue - item.pid = ay[0] + item.pid = ay[0] close = doc(tdlis[5]).text() if close != '': - item.close = float(close.replace(",","")) + item.close = float(close.replace(",", "")) volume = doc(tdlis[9]).text() if volume != '': - item.volume = int(volume.replace(",","")) + item.volume = int(volume.replace(",", "")) hold = doc(tdlis[10]).text() if hold != '': - item.hold = int(hold.replace(",","")) + item.hold = int(hold.replace(",", "")) item.month = item.code[len(item.pid):] # 这个逻辑是有点问题的,但是没好的办法 @@ -278,12 +289,12 @@ def getCzceData(curDT:datetime.datetime) -> dict: return dataitems @staticmethod - def getDceData(curDT:datetime.datetime) -> dict: - ''' + def getDceData(curDT: datetime.datetime) -> dict: + """ 读取DCE指定日期的行情快照 @curDT 指定的日期 - ''' + """ pname_map = { "聚乙烯": "l", @@ -302,21 +313,17 @@ def getDceData(curDT:datetime.datetime) -> dict: "聚丙烯": "pp", "聚氯乙烯": "v", "豆油": "y", - "乙二醇":"eg", - "粳米":"rr", - "苯乙烯":"eb", - "液化石油气":"pg", - "生猪":"lh" + "乙二醇": "eg", + "粳米": "rr", + "苯乙烯": "eb", + "液化石油气": "pg", + "生猪": "lh" } url = 'http://www.dce.com.cn/publicweb/quotesdata/dayQuotesCh.html' try: - data = {} - data['dayQuotes.variety'] = 'all' - data['dayQuotes.trade_type'] = 0 - data['year'] = curDT.year - data['month'] = curDT.month - 1 - data['day'] = curDT.day + data = {'dayQuotes.variety': 'all', 'dayQuotes.trade_type': 0, 'year': curDT.year, 'month': curDT.month - 1, + 'day': curDT.day} html = httpPost(url, data) except urllib.error.HTTPError as httperror: print(httperror) @@ -359,14 +366,14 @@ def getDceData(curDT:datetime.datetime) -> dict: return dataitems @staticmethod - def getIneData(curDT:datetime.datetime) -> dict: - ''' + def getIneData(curDT: datetime.datetime) -> dict: + """ 读取INE指定日期的行情快照 @curDT 指定的日期 - ''' + """ dtStr = curDT.strftime('%Y%m%d') - content = httpGet("http://www.ine.cn/data/dailydata/kx/kx%s.dat" % (dtStr)) + content = httpGet("http://www.ine.cn/data/dailydata/kx/kx%s.dat" % dtStr) if len(content) == 0: return None @@ -382,19 +389,18 @@ def getIneData(curDT:datetime.datetime) -> dict: item.code = pid + dm item.hold = int(day['OPENINTEREST']) item.close = float(day['CLOSEPRICE']) - item.volume = int(day['VOLUME']) if day['VOLUME']!='' else 0 + item.volume = int(day['VOLUME']) if day['VOLUME'] != '' else 0 item.month = item.code[len(item.pid):] items[item.code] = item return items - - def cache_by_date(self, exchg:str, curDT:datetime.datetime): - ''' + def cache_by_date(self, exchg: str, curDT: datetime.datetime): + """ 缓存指定日期指定交易所的行数据 @exchg 交易所代码 @curDT 指定日期 - ''' + """ dtStr = curDT.strftime('%Y%m%d') if dtStr not in self.day_cache: @@ -403,24 +409,24 @@ def cache_by_date(self, exchg:str, curDT:datetime.datetime): cacheItem = self.day_cache[dtStr] if exchg == 'CFFEX': cacheItem[exchg] = WtCacheMonExchg.getCffexData(curDT) - elif exchg == 'SHFE': + elif exchg == 'SHFE': cacheItem[exchg] = WtCacheMonExchg.getShfeData(curDT) - elif exchg == 'DCE': + elif exchg == 'DCE': cacheItem[exchg] = WtCacheMonExchg.getDceData(curDT) - elif exchg == 'CZCE': + elif exchg == 'CZCE': cacheItem[exchg] = WtCacheMonExchg.getCzceData(curDT) - elif exchg == 'INE': + elif exchg == 'INE': cacheItem[exchg] = WtCacheMonExchg.getIneData(curDT) else: raise Exception("未知交易所代码" + exchg) - def get_cache(self, exchg:str, curDT:datetime.datetime): - ''' + def get_cache(self, exchg: str, curDT: datetime.datetime): + """ 获取指定日期的某个交易所合约的快照数据 @exchg 交易所代码 @curDT 指定日期 - ''' + """ dtStr = curDT.strftime('%Y%m%d') if dtStr not in self.day_cache or exchg not in self.day_cache[dtStr]: self.cache_by_date(exchg, curDT) @@ -432,23 +438,24 @@ def get_cache(self, exchg:str, curDT:datetime.datetime): return None return self.day_cache[dtStr][exchg] + class WtCacheMonSS(WtCacheMon): - ''' + """ 快照缓存管理器 通过读取wtpy的datakit当日生成的快照文件,缓存当日行情数据 一般目录为"数据存储目录/his/snapshots/xxxxxxx.csv" - ''' + """ - def __init__(self, snapshot_path:str): + def __init__(self, snapshot_path: str): WtCacheMon.__init__(self) self.snapshot_path = snapshot_path - def cache_snapshot(self, curDT:datetime): - ''' + def cache_snapshot(self, curDT: datetime): + """ 缓存指定日期的快照数据 @curDT 指定的日期 - ''' + """ dtStr = curDT.strftime('%Y%m%d') filename = "%s%s.csv" % (self.snapshot_path, dtStr) @@ -464,7 +471,7 @@ def cache_snapshot(self, curDT:datetime): if len(line) == 0: break items = line.split(",") - + exchg = items[1] if exchg not in cacheItem: cacheItem[exchg] = dict() @@ -480,19 +487,19 @@ def cache_snapshot(self, curDT:datetime): day.hold = float(items[10]) day.month = day.code[len(day.pid):] if len(day.month) == 3: - if day.month[0] >= '0' and day.month[0] <= '5': + if '0' <= day.month[0] <= '5': day.month = "2" + day.month else: day.month = "1" + day.month cacheItem[exchg][day.code] = day - def get_cache(self, exchg, curDT:datetime): - ''' + def get_cache(self, exchg, curDT: datetime): + """ 获取指定日期的某个交易所合约的快照数据 @exchg 交易所代码 @curDT 指定日期 - ''' + """ dtStr = curDT.strftime('%Y%m%d') if dtStr not in self.day_cache: @@ -505,34 +512,38 @@ def get_cache(self, exchg, curDT:datetime): return None return self.day_cache[dtStr][exchg] + class WtMailNotifier: - ''' + """ 邮件通知器 - ''' - def __init__(self, user:str, pwd:str, sender:str=None, host:str="smtp.exmail.qq.com", port=465, isSSL:bool = True): + """ + + def __init__(self, user: str, pwd: str, sender: str = None, host: str = "smtp.exmail.qq.com", port=465, + isSSL: bool = True): self.user = user self.pwd = pwd - self.sender = sender if sender is not None else "WtHotNotifier<%s>" % (user) + self.sender = sender if sender is not None else "WtHotNotifier<%s>" % user self.receivers = list() self.mail_host = host self.mail_port = port self.mail_ssl = isSSL - def add_receiver(self, name:str, addr:str): - ''' + def add_receiver(self, name: str, addr: str): + """ 添加收件人 @name 收件人姓名 @addr 收件人邮箱地址 - ''' + """ self.receivers.append({ - "name":name, - "addr":addr + "name": name, + "addr": addr }) - def notify(self, hot_changes:dict, sec_changes:dict, nextDT:datetime.datetime, hotFile:str, hotMap:str, secFile:str, secMap:str): - ''' + def notify(self, hot_changes: dict, sec_changes: dict, nextDT: datetime.datetime, hotFile: str, hotMap: str, + secFile: str, secMap: str): + """ 通知主力切换事件 @hot_changes 当日主力切换的规则列表 @@ -540,9 +551,9 @@ def notify(self, hot_changes:dict, sec_changes:dict, nextDT:datetime.datetime, h @nextDT 生效日期 @hotFile 主力规则文件 @hotMap 主力映射文件 - ''' + """ dtStr = nextDT.strftime('%Y.%m.%d') - + import smtplib from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart @@ -556,41 +567,43 @@ def notify(self, hot_changes:dict, sec_changes:dict, nextDT:datetime.datetime, h for exchg in hot_changes: for pid in hot_changes[exchg]: item = hot_changes[exchg][pid][-1] - content += '品种%s.%s的主力合约已切换,下个交易日(%s)生效, %s -> %s\n' % (exchg, pid, dtStr, item["from"], item["to"]) + content += '品种%s.%s的主力合约已切换,下个交易日(%s)生效, %s -> %s\n' % ( + exchg, pid, dtStr, item["from"], item["to"]) content += '\n' for exchg in sec_changes: for pid in sec_changes[exchg]: item = sec_changes[exchg][pid][-1] - content += '品种%s.%s的次主力合约已切换,下个交易日(%s)生效, %s -> %s\n' % (exchg, pid, dtStr, item["from"], item["to"]) + content += '品种%s.%s的次主力合约已切换,下个交易日(%s)生效, %s -> %s\n' % ( + exchg, pid, dtStr, item["from"], item["to"]) msg_mp = MIMEMultipart() msg_mp['From'] = sender # 发送者 - - subject = '主力合约换月邮件<%s>' % (dtStr) + + subject = '主力合约换月邮件<%s>' % dtStr msg_mp['Subject'] = Header(subject, 'utf-8') content = MIMEText(content, 'plain', 'utf-8') msg_mp.attach(content) - xlspart = MIMEApplication(open(hotFile,'rb').read()) + xlspart = MIMEApplication(open(hotFile, 'rb').read()) xlspart["Content-Type"] = 'application/octet-stream' - xlspart.add_header('Content-Disposition','attachment', filename=os.path.basename(hotFile)) + xlspart.add_header('Content-Disposition', 'attachment', filename=os.path.basename(hotFile)) msg_mp.attach(xlspart) - xlspart = MIMEApplication(open(hotMap,'rb').read()) + xlspart = MIMEApplication(open(hotMap, 'rb').read()) xlspart["Content-Type"] = 'application/octet-stream' - xlspart.add_header('Content-Disposition','attachment', filename=os.path.basename(hotMap)) + xlspart.add_header('Content-Disposition', 'attachment', filename=os.path.basename(hotMap)) msg_mp.attach(xlspart) - xlspart = MIMEApplication(open(secFile,'rb').read()) + xlspart = MIMEApplication(open(secFile, 'rb').read()) xlspart["Content-Type"] = 'application/octet-stream' - xlspart.add_header('Content-Disposition','attachment', filename=os.path.basename(secFile)) + xlspart.add_header('Content-Disposition', 'attachment', filename=os.path.basename(secFile)) msg_mp.attach(xlspart) - xlspart = MIMEApplication(open(secMap,'rb').read()) + xlspart = MIMEApplication(open(secMap, 'rb').read()) xlspart["Content-Type"] = 'application/octet-stream' - xlspart.add_header('Content-Disposition','attachment', filename=os.path.basename(secMap)) + xlspart.add_header('Content-Disposition', 'attachment', filename=os.path.basename(secMap)) msg_mp.attach(xlspart) if self.mail_ssl: @@ -600,56 +613,58 @@ def notify(self, hot_changes:dict, sec_changes:dict, nextDT:datetime.datetime, h try: smtpObj.ehlo() - smtpObj.login(self.user, self.pwd) + smtpObj.login(self.user, self.pwd) logging.info("%s 登录成功 %s:%d", self.user, self.mail_host, self.mail_port) except smtplib.SMTPException as ex: logging.error("邮箱初始化失败:{}".format(ex)) for item in receivers: to = "%s<%s>" % (item["name"], item["addr"]) - msg_mp['To'] = Header(to, 'utf-8') # 接收者 + msg_mp['To'] = Header(to, 'utf-8') # 接收者 try: smtpObj.sendmail(sender, item["addr"], msg_mp.as_string()) logging.info("邮件发送失败,收件人: %s", to) except smtplib.SMTPException as ex: logging.error("邮件发送失败,收件人:{}, {}".format(to, ex)) + class WtHotPicker: - ''' + """ 主力选择器 - ''' - def __init__(self, markerFile:str = "./marker.json", hotFile:str = "../Common/hots.json", secFile:str = None): + """ + + def __init__(self, markerFile: str = "./marker.json", hotFile: str = "../Common/hots.json", secFile: str = None): self.marker_file = markerFile self.hot_file = hotFile self.sec_file = secFile - self.mail_notifier:WtMailNotifier = None - self.cache_monitor:WtCacheMon = None + self.mail_notifier: WtMailNotifier = None + self.cache_monitor: WtCacheMon = None self.current_hots = None self.current_secs = None - def set_cacher(self, cacher:WtCacheMon): - ''' + def set_cacher(self, cacher: WtCacheMon): + """ 设置日行情缓存器 - ''' + """ self.cache_monitor = cacher - - def set_mail_notifier(self, notifier:WtMailNotifier): - ''' + + def set_mail_notifier(self, notifier: WtMailNotifier): + """ 设置邮件通知器 - ''' + """ self.mail_notifier = notifier - def pick_exchg_hots(self, exchg:str, beginDT:datetime.datetime, endDT:datetime.datetime, alg:int = 0): - ''' + def pick_exchg_hots(self, exchg: str, beginDT: datetime.datetime, endDT: datetime.datetime, alg: int = 0): + """ 确定指定市场的主力合约 @exchg 交易所代码 @beginDT 开始日期 @endDT 截止日期 @alg 切换规则算法,0-除中金所外,按成交量确定,1-中金所,按照成交量和总持共同确定 - ''' + """ cacheMon = self.cache_monitor current_hots = self.current_hots @@ -678,9 +693,9 @@ def pick_exchg_hots(self, exchg:str, beginDT:datetime.datetime, endDT:datetime.d wd = curDT.weekday() fri_cnt = countFridays(curDT) cur_month = curDT.strftime('%Y%m')[2:] - next_month = int(cur_month)+1 + next_month = int(cur_month) + 1 if next_month % 100 == 13: - next_month = str(int(cur_month[:2])+1)+"01" + next_month = str(int(cur_month[:2]) + 1) + "01" else: next_month = str(next_month) @@ -697,29 +712,29 @@ def pick_exchg_hots(self, exchg:str, beginDT:datetime.datetime, endDT:datetime.d for pid in items_by_pid: ay = items_by_pid[pid] if alg == 1: - #ay.sort(key=functools.cmp_to_key(cmp_alg_01)) #按总持排序 - ay.sort(key=lambda x : x.volume) #按成交量 + # ay.sort(key=functools.cmp_to_key(cmp_alg_01)) #按总持排序 + ay.sort(key=lambda x: x.volume) # 按成交量 elif alg == 0: - ay.sort(key=lambda x : x.hold) #按总持 - + ay.sort(key=lambda x: x.hold) # 按总持 + if len(ay) > 1: hot = ay[-1] sec = ay[-2] - #中金所算法,如果是当月第三个周三,并且主力合约月份小于次主力合约月份, - #说明没有根据数据自动换月,强制进行换月 - if alg == 1 and wd == 2 and fri_cnt == 3 and hot.month==cur_month: + # 中金所算法,如果是当月第三个周三,并且主力合约月份小于次主力合约月份, + # 说明没有根据数据自动换月,强制进行换月 + if alg == 1 and wd == 2 and fri_cnt == 3 and hot.month == cur_month: for item in ay: if item.month == next_month: hot = item break - #如果主力合约月份大于等于次主力合约,则次主力递延一位 - if hot.month >= sec.month and len(ay)>=3: + # 如果主力合约月份大于等于次主力合约,则次主力递延一位 + if hot.month >= sec.month and len(ay) >= 3: sec = ay[-3] - - for i in range(-2,-len(ay),-1): + + for i in range(-2, -len(ay), -1): sec = ay[i] - #次主力合约月份大于等于次主力合约才可以 + # 次主力合约月份大于等于次主力合约才可以 if hot.month < sec.month: break if sec is not None and hot.month < sec.month: @@ -728,20 +743,17 @@ def pick_exchg_hots(self, exchg:str, beginDT:datetime.datetime, endDT:datetime.d else: # 如果这一天只有一个合约的信息,就没办法实现同时跟换主次月,跳过这一天,否则会出现主力换月,次主力没有换月的情况,导致某一天的主力次主力是同一个合约 continue - + # 生成换月表 for key in hots.keys(): nextDT = curDT + datetime.timedelta(days=1) if key not in lastHots: - item = {} - item["date"] = int(curDT.strftime('%Y%m%d')) - item["from"] = "" - item["to"] = hots[key] - item["oldclose"] = 0.0 - item["newclose"] = items[hots[key]].close + item = {"date": int(curDT.strftime('%Y%m%d')), "from": "", "to": hots[key], "oldclose": 0.0, + "newclose": items[hots[key]].close} hot_switches[key] = [item] lastHots[key] = hots[key] - logging.info("[%s]品种%s主力确认, 确认日期: %s, %s", exchg,key, nextDT.strftime('%Y%m%d'), hots[key]) + logging.info("[%s]品种%s主力确认, 确认日期: %s, %s", exchg, key, nextDT.strftime('%Y%m%d'), + hots[key]) else: oldcode = lastHots[key] newcode = hots[key] @@ -750,10 +762,7 @@ def pick_exchg_hots(self, exchg:str, beginDT:datetime.datetime, endDT:datetime.d oldItem = items[oldcode] newItem = items[newcode] if oldItem is None or newItem.month > oldItem.month: - item = {} - item["date"] = int(nextDT.strftime('%Y%m%d')) - item["from"] = oldcode - item["to"] = newcode + item = {"date": int(nextDT.strftime('%Y%m%d')), "from": oldcode, "to": newcode} if oldcode in items: item["oldclose"] = items[oldcode].close else: @@ -763,21 +772,19 @@ def pick_exchg_hots(self, exchg:str, beginDT:datetime.datetime, endDT:datetime.d if key not in hot_switches: hot_switches[key] = list() hot_switches[key].append(item) - logging.info("[%s]品种%s主力切换 切换日期: %s,%s -> %s", exchg, key, nextDT.strftime('%Y%m%d'), lastHots[key], hots[key]) + logging.info("[%s]品种%s主力切换 切换日期: %s,%s -> %s", exchg, key, + nextDT.strftime('%Y%m%d'), lastHots[key], hots[key]) lastHots[key] = hots[key] for key in seconds.keys(): nextDT = curDT + datetime.timedelta(days=1) if key not in lastSecs: - item = {} - item["date"] = int(curDT.strftime('%Y%m%d')) - item["from"] = "" - item["to"] = seconds[key] - item["oldclose"] = 0.0 - item["newclose"] = items[seconds[key]].close + item = {"date": int(curDT.strftime('%Y%m%d')), "from": "", "to": seconds[key], "oldclose": 0.0, + "newclose": items[seconds[key]].close} sec_switches[key] = [item] lastSecs[key] = seconds[key] - logging.info("[%s]品种%s次主力确认, 确认日期: %s, %s", exchg,key, nextDT.strftime('%Y%m%d'), seconds[key]) + logging.info("[%s]品种%s次主力确认, 确认日期: %s, %s", exchg, key, nextDT.strftime('%Y%m%d'), + seconds[key]) else: oldcode = lastSecs[key] newcode = seconds[key] @@ -786,10 +793,7 @@ def pick_exchg_hots(self, exchg:str, beginDT:datetime.datetime, endDT:datetime.d oldItem = items[oldcode] newItem = items[newcode] if oldItem is None or newItem.month > oldItem.month: - item = {} - item["date"] = int(nextDT.strftime('%Y%m%d')) - item["from"] = oldcode - item["to"] = newcode + item = {"date": int(nextDT.strftime('%Y%m%d')), "from": oldcode, "to": newcode} if oldcode in items: item["oldclose"] = items[oldcode].close else: @@ -799,25 +803,26 @@ def pick_exchg_hots(self, exchg:str, beginDT:datetime.datetime, endDT:datetime.d if key not in sec_switches: sec_switches[key] = list() sec_switches[key].append(item) - logging.info("[%s]品种%s次主力切换 切换日期: %s,%s -> %s", exchg, key, nextDT.strftime('%Y%m%d'), lastSecs[key], seconds[key]) + logging.info("[%s]品种%s次主力切换 切换日期: %s,%s -> %s", exchg, key, + nextDT.strftime('%Y%m%d'), lastSecs[key], seconds[key]) lastSecs[key] = seconds[key] # 日期递增 curDT = curDT + datetime.timedelta(days=1) - return hot_switches,sec_switches - + return hot_switches, sec_switches + def merge_switch_list(self, total, exchg, switch_list): - ''' + """ 合并主力切换规则 @total 已有的全部切换规则 @exchg 交易所代码 @switcg_list 新的切换规则 - ''' + """ if exchg not in total: total[exchg] = switch_list - logging.info("[%s]全市场主力切换规则重构" % (exchg)) + logging.info("[%s]全市场主力切换规则重构" % exchg) return True, total - + bChanged = False for pid in switch_list: if pid not in total[exchg]: @@ -830,8 +835,9 @@ def merge_switch_list(self, total, exchg, switch_list): bChanged = True return bChanged, total - def execute_rebuild(self, beginDate:datetime.datetime = None, endDate:datetime.datetime = None, exchanges = ["CFFEX", "SHFE", "CZCE", "DCE", "INE"], wait=False): - ''' + def execute_rebuild(self, beginDate: datetime.datetime = None, endDate: datetime.datetime = None, + exchanges=["CFFEX", "SHFE", "CZCE", "DCE", "INE"], wait=False): + """ 重构全部的主力切换规则 不依赖现有数据,全部重新确定主力合约的切换规则 @@ -839,13 +845,13 @@ def execute_rebuild(self, beginDate:datetime.datetime = None, endDate:datetime.d @endDate 截止日期 @exchanges 要重构的交易所列表 @wait 每个日期切换是否等待,等待时间1s,主要针对从交易所官网拉取,防止被拉黑名单 - ''' + """ if endDate is None: endDate = datetime.datetime.now() if beginDate is None: beginDate = datetime.datetime.strptime("2016-01-01", '%Y-%m-%d') - + total_hots = dict() total_secs = dict() @@ -855,7 +861,7 @@ def execute_rebuild(self, beginDate:datetime.datetime = None, endDate:datetime.d for exchg in exchanges: self.current_hots[exchg] = dict() self.current_secs[exchg] = dict() - + hot_changes = dict() sec_changes = dict() curDate = beginDate @@ -863,18 +869,18 @@ def execute_rebuild(self, beginDate:datetime.datetime = None, endDate:datetime.d if wait: time.sleep(1) for exchg in exchanges: - alg = 1 if exchg=='CFFEX' else 0 # 中金所的换月算法和其他交易所不同 - hotRules,secRules = self.pick_exchg_hots(exchg, curDate, curDate, alg=alg) + alg = 1 if exchg == 'CFFEX' else 0 # 中金所的换月算法和其他交易所不同 + hotRules, secRules = self.pick_exchg_hots(exchg, curDate, curDate, alg=alg) if len(hotRules.keys()) > 0: - hasChange,total_hots = self.merge_switch_list(total_hots, exchg, hotRules) + hasChange, total_hots = self.merge_switch_list(total_hots, exchg, hotRules) if exchg not in hot_changes: hot_changes[exchg] = dict() hot_changes[exchg].update(hotRules) if len(secRules.keys()) > 0: - hasChange,total_secs = self.merge_switch_list(total_secs, exchg, secRules) + hasChange, total_secs = self.merge_switch_list(total_secs, exchg, secRules) if exchg not in sec_changes: sec_changes[exchg] = dict() @@ -882,45 +888,45 @@ def execute_rebuild(self, beginDate:datetime.datetime = None, endDate:datetime.d curDate = curDate + datetime.timedelta(days=1) - #日期标记要保存 + # 日期标记要保存 marker = dict() marker["date"] = int(endDate.strftime('%Y%m%d')) output = open(self.marker_file, 'w') - output.write(json.dumps(marker, sort_keys=True, indent = 4)) + output.write(json.dumps(marker, sort_keys=True, indent=4)) output.close() - + logging.info("主力切换规则已更新") output = open(self.hot_file, 'w') - output.write(json.dumps(total_hots, sort_keys=True, indent = 4)) + output.write(json.dumps(total_hots, sort_keys=True, indent=4)) output.close() if self.sec_file is not None: output = open(self.sec_file, 'w') - output.write(json.dumps(total_secs, sort_keys=True, indent = 4)) + output.write(json.dumps(total_secs, sort_keys=True, indent=4)) output.close() output = open("hotmap.json", 'w') - output.write(json.dumps(self.current_hots, sort_keys=True, indent = 4)) + output.write(json.dumps(self.current_hots, sort_keys=True, indent=4)) output.close() output = open("secmap.json", 'w') - output.write(json.dumps(self.current_secs, sort_keys=True, indent = 4)) + output.write(json.dumps(self.current_secs, sort_keys=True, indent=4)) output.close() if self.mail_notifier is not None: self.mail_notifier.notify(hot_changes, sec_changes, endDate, hotFile, "hotmap.json", secFile, "secmap.json") - return total_hots,total_secs - - def execute_increment(self, endDate:datetime.datetime = None, exchanges = ["CFFEX", "SHFE", "CZCE", "DCE", "INE"]): - ''' + return total_hots, total_secs + + def execute_increment(self, endDate: datetime.datetime = None, exchanges=["CFFEX", "SHFE", "CZCE", "DCE", "INE"]): + """ 增量更新主力切换规则 会自动加载marker.json取得上次更新的日期,并读取hots.json确定当前的映射规则 @endDate 截止日期 @exchanges 要重构的交易所列表 - ''' + """ if endDate is None: endDate = datetime.datetime.now() @@ -929,7 +935,7 @@ def execute_increment(self, endDate:datetime.datetime = None, exchanges = ["CFFE hotFile = self.hot_file secFile = self.sec_file - marker = {"date":"0"} + marker = {"date": "0"} c = readFileContent(markerFile) if len(c) > 0: marker = json.loads(c) @@ -956,7 +962,7 @@ def execute_increment(self, endDate:datetime.datetime = None, exchanges = ["CFFE beginDT = datetime.datetime.strptime(lastDate, "%Y%m%d") + datetime.timedelta(days=1) else: beginDT = datetime.datetime.strptime("2016-01-01", '%Y-%m-%d') - + self.current_hots = dict() self.current_secs = dict() @@ -975,53 +981,53 @@ def execute_increment(self, endDate:datetime.datetime = None, exchanges = ["CFFE for pid in total_secs[exchg]: ay = total_secs[exchg][pid] self.current_secs[exchg][pid] = ay[-1]["to"] - + bChanged = False hot_changes = dict() sec_changes = dict() for exchg in exchanges: logging.info("[%s]开始分析主力换月数据" % exchg) - alg = 1 if exchg=='CFFEX' else 0 # 中金所的换月算法和其他交易所不同 - hotRules,secRules = self.pick_exchg_hots(exchg, beginDT, endDate, alg=alg) + alg = 1 if exchg == 'CFFEX' else 0 # 中金所的换月算法和其他交易所不同 + hotRules, secRules = self.pick_exchg_hots(exchg, beginDT, endDate, alg=alg) if len(hotRules.keys()) > 0: - hasChange,total_hots = self.merge_switch_list(total_hots, exchg, hotRules) - bChanged = bChanged or hasChange + hasChange, total_hots = self.merge_switch_list(total_hots, exchg, hotRules) + bChanged = bChanged or hasChange hot_changes[exchg] = hotRules if len(secRules.keys()) > 0: - hasChange,total_secs = self.merge_switch_list(total_secs, exchg, secRules) - bChanged = bChanged or hasChange + hasChange, total_secs = self.merge_switch_list(total_secs, exchg, secRules) + bChanged = bChanged or hasChange sec_changes[exchg] = secRules - - #日期标记要保存 + # 日期标记要保存 marker = dict() marker["date"] = int(endDate.strftime('%Y%m%d')) output = open(markerFile, 'w') - output.write(json.dumps(marker, sort_keys=True, indent = 4)) + output.write(json.dumps(marker, sort_keys=True, indent=4)) output.close() - + if bChanged: logging.info("主力切换规则已更新") output = open(hotFile, 'w') - output.write(json.dumps(total_hots, sort_keys=True, indent = 4)) + output.write(json.dumps(total_hots, sort_keys=True, indent=4)) output.close() output = open(secFile, 'w') - output.write(json.dumps(total_secs, sort_keys=True, indent = 4)) + output.write(json.dumps(total_secs, sort_keys=True, indent=4)) output.close() output = open("hotmap.json", 'w') - output.write(json.dumps(self.current_hots, sort_keys=True, indent = 4)) + output.write(json.dumps(self.current_hots, sort_keys=True, indent=4)) output.close() output = open("secmap.json", 'w') - output.write(json.dumps(self.current_secs, sort_keys=True, indent = 4)) + output.write(json.dumps(self.current_secs, sort_keys=True, indent=4)) output.close() if self.mail_notifier is not None: - self.mail_notifier.notify(hot_changes, sec_changes, endDate, hotFile, "hotmap.json", secFile, "secmap.json") + self.mail_notifier.notify(hot_changes, sec_changes, endDate, hotFile, "hotmap.json", secFile, + "secmap.json") else: logging.info("主力切换规则未更新,不保存数据") diff --git a/wtpy/monitor/DataMgr.py b/wtpy/monitor/DataMgr.py index 14e6f826..90c0feff 100644 --- a/wtpy/monitor/DataMgr.py +++ b/wtpy/monitor/DataMgr.py @@ -6,6 +6,7 @@ import datetime from .WtLogger import WtLogger + def backup_file(filename): if not os.path.exists(filename): return @@ -20,20 +21,21 @@ def backup_file(filename): import shutil shutil.copy(filename, target) + class DataMgr: - def __init__(self, datafile:str="mondata.db", logger:WtLogger=None): + def __init__(self, datafile: str = "mondata.db", logger: WtLogger = None): self.__grp_cache__ = dict() self.__logger__ = logger self.__db_conn__ = sqlite3.connect(datafile, check_same_thread=False) self.__check_db__() - #加载组合列表 + # 加载组合列表 cur = self.__db_conn__.cursor() self.__config__ = { - "groups":{}, - "users":{} + "groups": {}, + "users": {} } for row in cur.execute("SELECT * FROM groups;"): @@ -48,7 +50,8 @@ def __init__(self, datafile:str="mondata.db", logger:WtLogger=None): grpInfo["mqurl"] = row[8] self.__config__["groups"][grpInfo["id"]] = grpInfo - for row in cur.execute("SELECT id,loginid,name,role,passwd,iplist,remark,createby,createtime,modifyby,modifytime,products FROM users;"): + for row in cur.execute( + "SELECT id,loginid,name,role,passwd,iplist,remark,createby,createtime,modifyby,modifytime,products FROM users;"): usrInfo = dict() usrInfo["loginid"] = row[1] usrInfo["name"] = row[2] @@ -79,7 +82,7 @@ def __check_db__(self): tables = [] for row in cur.execute("select name from sqlite_master where type='table' order by name"): tables.append(row[0]) - + if "actions" not in tables: sql = "CREATE TABLE [actions] (\n" sql += "[id] INTEGER PRIMARY KEY autoincrement, \n" @@ -166,7 +169,7 @@ def __check_cache__(self, grpid, grpInfo): bNeedReset = True else: td = now - cache_time - if td.total_seconds() >= 60:# 上次缓存时间超过60s,则重新读取 + if td.total_seconds() >= 60: # 上次缓存时间超过60s,则重新读取 bNeedReset = True if bNeedReset: @@ -186,21 +189,22 @@ def __check_cache__(self, grpid, grpInfo): f.close() self.__grp_cache__[grpid] = { - "strategies":marker["marks"], - "channels":marker["channels"] - } + "strategies": marker["marks"], + "channels": marker["channels"] + } if "executers" in marker: self.__grp_cache__[grpid]["executers"] = marker["executers"] else: self.__grp_cache__[grpid]["executers"] = [] - except: + except Exception as e: + print(e) self.__grp_cache__[grpid] = { - "strategies":[], - "channels":[], - "executers":[] - } + "strategies": [], + "channels": [], + "executers": [] + } if "strategies" in self.__grp_cache__[grpid]: self.__grp_cache__[grpid]["strategies"].sort() @@ -213,7 +217,7 @@ def __check_cache__(self, grpid, grpInfo): self.__grp_cache__[grpid]["cachetime"] = now - def get_groups(self, tpfilter:str=''): + def get_groups(self, tpfilter: str = ''): ret = [] for grpid in self.__config__["groups"]: grpinfo = self.__config__["groups"][grpid] @@ -221,19 +225,19 @@ def get_groups(self, tpfilter:str=''): ret.append(grpinfo) elif grpinfo["gtype"] == tpfilter: ret.append(grpinfo) - + return ret - def has_group(self, grpid:str): - return (grpid in self.__config__["groups"]) + def has_group(self, grpid: str): + return grpid in self.__config__["groups"] - def get_group(self, grpid:str) -> dict: + def get_group(self, grpid: str) -> dict: if grpid in self.__config__["groups"]: return self.__config__["groups"][grpid] else: return None - def get_group_cfg(self, grpid:str): + def get_group_cfg(self, grpid: str): if grpid not in self.__config__["groups"]: return "{}" else: @@ -252,7 +256,7 @@ def get_group_cfg(self, grpid:str): else: return json.loads(content) - def set_group_cfg(self, grpid:str, config:dict): + def set_group_cfg(self, grpid: str, config: dict): if grpid not in self.__config__["groups"]: return False else: @@ -271,7 +275,7 @@ def set_group_cfg(self, grpid:str, config:dict): f.close() return True - def get_group_entry(self, grpid:str): + def get_group_entry(self, grpid: str): if grpid not in self.__config__["groups"]: return "{}" else: @@ -283,7 +287,7 @@ def get_group_entry(self, grpid:str): f.close() return content - def set_group_entry(self, grpid:str, content:str): + def set_group_entry(self, grpid: str, content: str): if grpid not in self.__config__["groups"]: return False else: @@ -296,7 +300,7 @@ def set_group_entry(self, grpid:str, content:str): f.close() return True - def add_group(self, grpInfo:dict): + def add_group(self, grpInfo: dict): grpid = grpInfo["id"] isNewGrp = not (grpid in self.__config__["groups"]) @@ -306,10 +310,12 @@ def add_group(self, grpInfo:dict): sql = '' if isNewGrp: sql = "INSERT INTO groups(groupid,name,path,info,gtype,datmod,env,mqurl) VALUES('%s','%s','%s','%s','%s','%s','%s','%s');" \ - % (grpid, grpInfo["name"], grpInfo["path"], grpInfo["info"], grpInfo["gtype"], grpInfo["datmod"], grpInfo["env"], grpInfo["mqurl"]) + % (grpid, grpInfo["name"], grpInfo["path"], grpInfo["info"], grpInfo["gtype"], grpInfo["datmod"], + grpInfo["env"], grpInfo["mqurl"]) else: sql = "UPDATE groups SET name='%s',path='%s',info='%s',gtype='%s',datmod='%s',env='%s',mqurl='%s',modifytime=datetime('now','localtime') WHERE groupid='%s';" \ - % (grpInfo["name"], grpInfo["path"], grpInfo["info"], grpInfo["gtype"], grpInfo["datmod"], grpInfo["env"], grpInfo["mqurl"], grpid) + % (grpInfo["name"], grpInfo["path"], grpInfo["info"], grpInfo["gtype"], grpInfo["datmod"], + grpInfo["env"], grpInfo["mqurl"], grpid) cur.execute(sql) self.__db_conn__.commit() bSucc = True @@ -321,20 +327,20 @@ def add_group(self, grpInfo:dict): return bSucc - def del_group(self, grpid:str): + def del_group(self, grpid: str): if grpid in self.__config__["groups"]: self.__config__["groups"].pop(grpid) - + cur = self.__db_conn__.cursor() - cur.execute("DELETE FROM groups WHERE groupid='%s';" % (grpid)) + cur.execute("DELETE FROM groups WHERE groupid='%s';" % grpid) self.__db_conn__.commit() def get_users(self): ret = [] for loginid in self.__config__["users"]: usrInfo = self.__config__["users"][loginid] - ret.append(usrInfo.copy()) - + ret.append(usrInfo.copy()) + return ret def add_user(self, usrInfo, admin): @@ -344,39 +350,42 @@ def add_user(self, usrInfo, admin): cur = self.__db_conn__.cursor() now = datetime.datetime.now() if isNewUser: - encpwd = hashlib.md5((loginid+usrInfo["passwd"]).encode("utf-8")).hexdigest() + encpwd = hashlib.md5((loginid + usrInfo["passwd"]).encode("utf-8")).hexdigest() usrInfo["passwd"] = encpwd usrInfo["createby"] = admin usrInfo["modifyby"] = admin usrInfo["createtime"] = now.strftime("%Y-%m-%d %H:%M:%S") usrInfo["modifytime"] = now.strftime("%Y-%m-%d %H:%M:%S") - cur.execute("INSERT INTO users(loginid,name,role,passwd,iplist,products,remark,createby,modifyby) VALUES(?,?,?,?,?,?,?,?,?);", - (loginid, usrInfo["name"], usrInfo["role"], encpwd, usrInfo["iplist"], usrInfo["products"], usrInfo["remark"], admin, admin)) + cur.execute( + "INSERT INTO users(loginid,name,role,passwd,iplist,products,remark,createby,modifyby) VALUES(?,?,?,?,?,?,?,?,?);", + (loginid, usrInfo["name"], usrInfo["role"], encpwd, usrInfo["iplist"], usrInfo["products"], + usrInfo["remark"], admin, admin)) else: usrInfo["modifyby"] = admin usrInfo["modifytime"] = now.strftime("%Y-%m-%d %H:%M:%S") - cur.execute("UPDATE users SET name=?,role=?,iplist=?,products=?,remark=?,modifyby=?,modifytime=datetime('now','localtime') WHERE loginid=?;", - (usrInfo["name"], usrInfo["role"], usrInfo["iplist"], usrInfo["products"], usrInfo["remark"], admin, loginid)) + cur.execute( + "UPDATE users SET name=?,role=?,iplist=?,products=?,remark=?,modifyby=?,modifytime=datetime('now','localtime') WHERE loginid=?;", + (usrInfo["name"], usrInfo["role"], usrInfo["iplist"], usrInfo["products"], usrInfo["remark"], admin, + loginid)) self.__db_conn__.commit() if loginid in self.__config__["users"]: self.__config__["users"][loginid]["modifyby"] = admin self.__config__["users"][loginid]["modifytime"] = usrInfo["modifytime"] - def mod_user_pwd(self, loginid:str, newpwd:str, admin:str): + def mod_user_pwd(self, loginid: str, newpwd: str, admin: str): cur = self.__db_conn__.cursor() - cur.execute("UPDATE users SET passwd=?,modifyby=?,modifytime=datetime('now','localtime') WHERE loginid=?;", - (newpwd,admin,loginid)) + cur.execute("UPDATE users SET passwd=?,modifyby=?,modifytime=datetime('now','localtime') WHERE loginid=?;", + (newpwd, admin, loginid)) self.__db_conn__.commit() - self.__config__["users"][loginid]["passwd"]=newpwd - + self.__config__["users"][loginid]["passwd"] = newpwd def del_user(self, loginid, admin): if loginid in self.__config__["users"]: self.__config__["users"].pop(loginid) - + cur = self.__db_conn__.cursor() - cur.execute("DELETE FROM users WHERE loginid='%s';" % (loginid)) + cur.execute("DELETE FROM users WHERE loginid='%s';" % loginid) self.__db_conn__.commit() return True else: @@ -385,28 +394,28 @@ def del_user(self, loginid, admin): def log_action(self, adminInfo, atype, remark): cur = self.__db_conn__.cursor() sql = "INSERT INTO actions(loginid,actiontime,actionip,actiontype,remark) VALUES('%s',datetime('now','localtime'),'%s','%s','%s');" % ( - adminInfo["loginid"], adminInfo["loginip"], atype, remark) + adminInfo["loginid"], adminInfo["loginip"], atype, remark) cur.execute(sql) self.__db_conn__.commit() - def get_user(self, loginid:str): + def get_user(self, loginid: str): if loginid in self.__config__["users"]: return self.__config__["users"][loginid].copy() elif loginid == 'superman': return { - "loginid":loginid, - "name":"超管", - "role":"superman", - "passwd":"25ed305a56504e95fd1ca9900a1da174", - "iplist":"", - "remark":"内置超管账号", - 'builtin':True, - 'products':'' + "loginid": loginid, + "name": "超管", + "role": "superman", + "passwd": "25ed305a56504e95fd1ca9900a1da174", + "iplist": "", + "remark": "内置超管账号", + 'builtin': True, + 'products': '' } else: return None - def get_strategies(self, grpid:str): + def get_strategies(self, grpid: str): if grpid not in self.__config__["groups"]: return [] @@ -415,22 +424,22 @@ def get_strategies(self, grpid:str): if "strategies" not in self.__grp_cache__[grpid]: return [] - + return self.__grp_cache__[grpid]["strategies"] - def get_channels(self, grpid:str): + def get_channels(self, grpid: str): if grpid not in self.__config__["groups"]: return [] grpInfo = self.__config__["groups"][grpid] self.__check_cache__(grpid, grpInfo) - + if "channels" not in self.__grp_cache__[grpid]: return [] return self.__grp_cache__[grpid]["channels"] - def get_trades(self, grpid:str, straid:str, limit:int = 200): + def get_trades(self, grpid: str, straid: str, limit: int = 200): if grpid not in self.__config__["groups"]: return [] @@ -439,15 +448,15 @@ def get_trades(self, grpid:str, straid:str, limit:int = 200): if "strategies" not in self.__grp_cache__[grpid]: return [] - + if straid not in self.__grp_cache__[grpid]["strategies"]: return [] if "trades" not in self.__grp_cache__[grpid]: self.__grp_cache__[grpid]["trades"] = dict() - + if straid not in self.__grp_cache__[grpid]["trades"]: - filepath = "./generated/outputs/%s/trades.csv" % (straid) + filepath = "./generated/outputs/%s/trades.csv" % straid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): return [] @@ -463,7 +472,7 @@ def get_trades(self, grpid:str, straid:str, limit:int = 200): last_row = trdCache["lastrow"] lines = f.readlines() f.close() - lines = lines[1+last_row:] + lines = lines[1 + last_row:] for line in lines: cells = line.split(",") @@ -471,7 +480,7 @@ def get_trades(self, grpid:str, straid:str, limit:int = 200): continue tItem = { - "strategy":straid, + "strategy": straid, "code": cells[0], "time": int(cells[1]), "direction": cells[2], @@ -487,10 +496,10 @@ def get_trades(self, grpid:str, straid:str, limit:int = 200): trdCache["trades"].append(tItem) trdCache["lastrow"] += 1 - + return trdCache["trades"][-limit:] - def get_funds(self, grpid:str, straid:str): + def get_funds(self, grpid: str, straid: str): if grpid not in self.__config__["groups"]: return [] @@ -499,16 +508,16 @@ def get_funds(self, grpid:str, straid:str): if "strategies" not in self.__grp_cache__[grpid]: return [] - + if straid != "all": if straid not in self.__grp_cache__[grpid]["strategies"]: return [] if "funds" not in self.__grp_cache__[grpid]: self.__grp_cache__[grpid]["funds"] = dict() - + if straid not in self.__grp_cache__[grpid]["funds"]: - filepath = "./generated/outputs/%s/funds.csv" % (straid) + filepath = "./generated/outputs/%s/funds.csv" % straid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): return [] @@ -525,7 +534,7 @@ def get_funds(self, grpid:str, straid:str): last_row = trdCache["lastrow"] lines = f.readlines() f.close() - lines = lines[1+last_row:] + lines = lines[1 + last_row:] for line in lines: cells = line.split(",") @@ -533,7 +542,7 @@ def get_funds(self, grpid:str, straid:str): continue tItem = { - "strategy":straid, + "strategy": straid, "date": int(cells[0]), "closeprofit": float(cells[1]), "dynprofit": float(cells[2]), @@ -555,7 +564,7 @@ def get_funds(self, grpid:str, straid:str): last_date = 0 # 这里再更新一条实时数据 - filepath = "./generated/stradata/%s.json" % (straid) + filepath = "./generated/stradata/%s.json" % straid filepath = os.path.join(grpInfo["path"], filepath) f = open(filepath, "r") try: @@ -564,28 +573,29 @@ def get_funds(self, grpid:str, straid:str): fund = json_data["fund"] if fund["tdate"] > last_date: ret.append({ - "strategy":straid, + "strategy": straid, "date": fund["tdate"], "closeprofit": fund["total_profit"], "dynprofit": fund["total_dynprofit"], "dynbalance": fund["total_profit"] + fund["total_dynprofit"] - fund["total_fees"], "fee": fund["total_fees"] }) - except: + except Exception as e: + print(e) pass f.close() - + return ret else: ret = list() for straid in self.__grp_cache__[grpid]["strategies"]: - filepath = "./generated/outputs/%s/funds.csv" % (straid) + filepath = "./generated/outputs/%s/funds.csv" % straid filepath = os.path.join(grpInfo["path"], filepath) f = open(filepath, "r") lines = f.readlines() f.close() - filepath = "./generated/stradata/%s.json" % (straid) + filepath = "./generated/stradata/%s.json" % straid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): continue @@ -594,12 +604,12 @@ def get_funds(self, grpid:str, straid:str): content = f.read() f.close() try: - + json_data = json.loads(content) fund = json_data["fund"] curDate = fund["tdate"] item = { - "strategy":straid, + "strategy": straid, "date": fund["tdate"], "closeprofit": fund["total_profit"], "dynprofit": fund["total_dynprofit"], @@ -617,16 +627,17 @@ def get_funds(self, grpid:str, straid:str): prebalance = float(cells[3]) prefee = float(cells[4]) - item['profit'] = item['closeprofit']-preprof + item['profit'] = item['closeprofit'] - preprof item['thisfee'] = item['fee'] - prefee - item['addition'] = item['dynbalance']-prebalance + item['addition'] = item['dynbalance'] - prebalance ret.append(item) - except: + except Exception as e: + print(e) pass return ret - def get_signals(self, grpid:str, straid:str, limit:int = 200): + def get_signals(self, grpid: str, straid: str, limit: int = 200): if grpid not in self.__config__["groups"]: return [] @@ -635,15 +646,15 @@ def get_signals(self, grpid:str, straid:str, limit:int = 200): if "strategies" not in self.__grp_cache__[grpid]: return [] - + if straid not in self.__grp_cache__[grpid]["strategies"]: return [] if "signals" not in self.__grp_cache__[grpid]: self.__grp_cache__[grpid]["signals"] = dict() - + if straid not in self.__grp_cache__[grpid]["signals"]: - filepath = "./generated/outputs/%s/signals.csv" % (straid) + filepath = "./generated/outputs/%s/signals.csv" % straid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): return [] @@ -660,13 +671,13 @@ def get_signals(self, grpid:str, straid:str, limit:int = 200): last_row = trdCache["lastrow"] lines = f.readlines() f.close() - lines = lines[1+last_row:] + lines = lines[1 + last_row:] for line in lines: cells = line.split(",") tItem = { - "strategy":straid, + "strategy": straid, "code": cells[0], "target": float(cells[1]), "sigprice": float(cells[2]), @@ -676,10 +687,10 @@ def get_signals(self, grpid:str, straid:str, limit:int = 200): trdCache["signals"].append(tItem) - trdCache["lastrow"] += len(lines) + trdCache["lastrow"] += len(lines) return trdCache["signals"][-limit:] - def get_rounds(self, grpid:str, straid:str, limit:int = 200): + def get_rounds(self, grpid: str, straid: str, limit: int = 200): if grpid not in self.__config__["groups"]: return [] @@ -688,15 +699,15 @@ def get_rounds(self, grpid:str, straid:str, limit:int = 200): if "strategies" not in self.__grp_cache__[grpid]: return [] - + if straid not in self.__grp_cache__[grpid]["strategies"]: return [] if "rounds" not in self.__grp_cache__[grpid]: self.__grp_cache__[grpid]["rounds"] = dict() - + if straid not in self.__grp_cache__[grpid]["rounds"]: - filepath = "./generated/outputs/%s/closes.csv" % (straid) + filepath = "./generated/outputs/%s/closes.csv" % straid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): return [] @@ -712,13 +723,13 @@ def get_rounds(self, grpid:str, straid:str, limit:int = 200): last_row = trdCache["lastrow"] lines = f.readlines() f.close() - lines = lines[1+last_row:] + lines = lines[1 + last_row:] for line in lines: cells = line.split(",") tItem = { - "strategy":straid, + "strategy": straid, "code": cells[0], "direct": cells[1], "opentime": int(cells[2]), @@ -733,10 +744,10 @@ def get_rounds(self, grpid:str, straid:str, limit:int = 200): trdCache["rounds"].append(tItem) trdCache["lastrow"] += len(lines) - + return trdCache["rounds"][-limit:] - def get_positions(self, grpid:str, straid:str): + def get_positions(self, grpid: str, straid: str): if grpid not in self.__config__["groups"]: return [] @@ -745,17 +756,17 @@ def get_positions(self, grpid:str, straid:str): if "strategies" not in self.__grp_cache__[grpid]: return [] - + ret = list() if straid != "all": if straid not in self.__grp_cache__[grpid]["strategies"]: return [] - - filepath = "./generated/stradata/%s.json" % (straid) + + filepath = "./generated/stradata/%s.json" % straid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): return [] - + f = open(filepath, "r") try: content = f.read() @@ -774,17 +785,18 @@ def get_positions(self, grpid:str, straid:str): dItem["volume"] = dItem["volumn"] dItem.pop("volumn") ret.append(dItem) - except: + except Exception as e: + print(e) pass f.close() else: for straid in self.__grp_cache__[grpid]["strategies"]: - filepath = "./generated/stradata/%s.json" % (straid) + filepath = "./generated/stradata/%s.json" % straid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): continue - + f = open(filepath, "r") try: content = f.read() @@ -803,13 +815,14 @@ def get_positions(self, grpid:str, straid:str): dItem["volume"] = dItem["volumn"] dItem.pop("volumn") ret.append(dItem) - except: + except Exception as e: + print(e) pass f.close() return ret - def get_channel_orders(self, grpid:str, chnlid:str, limit:int = 200): + def get_channel_orders(self, grpid: str, chnlid: str, limit: int = 200): if grpid not in self.__config__["groups"]: return [] @@ -818,15 +831,15 @@ def get_channel_orders(self, grpid:str, chnlid:str, limit:int = 200): if "channels" not in self.__grp_cache__[grpid]: return [] - + if chnlid not in self.__grp_cache__[grpid]["channels"]: return [] if "corders" not in self.__grp_cache__[grpid]: self.__grp_cache__[grpid]["corders"] = dict() - + if chnlid not in self.__grp_cache__[grpid]["corders"]: - filepath = "./generated/traders/%s/orders.csv" % (chnlid) + filepath = "./generated/traders/%s/orders.csv" % chnlid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): return [] @@ -839,19 +852,19 @@ def get_channel_orders(self, grpid:str, chnlid:str, limit:int = 200): trdCache = self.__grp_cache__[grpid]["corders"][chnlid] - f = open(trdCache["file"], "r",encoding="gb2312",errors="ignore") + f = open(trdCache["file"], "r", encoding="gb2312", errors="ignore") last_row = trdCache["lastrow"] lines = f.readlines() f.close() - lines = lines[1+last_row:] + lines = lines[1 + last_row:] for line in lines: cells = line.split(",") tItem = { - "channel":chnlid, - "localid":int(cells[0]), - "time":int(cells[2]), + "channel": chnlid, + "localid": int(cells[0]), + "time": int(cells[2]), "code": cells[3], "action": cells[4], "total": float(cells[5]), @@ -863,10 +876,10 @@ def get_channel_orders(self, grpid:str, chnlid:str, limit:int = 200): } trdCache["corders"].append(tItem) - + return trdCache["corders"][-limit:] - def get_channel_trades(self, grpid:str, chnlid:str, limit:int = 200): + def get_channel_trades(self, grpid: str, chnlid: str, limit: int = 200): if grpid not in self.__config__["groups"]: return [] @@ -875,15 +888,15 @@ def get_channel_trades(self, grpid:str, chnlid:str, limit:int = 200): if "channels" not in self.__grp_cache__[grpid]: return [] - + if chnlid not in self.__grp_cache__[grpid]["channels"]: return [] if "ctrades" not in self.__grp_cache__[grpid]: self.__grp_cache__[grpid]["ctrades"] = dict() - + if chnlid not in self.__grp_cache__[grpid]["ctrades"]: - filepath = "./generated/traders/%s/trades.csv" % (chnlid) + filepath = "./generated/traders/%s/trades.csv" % chnlid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): return [] @@ -896,19 +909,19 @@ def get_channel_trades(self, grpid:str, chnlid:str, limit:int = 200): trdCache = self.__grp_cache__[grpid]["ctrades"][chnlid] - f = open(trdCache["file"], "r",encoding="gb2312") + f = open(trdCache["file"], "r", encoding="gb2312") last_row = trdCache["lastrow"] lines = f.readlines() f.close() - lines = lines[1+last_row:] + lines = lines[1 + last_row:] for line in lines: cells = line.split(",") tItem = { - "channel":chnlid, - "localid":int(cells[0]), - "time":int(cells[2]), + "channel": chnlid, + "localid": int(cells[0]), + "time": int(cells[2]), "code": cells[3], "action": cells[4], "volume": float(cells[5]), @@ -918,10 +931,10 @@ def get_channel_trades(self, grpid:str, chnlid:str, limit:int = 200): } trdCache["ctrades"].append(tItem) - + return trdCache["ctrades"][-limit:] - def get_channel_positions(self, grpid:str, chnlid:str): + def get_channel_positions(self, grpid: str, chnlid: str): if self.__config__ is None: return [] @@ -936,7 +949,7 @@ def get_channel_positions(self, grpid:str, chnlid:str): if "channels" not in self.__grp_cache__[grpid]: return [] - + ret = list() channels = list() if chnlid != 'all': @@ -947,12 +960,12 @@ def get_channel_positions(self, grpid:str, chnlid:str): for cid in channels: if cid not in self.__grp_cache__[grpid]["channels"]: continue - - filepath = "./generated/traders/%s/rtdata.json" % (cid) + + filepath = "./generated/traders/%s/rtdata.json" % cid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): return [] - + f = open(filepath, "r") try: content = f.read() @@ -962,13 +975,14 @@ def get_channel_positions(self, grpid:str, chnlid:str): for pItem in positions: pItem["channel"] = cid ret.append(pItem) - except: + except Exception as e: + print(e) pass f.close() return ret - def get_channel_funds(self, grpid:str, chnlid:str): + def get_channel_funds(self, grpid: str, chnlid: str): if self.__config__ is None: return [] @@ -983,7 +997,7 @@ def get_channel_funds(self, grpid:str, chnlid:str): if "channels" not in self.__grp_cache__[grpid]: return None - + ret = dict() channels = list() if chnlid != 'all': @@ -994,12 +1008,12 @@ def get_channel_funds(self, grpid:str, chnlid:str): for cid in channels: if cid not in self.__grp_cache__[grpid]["channels"]: continue - - filepath = "./generated/traders/%s/rtdata.json" % (cid) + + filepath = "./generated/traders/%s/rtdata.json" % cid filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): continue - + f = open(filepath, "r") try: content = f.read() @@ -1007,7 +1021,8 @@ def get_channel_funds(self, grpid:str, chnlid:str): funds = json_data["funds"] ret[cid] = funds - except: + except Exception as e: + print(e) pass f.close() @@ -1017,7 +1032,9 @@ def get_actions(self, sdate, edate): ret = list() cur = self.__db_conn__.cursor() - for row in cur.execute("SELECT id,loginid,actiontime,actionip,actiontype,remark FROM actions WHERE actiontime>=? and actiontime<=?;", (sdate, edate)): + for row in cur.execute( + "SELECT id,loginid,actiontime,actionip,actiontype,remark FROM actions WHERE actiontime>=? and actiontime<=?;", + (sdate, edate)): aInfo = dict() aInfo["id"] = row[0] aInfo["loginid"] = row[1] @@ -1030,7 +1047,7 @@ def get_actions(self, sdate, edate): return ret - def get_group_trades(self, grpid:str): + def get_group_trades(self, grpid: str): if grpid not in self.__config__["groups"]: return [] @@ -1039,7 +1056,7 @@ def get_group_trades(self, grpid:str): if "grptrades" not in self.__grp_cache__[grpid]: self.__grp_cache__[grpid]["grptrades"] = dict() - + filepath = "./generated/portfolio/trades.csv" filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): @@ -1057,7 +1074,7 @@ def get_group_trades(self, grpid:str): last_row = trdCache["lastrow"] lines = f.readlines() f.close() - lines = lines[1+last_row:] + lines = lines[1 + last_row:] for line in lines: cells = line.split(",") @@ -1074,10 +1091,10 @@ def get_group_trades(self, grpid:str): trdCache["trades"].append(tItem) trdCache["lastrow"] += 1 - + return trdCache["trades"] - def get_group_rounds(self, grpid:str): + def get_group_rounds(self, grpid: str): if grpid not in self.__config__["groups"]: return [] @@ -1086,7 +1103,7 @@ def get_group_rounds(self, grpid:str): if "grprounds" not in self.__grp_cache__[grpid]: self.__grp_cache__[grpid]["grprounds"] = dict() - + filepath = "./generated/portfolio/closes.csv" filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): @@ -1104,7 +1121,7 @@ def get_group_rounds(self, grpid:str): last_row = trdCache["lastrow"] lines = f.readlines() f.close() - lines = lines[1+last_row:] + lines = lines[1 + last_row:] for line in lines: cells = line.split(",") @@ -1122,10 +1139,10 @@ def get_group_rounds(self, grpid:str): trdCache["rounds"].append(tItem) trdCache["lastrow"] += 1 - + return trdCache["rounds"] - def get_group_funds(self, grpid:str): + def get_group_funds(self, grpid: str): if grpid not in self.__config__["groups"]: return [] @@ -1134,7 +1151,7 @@ def get_group_funds(self, grpid:str): if "grpfunds" not in self.__grp_cache__[grpid]: self.__grp_cache__[grpid]["grpfunds"] = dict() - + filepath = "./generated/portfolio/funds.csv" filepath = os.path.join(grpInfo["path"], filepath) if os.path.exists(filepath): @@ -1154,7 +1171,7 @@ def get_group_funds(self, grpid:str): last_row = trdCache["lastrow"] lines = f.readlines() f.close() - lines = lines[1+last_row:] + lines = lines[1 + last_row:] for line in lines: cells = line.split(",") @@ -1179,7 +1196,7 @@ def get_group_funds(self, grpid:str): trdCache["funds"].append(tItem) trdCache["lastrow"] += 1 - + ret = trdCache["funds"].copy() if len(ret) > 0: @@ -1216,18 +1233,19 @@ def get_group_funds(self, grpid:str): "mdminbalance": fund["minmd"]["dyn_balance"], "mdmindate": fund["minmd"]["date"] }) - except: + except Exception as e: + print(e) pass f.close() return ret - def get_group_positions(self, grpid:str): + def get_group_positions(self, grpid: str): if grpid not in self.__config__["groups"]: return [] grpInfo = self.__config__["groups"][grpid] self.__check_cache__(grpid, grpInfo) - + filepath = "./generated/portfolio/datas.json" filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): @@ -1247,20 +1265,21 @@ def get_group_positions(self, grpid:str): for dItem in pItem["details"]: dItem["code"] = pItem["code"] ret.append(dItem) - except: + except Exception as e: + print(e) pass f.close() return ret - def get_group_performances(self, grpid:str): + def get_group_performances(self, grpid: str): if grpid not in self.__config__["groups"]: return {} grpInfo = self.__config__["groups"][grpid] self.__check_cache__(grpid, grpInfo) - - filepath = "./generated/portfolio/datas.json" + + filepath = "./generated/portfolio/datas.json" filepath = os.path.join(grpInfo["path"], filepath) if not os.path.exists(filepath): return {} @@ -1277,39 +1296,40 @@ def get_group_performances(self, grpid:str): ay = code.split(".") pid = code if len(ay) > 2: - if ay[1] not in ['IDX','STK','ETF']: + if ay[1] not in ['IDX', 'STK', 'ETF']: pid = ay[0] + "." + ay[1] else: pid = ay[0] + "." + ay[2] if pid not in perf: perf[pid] = { - 'closeprofit':0, - 'dynprofit':0 + 'closeprofit': 0, + 'dynprofit': 0 } perf[pid]['closeprofit'] += pItem['closeprofit'] perf[pid]['dynprofit'] += pItem['dynprofit'] - - except: + + except Exception as e: + print(e) pass f.close() return perf - def get_group_filters(self, grpid:str): + def get_group_filters(self, grpid: str): if grpid not in self.__config__["groups"]: return {} grpInfo = self.__config__["groups"][grpid] self.__check_cache__(grpid, grpInfo) - + filepath = os.path.join(grpInfo["path"], 'filters.json') isYaml = False if not os.path.exists(filepath): filepath = os.path.join(grpInfo["path"], 'filters.yaml') isYaml = True - + if not os.path.exists(filepath): filters = {} else: @@ -1321,7 +1341,8 @@ def get_group_filters(self, grpid:str): filters = yaml.full_load(content) else: filters = json.loads(content) - except: + except Exception as e: + print(e) pass f.close() @@ -1338,23 +1359,23 @@ def get_group_filters(self, grpid:str): for sid in gpCache["strategies"]: if sid not in filters['strategy_filters']: filters['strategy_filters'][sid] = False - + if "executers" in gpCache: for eid in gpCache["executers"]: if eid not in filters['executer_filters']: filters['executer_filters'][eid] = False for id in filters['strategy_filters'].keys(): - if type(filters['strategy_filters'][id]) != bool: + if not isinstance(filters['strategy_filters'][id], bool): filters['strategy_filters'][id] = True for id in filters['code_filters'].keys(): - if type(filters['code_filters'][id]) != bool: + if not isinstance(filters['code_filters'][id], bool): filters['code_filters'][id] = True return filters - def set_group_filters(self, grpid:str, filters:dict): + def set_group_filters(self, grpid: str, filters: dict): if grpid not in self.__config__["groups"]: return False @@ -1362,30 +1383,30 @@ def set_group_filters(self, grpid:str, filters:dict): self.__check_cache__(grpid, grpInfo) realfilters = { - "strategy_filters":{}, - "code_filters":{}, - "executer_filters":{} + "strategy_filters": {}, + "code_filters": {}, + "executer_filters": {} } if "strategy_filters" in filters: for sid in filters["strategy_filters"]: if filters["strategy_filters"][sid]: realfilters["strategy_filters"][sid] = { - "action":"redirect", - "target":0 + "action": "redirect", + "target": 0 } if "code_filters" in filters: for sid in filters["code_filters"]: if filters["code_filters"][sid]: realfilters["code_filters"][sid] = { - "action":"redirect", - "target":0 + "action": "redirect", + "target": 0 } if "executer_filters" in filters: realfilters["executer_filters"] = filters["executer_filters"] - + filepath = os.path.join(grpInfo["path"], 'filters.json') isYaml = False if not os.path.exists(filepath): @@ -1399,4 +1420,3 @@ def set_group_filters(self, grpid:str, filters:dict): f.write(json.dumps(realfilters, indent=4)) f.close() return True - \ No newline at end of file diff --git a/wtpy/monitor/EventReceiver.py b/wtpy/monitor/EventReceiver.py index 4faf2669..49e68118 100644 --- a/wtpy/monitor/EventReceiver.py +++ b/wtpy/monitor/EventReceiver.py @@ -3,32 +3,34 @@ from wtpy import WtMsgQue, WtMQClient -TOPIC_RT_TRADE = "TRD_TRADE" # 生产环境下的成交通知 -TOPIC_RT_ORDER = "TRD_ORDER" # 生产环境下的订单通知 +TOPIC_RT_TRADE = "TRD_TRADE" # 生产环境下的成交通知 +TOPIC_RT_ORDER = "TRD_ORDER" # 生产环境下的订单通知 TOPIC_RT_NOTIFY = "TRD_NOTIFY" # 生产环境下的普通通知 -TOPIC_RT_LOG = "LOG" # 生产环境下的日志通知 -TOPIC_TIMEOUT = "TIMEOUT" # 消息超时通知 +TOPIC_RT_LOG = "LOG" # 生产环境下的日志通知 +TOPIC_TIMEOUT = "TIMEOUT" # 消息超时通知 + class EventSink: def __init__(self): pass - def on_order(self, chnl:str, ordInfo:dict): + def on_order(self, chnl: str, ordInfo: dict): pass - def on_trade(self, chnl:str, trdInfo:dict): + def on_trade(self, chnl: str, trdInfo: dict): pass - - def on_notify(self, chnl:str, message:str): + + def on_notify(self, chnl: str, message: str): pass - def on_log(self, tag:str, time:int, message:str): + def on_log(self, tag: str, time: int, message: str): pass def on_timeout(self): pass -def decode_bytes(data:bytes): + +def decode_bytes(data: bytes): ret = chardet.detect(data) if ret is not None: encoding = ret["encoding"] @@ -39,9 +41,10 @@ def decode_bytes(data:bytes): else: return data.decode() + class EventReceiver(WtMQClient): - def __init__(self, url:str, topics:list = [], sink:EventSink = None, logger = None): + def __init__(self, url: str, topics: list = [], sink: EventSink = None, logger=None): self.url = url self.logger = logger mq = WtMsgQue(logger) @@ -53,13 +56,13 @@ def __init__(self, url:str, topics:list = [], sink:EventSink = None, logger = No self._worker = None self._sink = sink - def on_mq_message(self, topic:str, message:str, dataLen:int): + def on_mq_message(self, topic: str, message: str, dataLen: int): topic = decode_bytes(topic) if dataLen > 0: message = decode_bytes(message[:dataLen]) else: message = None - + if self._sink is not None: if topic == TOPIC_RT_TRADE: msgObj = json.loads(message) @@ -87,29 +90,32 @@ def run(self): def release(self): mq.destroy_mq_client(self) -TOPIC_BT_EVENT = "BT_EVENT" # 回测环境下的事件,主要通知回测的启动和结束 -TOPIC_BT_STATE = "BT_STATE" # 回测的状态 -TOPIC_BT_FUND = "BT_FUND" # 每日资金变化 + +TOPIC_BT_EVENT = "BT_EVENT" # 回测环境下的事件,主要通知回测的启动和结束 +TOPIC_BT_STATE = "BT_STATE" # 回测的状态 +TOPIC_BT_FUND = "BT_FUND" # 每日资金变化 + class BtEventSink: def __init__(self): pass - + def on_begin(self): pass - + def on_finish(self): pass - def on_fund(self, fundInfo:dict): + def on_fund(self, fundInfo: dict): pass - def on_state(self, statInfo:float): + def on_state(self, statInfo: float): pass + class BtEventReceiver(WtMQClient): - def __init__(self, url:str, topics:list = [], sink:BtEventSink = None, logger = None): + def __init__(self, url: str, topics: list = [], sink: BtEventSink = None, logger=None): self.url = url self.logger = logger mq.add_mq_client(url, self) @@ -120,7 +126,7 @@ def __init__(self, url:str, topics:list = [], sink:BtEventSink = None, logger = self._worker = None self._sink = sink - def on_mq_message(self, topic:str, message:str, dataLen:int): + def on_mq_message(self, topic: str, message: str, dataLen: int): topic = decode_bytes(topic) message = decode_bytes(message[:dataLen]) if self._sink is not None: diff --git a/wtpy/monitor/PushSvr.py b/wtpy/monitor/PushSvr.py index c8b114f0..dc2f7088 100644 --- a/wtpy/monitor/PushSvr.py +++ b/wtpy/monitor/PushSvr.py @@ -6,9 +6,10 @@ import threading import time + class PushServer: - def __init__(self, app:FastAPI, dataMgr, logger:WtLogger = None): + def __init__(self, app: FastAPI, dataMgr, logger: WtLogger = None): self.app = app self.dataMgr = dataMgr self.logger = logger @@ -20,7 +21,7 @@ def __init__(self, app:FastAPI, dataMgr, logger:WtLogger = None): self.mutex = threading.Lock() self.messages = list() - self.worker:threading.Thread = None + self.worker: threading.Thread = None self.stopped = False async def connect(self, ws: WebSocket): @@ -47,15 +48,16 @@ async def send_personal_message(data: dict, ws: WebSocket): # 发送个人消息 await ws.send_json(data) - def broadcast(self, data: dict, groupid:str=""): + def broadcast(self, data: dict, groupid: str = ""): self.lock.acquire() loop = None try: loop = asyncio.get_event_loop() - except: + except Exception as e: + print(e) loop = None - + if loop is None or loop.is_closed(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -63,17 +65,17 @@ def broadcast(self, data: dict, groupid:str=""): tasks = [] # 广播消息 for ws in self.active_connections: - if len(groupid)!=0 and "groupid" in ws.session and ws.session["groupid"]!=groupid: + if len(groupid) != 0 and "groupid" in ws.session and ws.session["groupid"] != groupid: continue tasks.append(asyncio.ensure_future(ws.send_json(data))) - - if len(tasks) > 0: + + if len(tasks) > 0: loop.run_until_complete(asyncio.gather(*tasks)) loop.close() self.lock.release() - def on_subscribe_group(self, ws:WebSocket, data:dict): + def on_subscribe_group(self, ws: WebSocket, data: dict): if ws not in self.active_connections: return @@ -82,12 +84,14 @@ def on_subscribe_group(self, ws:WebSocket, data:dict): tokenInfo = ws.session["tokeninfo"] ws.session["groupid"] = data["groupid"] - self.logger.info("{}@{} subscribed group {}".format(tokenInfo["loginid"], tokenInfo["loginip"] , data["groupid"])) + self.logger.info( + "{}@{} subscribed group {}".format(tokenInfo["loginid"], tokenInfo["loginip"], data["groupid"])) def run(self): app = self.app + @app.websocket("/") - async def ws_listen(ws:WebSocket): + async def ws_listen(ws: WebSocket): await self.connect(ws) try: while True: @@ -96,15 +100,17 @@ async def ws_listen(ws:WebSocket): req = json.loads(data) tp = req["type"] if tp == 'subscribe': - self.on_subscribe_group(ws,req) + self.on_subscribe_group(ws, req) await self.send_personal_message(req, ws) elif tp == 'heartbeat': - await self.send_personal_message({"type":"heartbeat", "message":"pong"}, ws) - except: + await self.send_personal_message({"type": "heartbeat", "message": "pong"}, ws) + except Exception as e: + print(e) continue except WebSocketDisconnect: self.disconnect(ws) + self.ready = True self.worker = threading.Thread(target=self.loop, daemon=True) @@ -115,7 +121,7 @@ def loop(self): if len(self.messages) == 0: time.sleep(1) continue - + self.mutex.acquire() messages = self.messages.copy() self.messages = [] @@ -127,12 +133,12 @@ def loop(self): else: self.broadcast(msg) - def notifyGrpLog(self, groupid, tag:str, time:int, message): + def notifyGrpLog(self, groupid, tag: str, time: int, message): if not self.ready: return self.mutex.acquire() - self.messages.append({"type":"gplog", "groupid":groupid, "tag":tag, "time":time, "message":message}) + self.messages.append({"type": "gplog", "groupid": groupid, "tag": tag, "time": time, "message": message}) self.mutex.release() def notifyGrpEvt(self, groupid, evttype): @@ -140,7 +146,7 @@ def notifyGrpEvt(self, groupid, evttype): return self.mutex.acquire() - self.messages.append({"type":"gpevt", "groupid":groupid, "evttype":evttype}) + self.messages.append({"type": "gpevt", "groupid": groupid, "evttype": evttype}) self.mutex.release() def notifyGrpChnlEvt(self, groupid, chnlid, evttype, data): @@ -148,5 +154,6 @@ def notifyGrpChnlEvt(self, groupid, chnlid, evttype, data): return self.mutex.acquire() - self.messages.append({"type":"chnlevt", "groupid":groupid, "channel":chnlid, "data":data, "evttype":evttype}) - self.mutex.release() \ No newline at end of file + self.messages.append( + {"type": "chnlevt", "groupid": groupid, "channel": chnlid, "data": data, "evttype": evttype}) + self.mutex.release() diff --git a/wtpy/monitor/WatchDog.py b/wtpy/monitor/WatchDog.py index d7d1cc14..7d52ef72 100644 --- a/wtpy/monitor/WatchDog.py +++ b/wtpy/monitor/WatchDog.py @@ -14,61 +14,65 @@ from enum import Enum + def isWindows(): if "windows" in platform.system().lower(): return True return False + class WatcherSink: def __init__(self): pass - def on_start(self, appid:str): + def on_start(self, appid: str): pass - def on_stop(self, appid:str, isErr:bool = False): + def on_stop(self, appid: str, isErr: bool = False): pass - def on_output(self, appid:str, tag:str, time:int, message:str): + def on_output(self, appid: str, tag: str, time: int, message: str): pass - def on_order(self, appid:str, chnl:str, ordInfo:dict): + def on_order(self, appid: str, chnl: str, ordInfo: dict): pass - def on_trade(self, appid:str, chnl:str, trdInfo:dict): + def on_trade(self, appid: str, chnl: str, trdInfo: dict): pass - - def on_notify(self, appid:str, chnl:str, message:str): + + def on_notify(self, appid: str, chnl: str, message: str): pass - def on_timeout(self, appid:str): + def on_timeout(self, appid: str): pass class ActionType(Enum): - ''' + """ 操作类型 枚举变量 - ''' - AT_START = 0 - AT_STOP = 1 - AT_RESTART = 2 + """ + AT_START = 0 + AT_STOP = 1 + AT_RESTART = 2 + class AppState(Enum): - ''' + """ app状态 枚举变量 - ''' - AS_NotExist = 901 - AS_NotRunning = 902 - AS_Running = 903 - AS_Closed = 904 - AS_Closing = 905 + """ + AS_NotExist = 901 + AS_NotRunning = 902 + AS_Running = 903 + AS_Closed = 904 + AS_Closing = 905 + class AppInfo(EventSink): - def __init__(self, appConf:dict, sink:WatcherSink = None, logger:WtLogger=None): + def __init__(self, appConf: dict, sink: WatcherSink = None, logger: WtLogger = None): self.__info__ = appConf self._cmd_line = None @@ -98,7 +102,7 @@ def __init__(self, appConf:dict, sink:WatcherSink = None, logger:WtLogger=None): if not os.path.exists(appConf["folder"]) or not os.path.exists(appConf["path"]): self._state == AppState.AS_NotExist - def applyConf(self, appConf:dict): + def applyConf(self, appConf: dict): self._lock.acquire() self.__info__ = appConf self._check_span = appConf["span"] @@ -140,7 +144,7 @@ def is_running(self, pids) -> bool: if self._state == AppState.AS_Closed: return False - + bNeedCheck = (self._procid is None) or (not psutil.pid_exists(self._procid)) if bNeedCheck: for pid in pids: @@ -155,7 +159,7 @@ def is_running(self, pids) -> bool: self._procid = pid self._mem = pInfo.memory_info().rss self.__logger__.info("应用%s挂载成功,进程ID: %d" % (self._id, self._procid)) - + if self._mq_url != '': # 如果事件接收器为空或者url发生了改变,则需要重新创建 bNeedCreate = self._evt_receiver is None or self._evt_receiver.url != self._mq_url @@ -166,7 +170,8 @@ def is_running(self, pids) -> bool: self._evt_receiver.run() self.__logger__.info("应用%s开始接收%s的通知信息" % (self._id, self._mq_url)) return True - except: + except Exception as e: + print(e) pass return False else: @@ -182,27 +187,30 @@ def run(self): if self._mq_url != '': # 每次启动都重新创建接收器 if self._evt_receiver is not None: - self.__logger__.info("应用%s正在释放原有事件接收器..." % (self._id)) + self.__logger__.info("应用%s正在释放原有事件接收器..." % self._id) self._evt_receiver.release() self._evt_receiver = EventReceiver(url=self._mq_url, logger=self.__logger__, sink=self) self._evt_receiver.run() self.__logger__.info("应用%s开始接收%s的通知信息" % (self._id, self._mq_url)) try: - args = self.__info__["param"].split(" ") if self.__info__["param"] != "" else [] + args = self.__info__["param"].split(" ") if self.__info__["param"] != "" else [] args.insert(0, self.__info__["path"]) if isWindows(): self._procid = subprocess.Popen(args, - cwd=self.__info__["folder"], creationflags=subprocess.CREATE_NEW_CONSOLE).pid + cwd=self.__info__["folder"], + creationflags=subprocess.CREATE_NEW_CONSOLE).pid else: - self._procid = subprocess.Popen(args, - cwd=self.__info__["folder"]).pid - - self._cmd_line = (self.__info__["path"] + " " + self.__info__["param"]) if self.__info__["param"] != "" else self.__info__["path"] + self._procid = subprocess.Popen(args, + cwd=self.__info__["folder"]).pid + + self._cmd_line = (self.__info__["path"] + " " + self.__info__["param"]) if self.__info__["param"] != "" else \ + self.__info__["path"] self.__logger__.info(f"cmdline: {self._cmd_line}, cwd:{self.__info__['folder']}") - except: - self.__logger__.info("应用%s启动异常" % (self._id)) + except Exception as e: + print(e) + self.__logger__.info("应用%s启动异常" % self._id) self._state = AppState.AS_Running @@ -221,7 +229,7 @@ def stop(self): else: os.system("kill -9 " + str(self._procid)) except SystemError as e: - self.__logger__.error("关闭异常: {}" % (e)) + self.__logger__.error("关闭异常: {}" % e) pass self._state = AppState.AS_Closed @@ -233,7 +241,7 @@ def stop(self): def restart(self): if self._procid is not None: self.stop() - + self.run() def update_state(self, pids): @@ -241,7 +249,7 @@ def update_state(self, pids): self._state = AppState.AS_Running elif self._state == AppState.AS_Running: self._state = AppState.AS_NotRunning - self.__logger__.info("应用%s的已停止" % (self._id)) + self.__logger__.info("应用%s的已停止" % self._id) self._procid = None self._mem = 0 if self._sink is not None: @@ -253,7 +261,7 @@ def tick(self, pids): if self._ticks == self._check_span: self.update_state(pids) if self._state == AppState.AS_NotRunning and self._guard: - self.__logger__.info("应用%s未启动,正在自动重启" % (self._id)) + self.__logger__.info("应用%s未启动,正在自动重启" % self._id) thrd = threading.Thread(target=self.run, daemon=True) thrd.start() # self.run() @@ -261,7 +269,7 @@ def tick(self, pids): self.__schedule__() self._ticks = 0 - + def __schedule__(self): weekflag = self._weekflag @@ -282,7 +290,7 @@ def __schedule__(self): for tInfo in self.__info__["schedule"]["tasks"]: if not tInfo["active"]: continue - + if "lastDate" in tInfo: lastDate = tInfo["lastDate"] else: @@ -331,29 +339,30 @@ def on_timeout(self): self._sink.on_timeout(self._id) # EventSink.on_order - def on_order(self, chnl:str, ordInfo:dict): + def on_order(self, chnl: str, ordInfo: dict): if self._sink is not None: self._sink.on_order(self._id, chnl, ordInfo) # EventSink.on_trade - def on_trade(self, chnl:str, trdInfo:dict): + def on_trade(self, chnl: str, trdInfo: dict): if self._sink is not None: self._sink.on_trade(self._id, chnl, trdInfo) - + # EventSink.on_notify - def on_notify(self, chnl:str, message:str): + def on_notify(self, chnl: str, message: str): if self._sink is not None: self._sink.on_notify(self._id, chnl, message) # EventSink.on_log - def on_log(self, tag:str, time:int, message:str): + def on_log(self, tag: str, time: int, message: str): if self._sink is not None: self._sink.on_output(self._id, tag, time, message) pass + class WatchDog: - def __init__(self, db, sink:WatcherSink = None, logger:WtLogger=None): + def __init__(self, db, sink: WatcherSink = None, logger: WtLogger = None): self.__db_conn__ = db self.__apps__ = dict() self.__app_conf__ = dict() @@ -364,7 +373,7 @@ def __init__(self, db, sink:WatcherSink = None, logger:WtLogger=None): mq = WtMsgQue(logger) - #加载调度列表 + # 加载调度列表 cur = self.__db_conn__.cursor() for row in cur.execute("SELECT * FROM schedules;"): appConf = dict() @@ -374,11 +383,11 @@ def __init__(self, db, sink:WatcherSink = None, logger:WtLogger=None): appConf["param"] = row[4] appConf["type"] = row[5] appConf["span"] = row[6] - appConf["guard"] = row[7]=='true' - appConf["redirect"] = row[8]=='true' + appConf["guard"] = row[7] == 'true' + appConf["redirect"] = row[8] == 'true' appConf["mqurl"] = row[11] appConf["schedule"] = dict() - appConf["schedule"]["active"] = row[9]=='true' + appConf["schedule"]["active"] = row[9] == 'true' appConf["schedule"]["weekflag"] = row[10] appConf["schedule"]["tasks"] = list() appConf["schedule"]["tasks"].append(json.loads(row[12])) @@ -390,7 +399,6 @@ def __init__(self, db, sink:WatcherSink = None, logger:WtLogger=None): self.__app_conf__[appConf["id"]] = appConf self.__apps__[appConf["id"]] = AppInfo(appConf, sink, self.__logger__) - def __watch_impl__(self): while not self.__stopped__: time.sleep(1) @@ -416,7 +424,7 @@ def run(self): self.__worker__.start() self.__logger__.info("自动调度服务已启动") - def start(self, appid:str): + def start(self, appid: str): if appid not in self.__apps__: return @@ -426,7 +434,7 @@ def start(self, appid:str): # thrd.start() appInfo.run() - def stop(self, appid:str): + def stop(self, appid: str): if appid not in self.__apps__: return @@ -436,10 +444,10 @@ def stop(self, appid:str): # thrd.start() appInfo.stop() - def has_app(self, appid:str): + def has_app(self, appid: str): return appid in self.__apps__ - def restart(self, appid:str): + def restart(self, appid: str): if appid not in self.__apps__: return @@ -447,22 +455,22 @@ def restart(self, appid:str): # thrd = threading.Thread(target=appInfo.restart, daemon=True) # thrd.start() appInfo.restart() - - def isRunning(self, appid:str): + + def isRunning(self, appid: str): if appid not in self.__apps__: return False appInfo = self.__apps__[appid] return appInfo.isRunning() - def getAppConf(self, appid:str): + def getAppConf(self, appid: str): if appid not in self.__apps__: return None - + appInfo = self.__apps__[appid] return appInfo.getConf() - def delApp(self, appid:str): + def delApp(self, appid: str): if appid not in self.__apps__: return @@ -473,7 +481,7 @@ def delApp(self, appid:str): self.__db_conn__.commit() self.__logger__.info("应用%s自动调度已删除" % (appid)) - def updateMQURL(self, appid:str, mqurl:str): + def updateMQURL(self, appid: str, mqurl: str): if appid not in self.__apps__: return @@ -481,14 +489,15 @@ def updateMQURL(self, appid:str, mqurl:str): appConf = self.__app_conf__[appid] appInst = self.__apps__[appid] appInst.applyConf(appConf) - + cur = self.__db_conn__.cursor() - sql = "UPDATE schedules SET mqurl='%s',modifytime=datetime('now','localtime') WHERE appid='%s';" % (mqurl, appid) + sql = "UPDATE schedules SET mqurl='%s',modifytime=datetime('now','localtime') WHERE appid='%s';" % ( + mqurl, appid) print(sql) cur.execute(sql) self.__db_conn__.commit() - def applyAppConf(self, appConf:dict, isGroup:bool = False): + def applyAppConf(self, appConf: dict, isGroup: bool = False): appid = appConf["id"] self.__app_conf__[appid] = appConf isNewApp = False @@ -514,16 +523,22 @@ def applyAppConf(self, appConf:dict, isGroup:bool = False): if isNewApp: sql = "INSERT INTO schedules(appid,path,folder,param,type,span,guard,redirect,schedule,weekflag,task1,task2,task3,task4,task5,task6,mqurl) \ VALUES('%s','%s','%s','%s',%d, %d,'%s','%s','%s','%s','%s','%s','%s','%s','%s','%s','%s');" % ( - appid, appConf["path"], appConf["folder"], appConf["param"], stype, appConf["span"], guard, redirect, schedule, appConf["schedule"]["weekflag"], - json.dumps(appConf["schedule"]["tasks"][0]),json.dumps(appConf["schedule"]["tasks"][1]),json.dumps(appConf["schedule"]["tasks"][2]), - json.dumps(appConf["schedule"]["tasks"][3]),json.dumps(appConf["schedule"]["tasks"][4]),json.dumps(appConf["schedule"]["tasks"][5]), - mqurl) + appid, appConf["path"], appConf["folder"], appConf["param"], stype, appConf["span"], guard, redirect, + schedule, appConf["schedule"]["weekflag"], + json.dumps(appConf["schedule"]["tasks"][0]), json.dumps(appConf["schedule"]["tasks"][1]), + json.dumps(appConf["schedule"]["tasks"][2]), + json.dumps(appConf["schedule"]["tasks"][3]), json.dumps(appConf["schedule"]["tasks"][4]), + json.dumps(appConf["schedule"]["tasks"][5]), + mqurl) else: sql = "UPDATE schedules SET path='%s',folder='%s',param='%s',type=%d,span='%s',guard='%s',redirect='%s',schedule='%s',weekflag='%s',task1='%s',task2='%s',\ task3='%s',task4='%s',task5='%s',task6='%s',mqurl='%s',modifytime=datetime('now','localtime') WHERE appid='%s';" % ( - appConf["path"], appConf["folder"], appConf["param"], stype, appConf["span"], guard, redirect, schedule, appConf["schedule"]["weekflag"], - json.dumps(appConf["schedule"]["tasks"][0]),json.dumps(appConf["schedule"]["tasks"][1]),json.dumps(appConf["schedule"]["tasks"][2]), - json.dumps(appConf["schedule"]["tasks"][3]),json.dumps(appConf["schedule"]["tasks"][4]),json.dumps(appConf["schedule"]["tasks"][5]), - mqurl, appid) + appConf["path"], appConf["folder"], appConf["param"], stype, appConf["span"], guard, redirect, schedule, + appConf["schedule"]["weekflag"], + json.dumps(appConf["schedule"]["tasks"][0]), json.dumps(appConf["schedule"]["tasks"][1]), + json.dumps(appConf["schedule"]["tasks"][2]), + json.dumps(appConf["schedule"]["tasks"][3]), json.dumps(appConf["schedule"]["tasks"][4]), + json.dumps(appConf["schedule"]["tasks"][5]), + mqurl, appid) cur.execute(sql) self.__db_conn__.commit() diff --git a/wtpy/monitor/WtBtMon.py b/wtpy/monitor/WtBtMon.py index 41cbff86..0f59222d 100644 --- a/wtpy/monitor/WtBtMon.py +++ b/wtpy/monitor/WtBtMon.py @@ -15,63 +15,71 @@ from .WtLogger import WtLogger from .EventReceiver import BtEventReceiver, BtEventSink + def isWindows(): if "windows" in platform.system().lower(): return True return False -def md5_str(v:str) -> str: + +def md5_str(v: str) -> str: return hashlib.md5(v.encode()).hexdigest() -def gen_btid(user:str, straid:str) -> str: + +def gen_btid(user: str, straid: str) -> str: now = datetime.datetime.now() s = user + "_" + straid + "_" + str(now.timestamp()) return md5_str(s) -def gen_straid(user:str) -> str: + +def gen_straid(user: str) -> str: now = datetime.datetime.now() s = user + "_" + str(now.timestamp()) return md5_str(s) + class BtTaskSink: def __init__(self): pass - def on_start(self, user:str, straid:str, btid:str): + def on_start(self, user: str, straid: str, btid: str): pass - def on_stop(self, user:str, straid:str, btid:str): + def on_stop(self, user: str, straid: str, btid: str): pass - def on_state(self, user:str, straid:str, btid:str, statInfo:dict): + def on_state(self, user: str, straid: str, btid: str, statInfo: dict): pass - def on_fund(self, user:str, straid:str, btid:str, fundInfo:dict): + def on_fund(self, user: str, straid: str, btid: str, fundInfo: dict): pass + class WtBtTask(BtEventSink): - ''' + """ 回测任务类 - ''' - def __init__(self, user:str, straid:str, btid:str, folder:str, logger:WtLogger = None, sink:BtTaskSink = None): + """ + + def __init__(self, user: str, straid: str, btid: str, folder: str, logger: WtLogger = None, + sink: BtTaskSink = None): self.user = user self.straid = straid self.btid = btid self.logger = logger self.folder = folder self.sink = sink - + self._cmd_line = None - self._mq_url = "ipc:///wtpy/bt_%s.ipc" % (btid) + self._mq_url = "ipc:///wtpy/bt_%s.ipc" % btid self._ticks = 0 self._state = 0 self._procid = None self._evt_receiver = None def __check__(self): - while True: + while True: time.sleep(1) pids = psutil.pids() if psutil.pid_exists(self._procid): @@ -94,14 +102,15 @@ def run(self): fullPath = os.path.join(self.folder, "runBT.py") if isWindows(): self._procid = subprocess.Popen([sys.executable, fullPath], # 需要执行的文件路径 - cwd=self.folder, creationflags=subprocess.CREATE_NEW_CONSOLE).pid + cwd=self.folder, creationflags=subprocess.CREATE_NEW_CONSOLE).pid else: self._procid = subprocess.Popen([sys.executable, fullPath], # 需要执行的文件路径 - cwd=self.folder).pid + cwd=self.folder).pid self._cmd_line = sys.executable + " " + fullPath - except: - self.logger.info("回测%s启动异常" % (self.btid)) + except Exception as e: + print(e) + self.logger.info("回测%s启动异常" % self.btid) self._state = 1 @@ -139,7 +148,8 @@ def is_running(self, pids) -> bool: self.watcher = threading.Thread(target=self.__check__, name=self.btid, daemon=True) self.watcher.run() - except: + except Exception as e: + print(e) pass return False @@ -152,22 +162,23 @@ def on_begin(self): def on_finish(self): pass - def on_state(self, statInfo:dict): + def on_state(self, statInfo: dict): if self.sink is not None: self.sink.on_state(self.user, self.straid, self.btid, statInfo) print(statInfo) - def on_fund(self, fundInfo:dict): + def on_fund(self, fundInfo: dict): if self.sink is not None: self.sink.on_fund(self.user, self.straid, self.btid, fundInfo) print(fundInfo) class WtBtMon(BtTaskSink): - ''' + """ 回测管理器 - ''' - def __init__(self, deploy_folder:str, dtServo:WtDtServo = None, logger:WtLogger = None): + """ + + def __init__(self, deploy_folder: str, dtServo: WtDtServo = None, logger: WtLogger = None): self.path = deploy_folder self.user_stras = dict() self.user_bts = dict() @@ -179,7 +190,7 @@ def __init__(self, deploy_folder:str, dtServo:WtDtServo = None, logger:WtLogger self.__load_tasks__() - def __load_user_data__(self, user:str): + def __load_user_data__(self, user: str): folder = os.path.join(self.path, user) if not os.path.exists(folder): os.mkdir(folder) @@ -203,8 +214,8 @@ def __save_user_data__(self, user): os.mkdir(folder) obj = { - "strategies":{}, - "backtests":{} + "strategies": {}, + "backtests": {} } if user in self.user_stras: @@ -219,10 +230,10 @@ def __save_user_data__(self, user): f.close() return True - def get_strategies(self, user:str) -> list: + def get_strategies(self, user: str) -> list: if user not in self.user_stras: bSucc = self.__load_user_data__(user) - + if not bSucc: return None @@ -231,7 +242,7 @@ def get_strategies(self, user:str) -> list: ay.append(self.user_stras[user][straid]) return ay - def add_strategy(self, user:str, name:str) -> dict: + def add_strategy(self, user: str, name: str) -> dict: if user not in self.user_stras: self.__load_user_data__(user) @@ -240,9 +251,9 @@ def add_strategy(self, user:str, name:str) -> dict: straid = gen_straid(user) self.user_stras[user][straid] = { - "id":straid, - "name":name, - "perform":{ + "id": straid, + "name": name, + "perform": { "days": 0, "total_return": 0, "annual_return": 0, @@ -269,7 +280,7 @@ def add_strategy(self, user:str, name:str) -> dict: return self.user_stras[user][straid] - def del_strategy(self, user:str, straid:str): + def del_strategy(self, user: str, straid: str): if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -290,8 +301,8 @@ def del_strategy(self, user:str, straid:str): self.user_stras[user].pop(straid) self.__save_user_data__(user) return True - - def has_strategy(self, user:str, straid:str, btid:str = None) -> bool: + + def has_strategy(self, user: str, straid: str, btid: str = None) -> bool: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -303,7 +314,7 @@ def has_strategy(self, user:str, straid:str, btid:str = None) -> bool: else: return btid in self.user_bts[user] - def get_strategy_code(self, user:str, straid:str, btid:str = None) -> str: + def get_strategy_code(self, user: str, straid: str, btid: str = None) -> str: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -330,7 +341,7 @@ def get_strategy_code(self, user:str, straid:str, btid:str = None) -> str: f.close() return content - def set_strategy_code(self, user:str, straid:str, content:str) -> bool: + def set_strategy_code(self, user: str, straid: str, content: str) -> bool: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -346,7 +357,7 @@ def set_strategy_code(self, user:str, straid:str, content:str) -> bool: f.close() return True - def get_backtests(self, user:str, straid:str) -> list: + def get_backtests(self, user: str, straid: str) -> list: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -362,7 +373,7 @@ def get_backtests(self, user:str, straid:str) -> list: return ay - def del_backtest(self, user:str, btid:str): + def del_backtest(self, user: str, btid: str): if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -378,7 +389,7 @@ def del_backtest(self, user:str, btid:str): self.__save_user_data__(user) - def get_bt_funds(self, user:str, straid:str, btid:str) -> list: + def get_bt_funds(self, user: str, straid: str, btid: str) -> list: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -417,10 +428,10 @@ def get_bt_funds(self, user:str, straid:str, btid:str) -> list: tItem["fee"] = float(cells[4]) funds.append(tItem) - + return funds - def get_bt_trades(self, user:str, straid:str, btid:str) -> list: + def get_bt_trades(self, user: str, straid: str, btid: str) -> list: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -465,10 +476,10 @@ def get_bt_trades(self, user:str, straid:str, btid:str) -> list: item["fee"] = float(cells[4]) items.append(item) - + return items - def get_bt_rounds(self, user:str, straid:str, btid:str) -> list: + def get_bt_rounds(self, user: str, straid: str, btid: str) -> list: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -509,10 +520,10 @@ def get_bt_rounds(self, user:str, straid:str, btid:str) -> list: } items.append(item) - + return items - def get_bt_signals(self, user:str, straid:str, btid:str) -> list: + def get_bt_signals(self, user: str, straid: str, btid: str) -> list: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -548,10 +559,10 @@ def get_bt_signals(self, user:str, straid:str, btid:str) -> list: } items.append(item) - + return items - def get_bt_summary(self, user:str, straid:str, btid:str) -> list: + def get_bt_summary(self, user: str, straid: str, btid: str) -> list: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -574,7 +585,7 @@ def get_bt_summary(self, user:str, straid:str, btid:str) -> list: obj = json.loads(content) return obj - def get_bt_state(self, user:str, straid:str, btid:str) -> list: + def get_bt_state(self, user: str, straid: str, btid: str) -> list: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -597,7 +608,7 @@ def get_bt_state(self, user:str, straid:str, btid:str) -> list: obj = json.loads(content) return obj - def get_bt_state(self, user:str, straid:str, btid:str) -> dict: + def get_bt_state(self, user: str, straid: str, btid: str) -> dict: if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -621,7 +632,7 @@ def get_bt_state(self, user:str, straid:str, btid:str) -> dict: return thisBts[btid]["state"] - def update_bt_state(self, user:str, straid:str, btid:str, stateObj:dict): + def update_bt_state(self, user: str, straid: str, btid: str, stateObj: dict): if user not in self.user_bts: bSucc = self.__load_user_data__(user) @@ -634,7 +645,7 @@ def update_bt_state(self, user:str, straid:str, btid:str, stateObj:dict): thisBts[btid]["state"] = stateObj - def get_bt_kline(self, user:str, straid:str, btid:str) -> list: + def get_bt_kline(self, user: str, straid: str, btid: str) -> list: if self.dt_servo is None: return None @@ -643,7 +654,7 @@ def get_bt_kline(self, user:str, straid:str, btid:str) -> list: if not bSucc: return None - + btState = self.get_bt_state(user, straid, btid) if btState is None: return None @@ -664,7 +675,7 @@ def get_bt_kline(self, user:str, straid:str, btid:str) -> list: if period[0] == 'd': bar["time"] = realBar.date else: - bar["time"] = 1990*100000000 + realBar.time + bar["time"] = 1990 * 100000000 + realBar.time bar["bartime"] = bar["time"] bar["open"] = realBar.open bar["high"] = realBar.high @@ -676,20 +687,21 @@ def get_bt_kline(self, user:str, straid:str, btid:str) -> list: return thisBts[btid]["kline"] - def run_backtest(self, user:str, straid:str, fromTime:int, endTime:int, capital:float, slippage:int=0) -> dict: + def run_backtest(self, user: str, straid: str, fromTime: int, endTime: int, capital: float, + slippage: int = 0) -> dict: if user not in self.user_bts: self.__load_user_data__(user) if user not in self.user_bts: self.user_bts[user] = dict() - + btid = gen_btid(user, straid) # 生成回测目录 folder = os.path.join(self.path, user, straid, "backtests") if not os.path.exists(folder): os.mkdir(folder) - + folder = os.path.join(folder, btid) os.mkdir(folder) @@ -735,10 +747,10 @@ def run_backtest(self, user:str, straid:str, fromTime:int, endTime:int, capital: f.close() btInfo = { - "id":btid, - "capital":capital, - "runtime":datetime.datetime.now().strftime("%Y.%m.%d %H:%M:%S"), - "state":{ + "id": btid, + "capital": capital, + "runtime": datetime.datetime.now().strftime("%Y.%m.%d %H:%M:%S"), + "state": { "code": "", "period": "", "stime": fromTime, @@ -746,7 +758,7 @@ def run_backtest(self, user:str, straid:str, fromTime:int, endTime:int, capital: "progress": 0, "elapse": 0 }, - "perform":{ + "perform": { "days": 0, "total_return": 0, "annual_return": 0, @@ -773,17 +785,17 @@ def run_backtest(self, user:str, straid:str, fromTime:int, endTime:int, capital: # 这里还需要记录一下回测的任务,不然如果重启就恢复不了了 taskInfo = { - "user":user, - "straid":straid, - "btid":btid, - "folder":folder + "user": user, + "straid": straid, + "btid": btid, + "folder": folder } - self.task_infos[btid]= taskInfo + self.task_infos[btid] = taskInfo self.__save_tasks__() return btInfo - def __update_bt_result__(self, user:str, straid:str, btid:str): + def __update_bt_result__(self, user: str, straid: str, btid: str): if user not in self.user_bts: self.__load_user_data__(user) @@ -800,7 +812,7 @@ def __update_bt_result__(self, user:str, straid:str, btid:str): self.user_stras[user][straid]["perform"] = summaryObj self.__save_user_data__(user) - + def __save_tasks__(self): obj = self.task_infos @@ -832,18 +844,17 @@ def __load_tasks__(self): else: # 之前记录过测回测任务,执行完成了,要更新回测数据 self.__update_bt_result__(tInfo["user"], tInfo["straid"], btid) - + self.__save_tasks__() - - def on_start(self, user:str, straid:str, btid:str): + def on_start(self, user: str, straid: str, btid: str): pass - def on_stop(self, user:str, straid:str, btid:str): + def on_stop(self, user: str, straid: str, btid: str): self.__update_bt_result__(user, straid, btid) - def on_state(self, user:str, straid:str, btid:str, statInfo:dict): + def on_state(self, user: str, straid: str, btid: str, statInfo: dict): self.user_bts[user][btid]["state"] = statInfo - def on_fund(self, user:str, straid:str, btid:str, fundInfo:dict): - pass \ No newline at end of file + def on_fund(self, user: str, straid: str, btid: str, fundInfo: dict): + pass diff --git a/wtpy/monitor/WtBtSnooper.py b/wtpy/monitor/WtBtSnooper.py index 65a90928..d96bb039 100644 --- a/wtpy/monitor/WtBtSnooper.py +++ b/wtpy/monitor/WtBtSnooper.py @@ -44,17 +44,18 @@ def do_trading_analyze(df_closes, df_funds): # 单笔最大亏损交易 largest_loss = float(df_loses['profit'].min()) # 交易的平均持仓K线根数 - avgtrd_hold_bar = 0 if totaltimes==0 else ((df_closes['closebarno'] - df_closes['openbarno']).sum()) / totaltimes + avgtrd_hold_bar = 0 if totaltimes == 0 else ((df_closes['closebarno'] - df_closes['openbarno']).sum()) / totaltimes # 平均空仓K线根数 avb = (df_closes['openbarno'] - df_closes['closebarno'].shift(1).fillna(value=0)) - avgemphold_bar = 0 if len(df_closes)==0 else avb.sum() / len(df_closes) + avgemphold_bar = 0 if len(df_closes) == 0 else avb.sum() / len(df_closes) # 两笔盈利交易之间的平均空仓K线根数 win_holdbar_situ = (df_wins['openbarno'].shift(-1) - df_wins['closebarno']).dropna() - winempty_avgholdbar = 0 if len(df_wins)== 0 or len(df_wins) == 1 else win_holdbar_situ.sum() / (len(df_wins)-1) + winempty_avgholdbar = 0 if len(df_wins) == 0 or len(df_wins) == 1 else win_holdbar_situ.sum() / (len(df_wins) - 1) # 两笔亏损交易之间的平均空仓K线根数 loss_holdbar_situ = (df_loses['openbarno'].shift(-1) - df_loses['closebarno']).dropna() - lossempty_avgholdbar = 0 if len(df_loses)== 0 or len(df_loses) == 1 else loss_holdbar_situ.sum() / (len(df_loses)-1) + lossempty_avgholdbar = 0 if len(df_loses) == 0 or len(df_loses) == 1 else loss_holdbar_situ.sum() / ( + len(df_loses) - 1) max_consecutive_wins = 0 # 最大连续盈利次数 max_consecutive_loses = 0 # 最大连续亏损次数 @@ -100,15 +101,15 @@ def do_trading_analyze(df_closes, df_funds): summary["max_consecutive_wins"] = max_consecutive_wins summary["max_consecutive_loses"] = max_consecutive_loses - return summary class WtBtSnooper: - ''' + """ 回测管理器 - ''' - def __init__(self, dtServo:WtDtServo = None): + """ + + def __init__(self, dtServo: WtDtServo = None): self.path = "" self.dt_servo = dtServo self.workspaces = list() @@ -128,46 +129,47 @@ def load_data(self): if len(content) == 0: return - obj = json.loads(content) + obj = json.loads(content) if "workspace" in obj: self.workspaces = obj["workspace"] def save_data(self): - obj = { + obj = { "workspace": self.workspaces } content = json.dumps(obj, ensure_ascii=False, indent=4) f = open("data.json", "w") f.write(content) - f.close() + f.close() - def add_static_folder(self, folder:str, path:str = "/static", name:str = "static"): + def add_static_folder(self, folder: str, path: str = "/static", name: str = "static"): self.static_folders.append({ "path": path, "folder": folder, "name": name }) - def __server_impl__(self, port:int, host:str): - uvicorn.run(self.server_inst, port = port, host = host) + def __server_impl__(self, port: int, host: str): + uvicorn.run(self.server_inst, port=port, host=host) - def run_as_server(self, port:int = 8081, host="127.0.0.1", bSync:bool = True): + def run_as_server(self, port: int = 8081, host="127.0.0.1", bSync: bool = True): tags_info = [ - {"name":"Backtest APIs","description":"回测查探器接口"} + {"name": "Backtest APIs", "description": "回测查探器接口"} ] - app = FastAPI(title="WtBtSnooper", description="A simple http api of WtBtSnooper", openapi_tags=tags_info, redoc_url=None, version="1.0.0") + app = FastAPI(title="WtBtSnooper", description="A simple http api of WtBtSnooper", openapi_tags=tags_info, + redoc_url=None, version="1.0.0") app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(SessionMiddleware, secret_key='!@#$%^&*()', max_age=25200, session_cookie='WtBtSnooper_sid') if len(self.static_folders) > 0: for static_item in self.static_folders: - app.mount(static_item["path"], StaticFiles(directory = static_item["folder"]), name=static_item["name"]) + app.mount(static_item["path"], StaticFiles(directory=static_item["folder"]), name=static_item["name"]) else: paths = os.path.split(__file__) a = (paths[:-1] + ("static/console",)) path = os.path.join(*a) - app.mount("/backtest", StaticFiles(directory = path), name="static") + app.mount("/backtest", StaticFiles(directory=path), name="static") self.server_inst = app @@ -177,18 +179,18 @@ def run_as_server(self, port:int = 8081, host="127.0.0.1", bSync:bool = True): self.__server_impl__(port, host) else: import threading - self.worker = threading.Thread(target=self.__server_impl__, args=(port,host,)) + self.worker = threading.Thread(target=self.__server_impl__, args=(port, host,)) self.worker.setDaemon(True) self.worker.start() - def get_workspace_path(self, id:str) ->str: + def get_workspace_path(self, id: str) -> str: for wInfo in self.workspaces: if wInfo["id"] == id: return wInfo["path"] - + return "" - def init_bt_apis(self, app:FastAPI): + def init_bt_apis(self, app: FastAPI): @app.get("/") async def console_entry(): return RedirectResponse("/backtest/backtest.html") @@ -196,17 +198,17 @@ async def console_entry(): @app.post("/bt/qryws", tags=["Backtest APIs"], description="获取工作空间") async def qry_workspaces(): ret = { - "result":0, - "message":"Ok", + "result": 0, + "message": "Ok", "workspaces": self.workspaces - } + } return ret @app.post("/bt/addws", tags=["Backtest APIs"], description="添加工作空间") async def add_workspace( - path:str = Body(..., title="工作空间路径", embed=True), - name:str = Body(..., title="工作空间名称", embed=True) + path: str = Body(..., title="工作空间路径", embed=True), + name: str = Body(..., title="工作空间名称", embed=True) ): md5 = hashlib.md5() now = datetime.datetime.now().replace(tzinfo=pytz.timezone('UTC')).strftime("%Y.%m.%d %H:%M:%S") @@ -223,13 +225,13 @@ async def add_workspace( self.save_data() return { - "result":0, - "message":"Ok" + "result": 0, + "message": "Ok" } @app.post("/bt/delws", tags=["Backtest APIs"], description="删除工作空间") async def del_workspace( - wsid:str = Body(..., title="工作空间ID", embed=True) + wsid: str = Body(..., title="工作空间ID", embed=True) ): for wInfo in self.workspaces: if wInfo["id"] == wsid: @@ -238,25 +240,25 @@ async def del_workspace( break return { - "result":0, - "message":"Ok" + "result": 0, + "message": "Ok" } # 获取策略回测回合 @app.post("/bt/qrybtstras", tags=["Backtest APIs"], description="读取全部回测策略") def qry_stra_bt_strategies( - wsid:str = Body(..., title="工作空间ID", embed=True) + wsid: str = Body(..., title="工作空间ID", embed=True) ): path = self.get_workspace_path(wsid) if len(path) == 0: ret = { - "result":-1, - "message":"Invalid workspace" + "result": -1, + "message": "Invalid workspace" } ret = { - "result":0, - "message":"OK", + "result": 0, + "message": "OK", "strategies": self.get_all_strategy(path) } return ret @@ -264,27 +266,27 @@ def qry_stra_bt_strategies( # 拉取K线数据 @app.post("/bt/qrybars", tags=["Backtest APIs"], description="获取K线") async def qry_bt_bars( - wsid:str = Body(..., title="工作空间ID", embed=True), - straid:str = Body(..., title="策略ID", embed=True) + wsid: str = Body(..., title="工作空间ID", embed=True), + straid: str = Body(..., title="策略ID", embed=True) ): path = self.get_workspace_path(wsid) if len(path) == 0: ret = { - "result":-1, - "message":"Invalid workspace" + "result": -1, + "message": "Invalid workspace" } code, bars, index, marks = self.get_bt_kline(path, straid) if bars is None: ret = { - "result":-2, - "message":"Data not found" + "result": -2, + "message": "Data not found" } else: - + ret = { - "result":0, - "message":"Ok", + "result": 0, + "message": "Ok", "bars": bars, "code": code } @@ -297,128 +299,127 @@ async def qry_bt_bars( return ret - # 获取策略回测信号 @app.post("/bt/qrybtsigs", tags=["Backtest APIs"], description="读取信号明细") def qry_stra_bt_signals( - wsid:str = Body(..., title="工作空间ID", embed=True), - straid:str = Body(..., title="策略ID", embed=True) + wsid: str = Body(..., title="工作空间ID", embed=True), + straid: str = Body(..., title="策略ID", embed=True) ): path = self.get_workspace_path(wsid) if len(path) == 0: ret = { - "result":-1, - "message":"Invalid workspace" + "result": -1, + "message": "Invalid workspace" } ret = { - "result":0, - "message":"OK", - "signals":self.get_bt_signals(path, straid) + "result": 0, + "message": "OK", + "signals": self.get_bt_signals(path, straid) } - + return ret # 获取策略回测成交 @app.post("/bt/qrybttrds", tags=["Backtest APIs"], description="读取成交明细") def qry_stra_bt_trades( - wsid:str = Body(..., title="工作空间ID", embed=True), - straid:str = Body(..., title="策略ID", embed=True) + wsid: str = Body(..., title="工作空间ID", embed=True), + straid: str = Body(..., title="策略ID", embed=True) ): path = self.get_workspace_path(wsid) if len(path) == 0: ret = { - "result":-1, - "message":"Invalid workspace" + "result": -1, + "message": "Invalid workspace" } ret = { - "result":0, - "message":"OK", - "trades":self.get_bt_trades(path, straid) + "result": 0, + "message": "OK", + "trades": self.get_bt_trades(path, straid) } - + return ret # 获取策略回测资金 @app.post("/bt/qrybtfunds", tags=["Backtest APIs"], description="读取资金明细") def qry_stra_bt_funds( - wsid:str = Body(..., title="工作空间ID", embed=True), - straid:str = Body(..., title="策略ID", embed=True) + wsid: str = Body(..., title="工作空间ID", embed=True), + straid: str = Body(..., title="策略ID", embed=True) ): path = self.get_workspace_path(wsid) if len(path) == 0: ret = { - "result":-1, - "message":"Invalid workspace" + "result": -1, + "message": "Invalid workspace" } ret = { - "result":0, - "message":"OK", - "funds":self.get_bt_funds(path, straid) + "result": 0, + "message": "OK", + "funds": self.get_bt_funds(path, straid) } - + return ret # 获取策略回测回合 @app.post("/bt/qrybtrnds", tags=["Backtest APIs"], description="读取回合明细") def qry_stra_bt_rounds( - wsid:str = Body(..., title="工作空间ID", embed=True), - straid:str = Body(..., title="策略ID", embed=True) + wsid: str = Body(..., title="工作空间ID", embed=True), + straid: str = Body(..., title="策略ID", embed=True) ): path = self.get_workspace_path(wsid) if len(path) == 0: ret = { - "result":-1, - "message":"Invalid workspace" + "result": -1, + "message": "Invalid workspace" } ret = { - "result":0, - "message":"OK", - "rounds":self.get_bt_rounds(path, straid) + "result": 0, + "message": "OK", + "rounds": self.get_bt_rounds(path, straid) } return ret # 获取策略回测回合 @app.post("/bt/qrybtinfo", tags=["Backtest APIs"], description="读取回合明细") def qry_stra_bt_rounds( - wsid:str = Body(..., title="工作空间ID", embed=True), - straid:str = Body(..., title="策略ID", embed=True) + wsid: str = Body(..., title="工作空间ID", embed=True), + straid: str = Body(..., title="策略ID", embed=True) ): path = self.get_workspace_path(wsid) if len(path) == 0: ret = { - "result":-1, - "message":"Invalid workspace" + "result": -1, + "message": "Invalid workspace" } ret = { - "result":0, - "message":"OK", - "info":self.get_bt_info(path, straid) + "result": 0, + "message": "OK", + "info": self.get_bt_info(path, straid) } return ret @app.post("/bt/qrybtcloses", tags=["Backtest APIs"], description="读取成交数据") def qry_stra_bt_closes( - wsid:str = Body(..., title="工作空间ID", embed=True), - straid:str = Body(..., title="策略ID", embed=True) + wsid: str = Body(..., title="工作空间ID", embed=True), + straid: str = Body(..., title="策略ID", embed=True) ): path = self.get_workspace_path(wsid) if len(path) == 0: ret = { - "result":-1, - "message":"Invalid workspace" + "result": -1, + "message": "Invalid workspace" } ret = { - "result":0, - "message":"OK", - "closes_long":self.get_bt_closes(path, straid)[0], - "closes_short":self.get_bt_closes(path, straid)[1], - "closes_all":self.get_bt_closes(path, straid)[2], + "result": 0, + "message": "OK", + "closes_long": self.get_bt_closes(path, straid)[0], + "closes_short": self.get_bt_closes(path, straid)[1], + "closes_all": self.get_bt_closes(path, straid)[2], "closes_month": self.get_bt_closes(path, straid)[3], "closes_year": self.get_bt_closes(path, straid)[4] } @@ -426,20 +427,20 @@ def qry_stra_bt_closes( @app.post("/bt/qrybtanalysis", tags=["Backtest APIs"], description="读取策略分析") def qry_stra_bt_analysis( - wsid:str = Body(..., title="工作空间ID", embed=True), - straid:str = Body(..., title="策略ID", embed=True) + wsid: str = Body(..., title="工作空间ID", embed=True), + straid: str = Body(..., title="策略ID", embed=True) ): path = self.get_workspace_path(wsid) if len(path) == 0: ret = { - "result":-1, - "message":"Invalid workspace" + "result": -1, + "message": "Invalid workspace" } ret = { - "result":0, - "message":"OK", - "analysis":self.get_bt_analysis(path, straid) + "result": 0, + "message": "OK", + "analysis": self.get_bt_analysis(path, straid) } return ret @@ -452,7 +453,7 @@ def get_all_strategy(self, path) -> list: ret.append(filename) return ret - def get_bt_info(self, path:str, straid:str) -> dict: + def get_bt_info(self, path: str, straid: str) -> dict: filename = f"{straid}/summary.json" filename = os.path.join(path, filename) if not os.path.exists(filename): @@ -480,9 +481,9 @@ def get_bt_info(self, path:str, straid:str) -> dict: def get_bt_analysis(self, path: str, straid: str) -> dict: funds_filename = f"{straid}/funds.csv" - funds_filename = os.path.join(path,funds_filename) + funds_filename = os.path.join(path, funds_filename) closes_filename = f"{straid}/closes.csv" - closes_filename = os.path.join(path,closes_filename) + closes_filename = os.path.join(path, closes_filename) if not (os.path.exists(funds_filename) or os.path.exists(closes_filename)): return None @@ -504,7 +505,7 @@ def get_bt_analysis(self, path: str, straid: str) -> dict: 'summary_long': summary_long } - def get_bt_funds(self, path:str, straid:str) -> list: + def get_bt_funds(self, path: str, straid: str) -> list: filename = f"{straid}/funds.csv" filename = os.path.join(path, filename) if not os.path.exists(filename): @@ -533,10 +534,10 @@ def get_bt_funds(self, path:str, straid:str) -> list: tItem["fee"] = float(cells[4]) funds.append(tItem) - + return funds - def get_bt_closes(self, path:str, straid:str): + def get_bt_closes(self, path: str, straid: str): summary_file = f"{straid}/summary.json" summary_file = os.path.join(path, summary_file) closes_file = f"{straid}/closes.csv" @@ -566,29 +567,28 @@ def get_bt_closes(self, path:str, straid:str): closes_all = list() for item in np_trade: litem = { - "opentime":int(item[2]), - "closetime":int(item[4]), - "profit":float(item[7]), - "direct":str(item[1]), - "openprice":float(item[3]), - "closeprice":float(item[5]), - "maxprofit":float(item[8]), - "maxloss":float(item[9]), - "qty":int(item[6]), + "opentime": int(item[2]), + "closetime": int(item[4]), + "profit": float(item[7]), + "direct": str(item[1]), + "openprice": float(item[3]), + "closeprice": float(item[5]), + "maxprofit": float(item[8]), + "maxloss": float(item[9]), + "qty": int(item[6]), "capital": capital, - 'profit_sum':float(item[16]), - 'Withdrawal':float(item[17]), - 'profit_ratio':float(item[18]), - 'Withdrawal_ratio':float(item[19]) + 'profit_sum': float(item[16]), + 'Withdrawal': float(item[17]), + 'profit_ratio': float(item[18]), + 'Withdrawal_ratio': float(item[19]) } closes_all.append(litem) df_closes['time'] = df_closes['closetime'].apply(lambda x: datetime.datetime.strptime(str(x), '%Y%m%d%H%M')) df_c_m = df_closes.resample(rule='M', on='time', label='right', - closed='right').agg({ - 'profit': 'sum', - 'maxprofit': 'sum', - 'maxloss': 'sum', - }) + closed='right').agg({'profit': 'sum', + 'maxprofit': 'sum', + 'maxloss': 'sum', + }) df_c_m = df_c_m.reset_index() df_c_m['equity'] = df_c_m['profit'].expanding(1).sum() + capital df_c_m['monthly_profit'] = 100 * (df_c_m['equity'] / df_c_m['equity'].shift(1).fillna(value=capital) - 1) @@ -596,21 +596,20 @@ def get_bt_closes(self, path:str, straid:str): np_m = np.array(df_c_m).tolist() for item in np_m: litem = { - "time":int(item[0].strftime('%Y%m')), - "profit":float(item[1]), - 'maxprofit':float(item[2]), - 'maxloss':float(item[3]), - 'equity':float(item[4]), - 'monthly_profit':float(item[5]) + "time": int(item[0].strftime('%Y%m')), + "profit": float(item[1]), + 'maxprofit': float(item[2]), + 'maxloss': float(item[3]), + 'equity': float(item[4]), + 'monthly_profit': float(item[5]) } closes_month.append(litem) df_c_y = df_closes.resample(rule='Y', on='time', label='right', - closed='right').agg({ - 'profit': 'sum', - 'maxprofit': 'sum', - 'maxloss': 'sum', - }) + closed='right').agg({'profit': 'sum', + 'maxprofit': 'sum', + 'maxloss': 'sum', + }) df_c_y = df_c_y.reset_index() df_c_y['equity'] = df_c_y['profit'].expanding(1).sum() + capital df_c_y['monthly_profit'] = 100 * (df_c_y['equity'] / df_c_y['equity'].shift(1).fillna(value=capital) - 1) @@ -618,12 +617,12 @@ def get_bt_closes(self, path:str, straid:str): np_y = np.array(df_c_y).tolist() for item in np_y: litem = { - "time":int(item[0].strftime('%Y%m')), - "profit":float(item[1]), - 'maxprofit':float(item[2]), - 'maxloss':float(item[3]), - 'equity':float(item[4]), - 'annual_profit':float(item[5]) + "time": int(item[0].strftime('%Y%m')), + "profit": float(item[1]), + 'maxprofit': float(item[2]), + 'maxloss': float(item[3]), + 'equity': float(item[4]), + 'annual_profit': float(item[5]) } closes_year.append(litem) @@ -631,30 +630,30 @@ def get_bt_closes(self, path:str, straid:str): df_short = df_closes[df_closes['direct'].apply(lambda x: 'SHORT' in x)] df_long = df_long.copy() df_short = df_short.copy() - df_long["long_profit"] = df_long["profit"].expanding(1).sum()-df_long["fee"].expanding(1).sum() + df_long["long_profit"] = df_long["profit"].expanding(1).sum() - df_long["fee"].expanding(1).sum() closes_long = list() closes_short = list() np_long = np.array(df_long).tolist() for item in np_long: litem = { - "date":int(item[4]), - "long_profit":float(item[-1]), - "capital":capital + "date": int(item[4]), + "long_profit": float(item[-1]), + "capital": capital } closes_long.append(litem) - df_short["short_profit"] = df_short["profit"].expanding(1).sum()-df_short["fee"].expanding(1).sum() + df_short["short_profit"] = df_short["profit"].expanding(1).sum() - df_short["fee"].expanding(1).sum() np_short = np.array(df_short).tolist() for item in np_short: litem = { - "date":int(item[4]), - "short_profit":float(item[-1]), - "capital":capital + "date": int(item[4]), + "short_profit": float(item[-1]), + "capital": capital } closes_short.append(litem) return closes_long, closes_short, closes_all, closes_month, closes_year - def get_bt_trades(self, path:str, straid:str) -> list: + def get_bt_trades(self, path: str, straid: str) -> list: filename = f"{straid}/trades.csv" filename = os.path.join(path, filename) if not os.path.exists(filename): @@ -689,10 +688,10 @@ def get_bt_trades(self, path:str, straid:str) -> list: item["fee"] = float(cells[4]) items.append(item) - + return items - def get_bt_rounds(self, path:str, straid:str) -> list: + def get_bt_rounds(self, path: str, straid: str) -> list: filename = f"{straid}/closes.csv" filename = os.path.join(path, filename) if not os.path.exists(filename): @@ -723,10 +722,10 @@ def get_bt_rounds(self, path:str, straid:str) -> list: } items.append(item) - + return items - def get_bt_signals(self, path:str, straid:str) -> list: + def get_bt_signals(self, path: str, straid: str) -> list: filename = f"{straid}/signals.csv" filename = os.path.join(path, filename) if not os.path.exists(filename): @@ -752,10 +751,10 @@ def get_bt_signals(self, path:str, straid:str) -> list: } items.append(item) - + return items - def get_bt_kline(self, path:str, straid:str) -> list: + def get_bt_kline(self, path: str, straid: str) -> list: if self.dt_servo is None: return None @@ -778,7 +777,7 @@ def get_bt_kline(self, path:str, straid:str) -> list: index = None marks = None - #如果有btchart,就用btchart定义的K线 + # 如果有btchart,就用btchart定义的K线 filename = f"{straid}/btchart.json" filename = os.path.join(path, filename) if os.path.exists(filename): @@ -841,18 +840,18 @@ def get_bt_kline(self, path:str, straid:str) -> list: if barList is None: return None - isDay = period[0]=='d' + isDay = period[0] == 'd' bars = list() for realBar in barList: bars.append(dict( - bartime = int(realBar['date'] if isDay else 199000000000 + realBar['time']), - open = realBar['open'], - high = realBar['high'], - low = realBar['low'], - close = realBar['close'], - volume = realBar['volume'], - turnover = realBar['turnover'] + bartime=int(realBar['date'] if isDay else 199000000000 + realBar['time']), + open=realBar['open'], + high=realBar['high'], + low=realBar['low'], + close=realBar['close'], + volume=realBar['volume'], + turnover=realBar['turnover'] )) return code, bars, index, marks diff --git a/wtpy/monitor/WtLogger.py b/wtpy/monitor/WtLogger.py index 6e237eaf..baac3b86 100644 --- a/wtpy/monitor/WtLogger.py +++ b/wtpy/monitor/WtLogger.py @@ -11,50 +11,51 @@ 'CRITICAL': 'bold_red', } + class WtLogger: - def __init__(self, catName:str='', filename:str="out.log"): + def __init__(self, catName: str = '', filename: str = "out.log"): self.logger = logging.getLogger(catName) self.logger.setLevel(logging.DEBUG) - #创建一个handler,用于写入日志文件 - log_path = os.getcwd()+"/logs/" # 指定文件输出路径,注意logs是个文件夹,一定要加上/,不然会导致输出路径错误,把logs变成文件名的一部分了 + # 创建一个handler,用于写入日志文件 + log_path = os.getcwd() + "/logs/" # 指定文件输出路径,注意logs是个文件夹,一定要加上/,不然会导致输出路径错误,把logs变成文件名的一部分了 if not os.path.exists(log_path): os.mkdir(log_path) - logname = log_path + filename #指定输出的日志文件名 - fh = TimedRotatingFileHandler(logname, encoding = 'utf-8', when="d") # 指定utf-8格式编码,避免输出的日志文本乱码 + logname = log_path + filename # 指定输出的日志文件名 + fh = TimedRotatingFileHandler(logname, encoding='utf-8', when="d") # 指定utf-8格式编码,避免输出的日志文本乱码 fh.setLevel(logging.INFO) - #创建一个handler,用于将日志输出到控制台 + # 创建一个handler,用于将日志输出到控制台 ch = logging.StreamHandler() ch.setLevel(logging.INFO) formatter = logging.Formatter( - fmt='[%(asctime)s.%(msecs)03d - %(levelname)s] %(message)s', + fmt='[%(asctime)s.%(msecs)03d - %(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S' - ) + ) fh.setFormatter(formatter) # 定义handler的输出格式 formatter = colorlog.ColoredFormatter( - fmt='%(log_color)s[%(asctime)s.%(msecs)03d - %(levelname)s] %(message)s', + fmt='%(log_color)s[%(asctime)s.%(msecs)03d - %(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S', log_colors=log_colors_config - ) + ) ch.setFormatter(formatter) # 给logger添加handler self.logger.addHandler(fh) self.logger.addHandler(ch) - def info(self, message:str): + def info(self, message: str): self.logger.info(message) - def warn(self, message:str): + def warn(self, message: str): self.logger.warn(message) - def error(self, message:str): + def error(self, message: str): self.logger.error(message) - def fatal(self, message:str): + def fatal(self, message: str): self.logger.fatal(message) diff --git a/wtpy/monitor/WtMonSvr.py b/wtpy/monitor/WtMonSvr.py index 6151ee7f..4238027b 100644 --- a/wtpy/monitor/WtMonSvr.py +++ b/wtpy/monitor/WtMonSvr.py @@ -1,6 +1,7 @@ from fastapi import FastAPI, Body, Request from fastapi.staticfiles import StaticFiles from fastapi.middleware.gzip import GZipMiddleware +from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware from starlette.responses import RedirectResponse, FileResponse import uvicorn @@ -24,27 +25,32 @@ import signal import platform + def isWindows(): if "windows" in platform.system().lower(): return True return False + def get_session(request: Request, key: str): if key not in request["session"]: return None return request["session"][key] + def set_session(request: Request, key: str, val): request["session"][key] = val + def pop_session(request: Request, key: str): if key not in request["session"]: return request["session"].pop(key) -def AES_Encrypt(key:str, data:str): - from Crypto.Cipher import AES # pip install pycryptodome + +def AES_Encrypt(key: str, data: str): + from Crypto.Cipher import AES # pip install pycryptodome vi = '0102030405060708' pad = lambda s: s + (16 - len(s) % 16) * chr(16 - len(s) % 16) data = pad(data) @@ -58,8 +64,9 @@ def AES_Encrypt(key:str, data:str): # 对byte字符串按utf-8进行解码 return enctext -def AES_Decrypt(key:str, data:str): - from Crypto.Cipher import AES # pip install pycryptodome + +def AES_Decrypt(key: str, data: str): + from Crypto.Cipher import AES # pip install pycryptodome vi = '0102030405060708' data = data.encode('utf8') encodebytes = base64.decodebytes(data) @@ -91,7 +98,7 @@ def get_tail(filename, N: int = 100, encoding="GBK"): return ''.join(last_line), len(last_line) -def check_auth(request: Request, token:str = None, seckey:str = None): +def check_auth(request: Request, token: str = None, seckey: str = None): if token is None: tokeninfo = get_session(request, "tokeninfo") # session里没有用户信息 @@ -131,6 +138,7 @@ def check_auth(request: Request, token:str = None, seckey:str = None): return True, tokeninfo + def get_cfg_tree(root: str, name: str): if not os.path.exists(root): return { @@ -195,7 +203,7 @@ def get_cfg_tree(root: str, name: str): if "executers" in cfgObj: filename = cfgObj["executers"] - if type(filename) == str: + if isinstance(filename, str): filepath = os.path.join(root, filename) ret['children'].append({ "label": filename, @@ -207,7 +215,7 @@ def get_cfg_tree(root: str, name: str): if "parsers" in cfgObj: filename = cfgObj["parsers"] - if type(filename) == str: + if isinstance(filename, str): filepath = os.path.join(root, filename) ret['children'].append({ "label": filename, @@ -219,7 +227,7 @@ def get_cfg_tree(root: str, name: str): if "traders" in cfgObj: filename = cfgObj["traders"] - if type(filename) == str: + if isinstance(filename, str): filepath = os.path.join(root, filename) ret['children'].append({ "label": filename, @@ -300,18 +308,18 @@ def __init__(self): def notify(self, level: str, msg: str): return -from fastapi.middleware.cors import CORSMiddleware class WtMonSvr(WatcherSink): - def __init__(self, static_folder: str = "static/", deploy_dir="C:/", sink: WtMonSink = None, notifyTimeout:bool = True): - ''' + def __init__(self, static_folder: str = "static/", deploy_dir="C:/", sink: WtMonSink = None, + notifyTimeout: bool = True): + """ WtMonSvr构造函数 @static_folder 静态文件根目录 @static_url_path 静态文件访问路径 @deploy_dir 实盘部署目录 - ''' + """ self.logger = WtLogger(__name__, "WtMonSvr.log") self._sink_ = sink @@ -341,10 +349,10 @@ def __init__(self, static_folder: str = "static/", deploy_dir="C:/", sink: WtMon script_dir = os.path.dirname(__file__) static_folder = os.path.join(script_dir, static_folder) - target_dir = os.path.join(static_folder,"console") + target_dir = os.path.join(static_folder, "console") app.mount("/console", StaticFiles(directory=target_dir), name="console") - target_dir = os.path.join(static_folder,"mobile") + target_dir = os.path.join(static_folder, "mobile") app.mount("/mobile", StaticFiles(directory=target_dir), name="mobile") self.app = app @@ -360,29 +368,29 @@ def __init__(self, static_folder: str = "static/", deploy_dir="C:/", sink: WtMon self.init_comm_apis(app) def enable_token(self, seckey: str = "WtMonSvr@2021"): - ''' + """ 启用访问令牌, 默认通过session方式验证 注意: 这里如果启用令牌访问的话, 需要安装pycryptodome, 所以改成单独控制 - ''' - + """ + self.__sec_key__ = seckey self.__token_enabled__ = True def set_bt_mon(self, btMon: WtBtMon): - ''' + """ 设置回测管理器 @btMon 回测管理器WtBtMon实例 - ''' + """ self.__bt_mon__ = btMon self.init_bt_apis(self.app) def set_dt_servo(self, dtServo: WtDtServo): - ''' + """ 设置DtServo @dtServo 本地数据伺服WtDtServo实例 - ''' + """ self.__dt_servo__ = dtServo def init_bt_apis(self, app: FastAPI): @@ -390,13 +398,13 @@ def init_bt_apis(self, app: FastAPI): # 拉取K线数据 @app.post("/bt/qrybars", tags=["回测管理接口"]) async def qry_bt_bars( - request: Request, - token: str = Body(None, title="访问令牌", embed=True), - code: str = Body(..., title="合约代码", embed=True), - period: str = Body(..., title="K线周期", embed=True), - stime: int = Body(None, title="开始时间", embed=True), - etime: int = Body(..., title="结束时间", embed=True), - count: int = Body(None, title="数据条数", embed=True) + request: Request, + token: str = Body(None, title="访问令牌", embed=True), + code: str = Body(..., title="合约代码", embed=True), + period: str = Body(..., title="K线周期", embed=True), + stime: int = Body(None, title="开始时间", embed=True), + etime: int = Body(..., title="结束时间", embed=True), + count: int = Body(None, title="数据条数", embed=True) ): bSucc, userInfo = check_auth(request, token, self.__sec_key__) if not bSucc: @@ -1010,13 +1018,13 @@ def cmd_run_stra_bt( def init_mgr_apis(self, app: FastAPI): - '''下面是API接口的编写''' + """下面是API接口的编写""" @app.post("/mgr/login", tags=["用户接口"]) async def cmd_login( - request: Request, - loginid: str = Body(..., title="用户名", embed=True), - passwd: str = Body(..., title="用户密码", embed=True) + request: Request, + loginid: str = Body(..., title="用户名", embed=True), + passwd: str = Body(..., title="用户密码", embed=True) ): if True: user = loginid @@ -1138,7 +1146,7 @@ async def cmd_add_group( name: str = Body('', title="组合名称", embed=True), path: str = Body('', title="组合路径", embed=True), gtype: str = Body('', title="组合类型", embed=True), - info: str = Body('', title="组合信息",embed=True), + info: str = Body('', title="组合信息", embed=True), env: str = Body('', title="组合环境,实盘/回测", embed=True), datmod: str = Body('mannual', title="数据模式,mannal/auto", embed=True), mqurl: str = Body('', title="消息队列地址", embed=True), @@ -1198,7 +1206,8 @@ async def cmd_add_group( "result": -2, "message": "添加用户失败" } - except: + except Exception as e: + print(e) ret = { "result": -1, "message": "请求解析失败" @@ -1455,7 +1464,7 @@ async def qry_groups( return tokenInfo products = tokenInfo["products"] - + try: groups = self.__data_mgr__.get_groups() rets = list() @@ -1469,7 +1478,8 @@ async def qry_groups( "message": "Ok", "groups": rets } - except: + except Exception as e: + print(e) ret = { "result": -1, "message": "请求解析失败" @@ -1565,7 +1575,8 @@ async def cmd_commit_group_file( "result": 0, "message": "Ok" } - except: + except Exception as e: + print(e) ret = { "result": -1, "message": "文件保存失败" @@ -1666,7 +1677,8 @@ async def qry_logs( "content": content, "lines": lines } - except: + except Exception as e: + print(e) ret = { "result": -1, "message": "请求解析失败" @@ -2410,19 +2422,20 @@ async def cmd_commit_group_filters( "result": 0, "message": "Ok" } - except: + except Exception as e: + print(e) ret = { "result": -1, "message": "过滤器保存失败" } return ret - + @app.get("/mgr/auth", tags=["令牌认证"]) @app.post("/mgr/auth") async def authority( - request: Request, - token: str = Body(None, title="访问令牌", embed=True) + request: Request, + token: str = Body(None, title="访问令牌", embed=True) ): bSucc, userInfo = check_auth(request, token, self.__sec_key__) if not bSucc: @@ -2438,7 +2451,7 @@ def init_comm_apis(self, app: FastAPI): @app.get("/console") async def console_entry(): return RedirectResponse("/console/index.html") - + @app.get("/mobile") async def mobile_entry(): return RedirectResponse("/mobile/index.html") @@ -2446,7 +2459,7 @@ async def mobile_entry(): @app.get("/favicon.ico") async def favicon_entry(): return FileResponse(os.path.join(self.static_folder, "favicon.ico")) - + @app.get("/hasbt") @app.post("/hasbt") async def check_btmon(): @@ -2510,7 +2523,7 @@ def on_notify(self, grpid: str, chnl: str, message: str): def on_timeout(self, grpid: str): if not self.notifyTimeout: return - + if self._sink_: grpInfo = self.__data_mgr__.get_group(grpid) self._sink_.notify("fatal", f'检测到 {grpInfo["name"]}[{grpid}]的MQ消息超时,请及时检查并处理!!!') diff --git a/wtpy/monitor/__init__.py b/wtpy/monitor/__init__.py index 8f117dde..51a8c022 100644 --- a/wtpy/monitor/__init__.py +++ b/wtpy/monitor/__init__.py @@ -1,7 +1,6 @@ - from .WtMonSvr import WtMonSvr, WtMonSink from .WtBtMon import WtBtMon from .WtLogger import WtLogger from .WtBtSnooper import WtBtSnooper -__all__ = ["WtMonSvr","WtBtMon","WtLogger", "WtMonSink", "WtBtSnooper"] \ No newline at end of file +__all__ = ["WtMonSvr", "WtBtMon", "WtLogger", "WtMonSink", "WtBtSnooper"] diff --git a/wtpy/wrapper/ContractLoader.py b/wtpy/wrapper/ContractLoader.py index c776ffe3..5aea159e 100644 --- a/wtpy/wrapper/ContractLoader.py +++ b/wtpy/wrapper/ContractLoader.py @@ -1,17 +1,21 @@ +import json +import os +from ctypes import cdll, c_char_p, c_bool +from enum import Enum + from .PlatformHelper import PlatformHelper as ph -import os, json -from ctypes import cdll,c_char_p, c_bool, c_bool -from enum import Enum + class LoaderType(Enum): - ''' + """ 引擎类型 枚举变量 - ''' - LT_CTP = 1 - LT_CTPOpt = 2 + """ + LT_CTP = 1 + LT_CTPOpt = 2 + -def getModuleName(lType:LoaderType)->str: +def getModuleName(lType: LoaderType) -> str: if lType == LoaderType.LT_CTP: filename = "CTPLoader" elif lType == LoaderType.LT_CTPOpt: @@ -19,7 +23,7 @@ def getModuleName(lType:LoaderType)->str: else: raise Exception('Invalid loader type') return - + paths = os.path.split(__file__) exename = ph.getModule(filename) a = (paths[:-1] + (exename,)) @@ -28,23 +32,23 @@ def getModuleName(lType:LoaderType)->str: class ContractLoader: - def __init__(self, lType:LoaderType = LoaderType.LT_CTP): + def __init__(self, lType: LoaderType = LoaderType.LT_CTP): print(getModuleName(lType)) self.api = cdll.LoadLibrary(getModuleName(lType)) - self.api.run.argtypes = [ c_char_p, c_bool, c_bool] + self.api.run.argtypes = [c_char_p, c_bool, c_bool] - def start(self, cfgfile:str = 'config.ini', bAsync:bool = False): - ''' + def start(self, cfgfile: str = 'config.ini', bAsync: bool = False): + """ 启动合约加载器 @cfgfile 配置文件名 @bAsync 是否异步,异步则立即返回,默认False - ''' - self.api.run(bytes(cfgfile, encoding = "utf8"), bAsync, True) + """ + self.api.run(bytes(cfgfile, encoding="utf8"), bAsync, True) - def start_with_config(self, config:dict, bAsync:bool = False): - ''' + def start_with_config(self, config: dict, bAsync: bool = False): + """ 启动合约加载器 @cfgfile 配置文件名 @bAsync 是否异步,异步则立即返回,默认False - ''' - self.api.run(bytes(json.dumps(config), encoding = "utf8"), bAsync, False) \ No newline at end of file + """ + self.api.run(bytes(json.dumps(config), encoding="utf8"), bAsync, False) diff --git a/wtpy/wrapper/PlatformHelper.py b/wtpy/wrapper/PlatformHelper.py index 0fe21fb4..31227b3d 100644 --- a/wtpy/wrapper/PlatformHelper.py +++ b/wtpy/wrapper/PlatformHelper.py @@ -1,11 +1,12 @@ import platform + class PlatformHelper: @staticmethod def isPythonX64() -> bool: ret = platform.architecture() - return (ret[0] == "64bit") + return ret[0] == "64bit" @staticmethod def isWindows() -> bool: @@ -15,17 +16,17 @@ def isWindows() -> bool: return False @staticmethod - def getModule(moduleName:str, subdir:str="") -> str: + def getModule(moduleName: str, subdir: str = "") -> str: dllname = "" ext = "" prefix = "" - if PlatformHelper.isWindows(): #windows平台 + if PlatformHelper.isWindows(): # windows平台 ext = ".dll" if PlatformHelper.isPythonX64(): dllname = "x64/" else: dllname = "x86/" - else:#Linux平台 + else: # Linux平台 dllname = "linux/" prefix = "lib" ext = ".so" @@ -35,11 +36,10 @@ def getModule(moduleName:str, subdir:str="") -> str: dllname += prefix + moduleName + ext return dllname - + @staticmethod - def auto_encode(s:str) -> bytes: + def auto_encode(s: str) -> bytes: if PlatformHelper.isWindows(): - return bytes(s, encoding = "utf-8").decode('utf-8').encode('gbk') + return bytes(s, encoding="utf-8").decode('utf-8').encode('gbk') else: - return bytes(s, encoding = "utf-8") - \ No newline at end of file + return bytes(s, encoding="utf-8") diff --git a/wtpy/wrapper/TraderDumper.py b/wtpy/wrapper/TraderDumper.py index 6cba9318..f915c296 100644 --- a/wtpy/wrapper/TraderDumper.py +++ b/wtpy/wrapper/TraderDumper.py @@ -1,56 +1,61 @@ from .PlatformHelper import PlatformHelper as ph -from ctypes import cdll,CFUNCTYPE,c_void_p,c_char_p,c_uint32,c_double,c_bool,c_uint64 +from ctypes import cdll, CFUNCTYPE, c_void_p, c_char_p, c_uint32, c_double, c_bool, c_uint64 import os import chardet import yaml import json -CB_ACCOUNT = CFUNCTYPE(c_void_p, c_char_p, c_uint32, c_char_p, c_double, c_double, c_double, c_double, - c_double, c_double, c_double, c_double, c_double, c_double, c_bool) -CB_ORDER = CFUNCTYPE(c_void_p, c_char_p, c_char_p, c_char_p, c_uint32, c_char_p, c_uint32, c_uint32, - c_double, c_double, c_double, c_double, c_uint32, c_uint32, c_uint64, c_uint32, c_char_p, c_bool) -CB_TRADE = CFUNCTYPE(c_void_p, c_char_p, c_char_p, c_char_p, c_uint32, c_char_p, c_char_p, c_uint32, - c_uint32, c_double, c_double, c_double, c_uint32, c_uint32, c_uint64, c_bool) -CB_POSITION = CFUNCTYPE(c_void_p, c_char_p, c_char_p, c_char_p, c_uint32, c_uint32, c_double, c_double, c_double, - c_double, c_double, c_double, c_uint32, c_bool) +CB_ACCOUNT = CFUNCTYPE(c_void_p, c_char_p, c_uint32, c_char_p, c_double, c_double, c_double, c_double, + c_double, c_double, c_double, c_double, c_double, c_double, c_bool) +CB_ORDER = CFUNCTYPE(c_void_p, c_char_p, c_char_p, c_char_p, c_uint32, c_char_p, c_uint32, c_uint32, + c_double, c_double, c_double, c_double, c_uint32, c_uint32, c_uint64, c_uint32, c_char_p, c_bool) +CB_TRADE = CFUNCTYPE(c_void_p, c_char_p, c_char_p, c_char_p, c_uint32, c_char_p, c_char_p, c_uint32, + c_uint32, c_double, c_double, c_double, c_uint32, c_uint32, c_uint64, c_bool) +CB_POSITION = CFUNCTYPE(c_void_p, c_char_p, c_char_p, c_char_p, c_uint32, c_uint32, c_double, c_double, c_double, + c_double, c_double, c_double, c_uint32, c_bool) + class DumperSink: - def on_account(self, channelid, curTDate:int, currency, prebalance:float, balance:float, dynbalance:float, - closeprofit:float, dynprofit:float, fee:float, margin:float, deposit:float, withdraw:float, isLast:bool): + def on_account(self, channelid, curTDate: int, currency, prebalance: float, balance: float, dynbalance: float, + closeprofit: float, dynprofit: float, fee: float, margin: float, deposit: float, withdraw: float, + isLast: bool): pass - def on_order(self, channelid, exchg, code, curTDate:int, orderid, direct:int, offset:int, - volume:float, leftover:float, traded:float, price:float, ordertype:int, pricetype:int, ordertime:int, state:int, statemsg, isLast:bool): + def on_order(self, channelid, exchg, code, curTDate: int, orderid, direct: int, offset: int, + volume: float, leftover: float, traded: float, price: float, ordertype: int, pricetype: int, + ordertime: int, state: int, statemsg, isLast: bool): pass - def on_trade(self, channelid, exchg, code, curTDate:int, tradeid, orderid, direct:int, - offset:int, volume:float, price:float, amount:float, ordertype:int, tradetype:int, tradetime:int, isLast:bool): + def on_trade(self, channelid, exchg, code, curTDate: int, tradeid, orderid, direct: int, + offset: int, volume: float, price: float, amount: float, ordertype: int, tradetype: int, + tradetime: int, isLast: bool): pass - def on_position(self, channelid, exchg, code, curTDate:int, direct:int, volume:float, newvol:float, - cost:float, margin:float, avgpx:float, dynprofit:float, volscale:int, isLast:bool): + def on_position(self, channelid, exchg, code, curTDate: int, direct: int, volume: float, newvol: float, + cost: float, margin: float, avgpx: float, dynprofit: float, volscale: int, isLast: bool): pass + class TraderDumper: - def __init__(self, sink:DumperSink, logCfg:str = 'logCfg.yaml'): + def __init__(self, sink: DumperSink, logCfg: str = 'logCfg.yaml'): paths = os.path.split(__file__) dllname = ph.getModule("TraderDumper") a = (paths[:-1] + (dllname,)) _path = os.path.join(*a) self.api = cdll.LoadLibrary(_path) - self.sink:DumperSink = sink + self.sink: DumperSink = sink self.__config__ = None - self.api.init(bytes(logCfg, encoding = "utf8")) + self.api.init(bytes(logCfg, encoding="utf8")) - #注册回调函数 - self.cb_account = CB_ACCOUNT(self.sink.on_account) - self.cb_order = CB_ORDER(self.sink.on_order) - self.cb_trade = CB_TRADE(self.sink.on_trade) - self.cb_position = CB_POSITION(self.sink.on_position) + # 注册回调函数 + self.cb_account = CB_ACCOUNT(self.sink.on_account) + self.cb_order = CB_ORDER(self.sink.on_order) + self.cb_trade = CB_TRADE(self.sink.on_trade) + self.cb_position = CB_POSITION(self.sink.on_position) self.api.register_callbacks(self.cb_account, self.cb_order, self.cb_trade, self.cb_position) def __check_config__(self): @@ -66,15 +71,15 @@ def __check_config__(self): def clear_traders(self): self.__config__['traders'] = [] - def add_trader(self, params:dict): + def add_trader(self, params: dict): self.__config__['traders'].append(params) - def init(self, folder:str, - cfgfile:str = 'config.yaml', - commfile:str= None, - contractfile:str = None, - sessionfile:str = None): - + def init(self, folder: str, + cfgfile: str = 'config.yaml', + commfile: str = None, + contractfile: str = None, + sessionfile: str = None): + if os.path.exists(cfgfile): f = open(cfgfile, "rb") content = f.read() @@ -93,7 +98,7 @@ def init(self, folder:str, if contractfile is not None: self.__config__["basefiles"]["contract"] = folder + contractfile - + if sessionfile is not None: self.__config__["basefiles"]["session"] = folder + sessionfile @@ -105,11 +110,11 @@ def __commit__(self): f = open("config.json", "w") f.write(content) f.close() - self.api.config(bytes(content, encoding = "utf8"), False) + self.api.config(bytes(content, encoding="utf8"), False) - def run(self, bOnce:bool = False): + def run(self, bOnce: bool = False): self.__commit__() self.api.run(bOnce) def release(self): - self.api.release() \ No newline at end of file + self.api.release() diff --git a/wtpy/wrapper/WtBtWrapper.py b/wtpy/wrapper/WtBtWrapper.py index 6de345d8..90e29730 100644 --- a/wtpy/wrapper/WtBtWrapper.py +++ b/wtpy/wrapper/WtBtWrapper.py @@ -1,29 +1,33 @@ from ctypes import c_uint32, cdll, c_char_p, c_bool, c_ulong, c_uint64, c_double, c_int, POINTER -from wtpy.WtCoreDefs import CB_STRATEGY_INIT, CB_STRATEGY_TICK, CB_STRATEGY_CALC, CB_STRATEGY_BAR, CB_STRATEGY_GET_BAR, CB_STRATEGY_GET_TICK, CB_STRATEGY_GET_POSITION, CB_STRATEGY_COND_TRIGGER +from wtpy.WtCoreDefs import CB_STRATEGY_INIT, CB_STRATEGY_TICK, CB_STRATEGY_CALC, CB_STRATEGY_BAR, CB_STRATEGY_GET_BAR, \ + CB_STRATEGY_GET_TICK, CB_STRATEGY_GET_POSITION, CB_STRATEGY_COND_TRIGGER from wtpy.WtCoreDefs import CB_HFTSTRA_CHNL_EVT, CB_HFTSTRA_ENTRUST, CB_HFTSTRA_ORD, CB_HFTSTRA_TRD, CB_SESSION_EVENT -from wtpy.WtCoreDefs import CB_HFTSTRA_ORDQUE, CB_HFTSTRA_ORDDTL, CB_HFTSTRA_TRANS, CB_HFTSTRA_GET_ORDQUE, CB_HFTSTRA_GET_ORDDTL, CB_HFTSTRA_GET_TRANS +from wtpy.WtCoreDefs import CB_HFTSTRA_ORDQUE, CB_HFTSTRA_ORDDTL, CB_HFTSTRA_TRANS, CB_HFTSTRA_GET_ORDQUE, \ + CB_HFTSTRA_GET_ORDDTL, CB_HFTSTRA_GET_TRANS from wtpy.WtCoreDefs import CHNL_EVENT_READY, CHNL_EVENT_LOST, CB_ENGINE_EVENT from wtpy.WtCoreDefs import FUNC_LOAD_HISBARS, FUNC_LOAD_HISTICKS, FUNC_LOAD_ADJFACTS -from wtpy.WtCoreDefs import EVENT_ENGINE_INIT, EVENT_SESSION_BEGIN, EVENT_SESSION_END, EVENT_ENGINE_SCHDL, EVENT_BACKTEST_END +from wtpy.WtCoreDefs import EVENT_ENGINE_INIT, EVENT_SESSION_BEGIN, EVENT_SESSION_END, EVENT_ENGINE_SCHDL, \ + EVENT_BACKTEST_END from wtpy.WtCoreDefs import WTSTickStruct, WTSBarStruct, WTSOrdQueStruct, WTSOrdDtlStruct, WTSTransStruct from .PlatformHelper import PlatformHelper as ph from wtpy.WtUtilDefs import singleton from wtpy.WtDataDefs import WtNpKline, WtNpOrdDetails, WtNpOrdQueues, WtNpTicks, WtNpTransactions import os + # Python对接C接口的库 @singleton class WtBtWrapper: - ''' + """ Wt平台C接口底层对接模块 - ''' + """ # api可以作为公共变量 api = None ver = "Unknown" _engine = None - + # 构造函数, 传入动态库名 def __init__(self, engine): self._engine = engine @@ -32,7 +36,7 @@ def __init__(self, engine): a = (paths[:-1] + (dllname,)) _path = os.path.join(*a) self.api = cdll.LoadLibrary(_path) - + self.api.get_version.restype = c_char_p self.api.cta_get_last_entertime.restype = c_uint64 self.api.cta_get_first_entertime.restype = c_uint64 @@ -85,7 +89,7 @@ def __init__(self, engine): self.api.hft_get_position_avgpx.restype = c_double self.api.hft_get_undone.restype = c_double self.api.hft_get_price.restype = c_double - + self.api.hft_buy.restype = c_char_p self.api.hft_buy.argtypes = [c_ulong, c_char_p, c_double, c_double, c_char_p, c_int] self.api.hft_sell.restype = c_char_p @@ -104,7 +108,7 @@ def __init__(self, engine): self.api.get_raw_stdcode.restype = c_char_p - def on_engine_event(self, evtid:int, evtDate:int, evtTime:int): + def on_engine_event(self, evtid: int, evtDate: int, evtTime: int): engine = self._engine if evtid == EVENT_ENGINE_INIT: engine.on_init() @@ -118,14 +122,14 @@ def on_engine_event(self, evtid:int, evtDate:int, evtTime:int): engine.on_backtest_end() return - def on_stra_init(self, id:int): + def on_stra_init(self, id: int): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_init() return - def on_session_event(self, id:int, udate:int, isBegin:bool): + def on_session_event(self, id: int, udate: int, isBegin: bool): engine = self._engine ctx = engine.get_context(id) if ctx is not None: @@ -135,48 +139,49 @@ def on_session_event(self, id:int, udate:int, isBegin:bool): ctx.on_session_end(udate) return - def on_stra_calc(self, id:int, curDate:int, curTime:int): + def on_stra_calc(self, id: int, curDate: int, curTime: int): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_calculate() return - def on_stra_calc_done(self, id:int, curDate:int, curTime:int): + def on_stra_calc_done(self, id: int, curDate: int, curTime: int): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_calculate_done() return - def on_stra_tick(self, id:int, stdCode:str, newTick:POINTER(WTSTickStruct)): + def on_stra_tick(self, id: int, stdCode: str, newTick: POINTER(WTSTickStruct)): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_tick(bytes.decode(stdCode), newTick) return - - def on_stra_bar(self, id:int, stdCode:str, period:str, newBar:POINTER(WTSBarStruct)): + + def on_stra_bar(self, id: int, stdCode: str, period: str, newBar: POINTER(WTSBarStruct)): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_bar(bytes.decode(stdCode), bytes.decode(period), newBar) return - def on_stra_get_bar(self, id:int, stdCode:str, period:str, curBar:POINTER(WTSBarStruct), count:int, isLast:bool): - ''' + def on_stra_get_bar(self, id: int, stdCode: str, period: str, curBar: POINTER(WTSBarStruct), count: int, + isLast: bool): + """ 获取K线回调, 该回调函数因为是python主动发起的, 需要同步执行, 所以不走事件推送 @id 策略id @stdCode 合约代码 @period K线周期 @curBar 最新一条K线 @isLast 是否是最后一条 - ''' + """ engine = self._engine ctx = engine.get_context(id) period = bytes.decode(period) - isDay = period[0]=='d' + isDay = period[0] == 'd' npBars = WtNpKline(isDay) npBars.set_data(curBar, count) @@ -184,14 +189,14 @@ def on_stra_get_bar(self, id:int, stdCode:str, period:str, curBar:POINTER(WTSBar if ctx is not None: ctx.on_getbars(bytes.decode(stdCode), period, npBars) - def on_stra_get_tick(self, id:int, stdCode:str, curTick:POINTER(WTSTickStruct), count:int, isLast:bool): - ''' + def on_stra_get_tick(self, id: int, stdCode: str, curTick: POINTER(WTSTickStruct), count: int, isLast: bool): + """ 获取Tick回调, 该回调函数因为是python主动发起的, 需要同步执行, 所以不走事件推送 @id 策略id @stdCode 合约代码 @curTick 最新一笔Tick @isLast 是否是最后一条 - ''' + """ engine = self._engine ctx = engine.get_context(id) @@ -202,19 +207,19 @@ def on_stra_get_tick(self, id:int, stdCode:str, curTick:POINTER(WTSTickStruct), ctx.on_getticks(bytes.decode(stdCode), npTicks) return - def on_stra_get_position(self, id:int, stdCode:str, qty:float, isLast:bool): + def on_stra_get_position(self, id: int, stdCode: str, qty: float, isLast: bool): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_getpositions(bytes.decode(stdCode), qty, isLast) - def on_stra_cond_triggerd(self, id:int, stdCode:str, target:float, price:float, usertag:str): + def on_stra_cond_triggerd(self, id: int, stdCode: str, target: float, price: float, usertag: str): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_condition_triggered(bytes.decode(stdCode), target, price, bytes.decode(usertag)) - def on_hftstra_channel_evt(self, id:int, trader:str, evtid:int): + def on_hftstra_channel_evt(self, id: int, trader: str, evtid: int): engine = self._engine ctx = engine.get_context(id) if ctx is None: @@ -224,23 +229,25 @@ def on_hftstra_channel_evt(self, id:int, trader:str, evtid:int): elif evtid == CHNL_EVENT_LOST: ctx.on_channel_lost() - def on_hftstra_order(self, id:int, localid:int, stdCode:str, isBuy:bool, totalQty:float, leftQty:float, price:float, isCanceled:bool, userTag:str): + def on_hftstra_order(self, id: int, localid: int, stdCode: str, isBuy: bool, totalQty: float, leftQty: float, + price: float, isCanceled: bool, userTag: str): stdCode = bytes.decode(stdCode) - userTag = bytes.decode(userTag,"gbk") + userTag = bytes.decode(userTag, "gbk") engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_order(localid, stdCode, isBuy, totalQty, leftQty, price, isCanceled, userTag) - def on_hftstra_trade(self, id:int, localid:int, stdCode:str, isBuy:bool, qty:float, price:float, userTag:str): + def on_hftstra_trade(self, id: int, localid: int, stdCode: str, isBuy: bool, qty: float, price: float, + userTag: str): stdCode = bytes.decode(stdCode) - userTag = bytes.decode(userTag,"gbk") + userTag = bytes.decode(userTag, "gbk") engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_trade(localid, stdCode, isBuy, qty, price, userTag) - def on_hftstra_entrust(self, id:int, localid:int, stdCode:str, bSucc:bool, message:str, userTag:str): + def on_hftstra_entrust(self, id: int, localid: int, stdCode: str, bSucc: bool, message: str, userTag: str): stdCode = bytes.decode(stdCode) message = bytes.decode(message, "gbk") userTag = bytes.decode(userTag, "gbk") @@ -249,61 +256,64 @@ def on_hftstra_entrust(self, id:int, localid:int, stdCode:str, bSucc:bool, messa if ctx is not None: ctx.on_entrust(localid, stdCode, bSucc, message, userTag) - def on_hftstra_order_queue(self, id:int, stdCode:str, newOrdQue:POINTER(WTSOrdQueStruct)): + def on_hftstra_order_queue(self, id: int, stdCode: str, newOrdQue: POINTER(WTSOrdQueStruct)): stdCode = bytes.decode(stdCode) engine = self._engine ctx = engine.get_context(id) - + if ctx is not None: ctx.on_order_queue(stdCode, newOrdQue.to_tuple()) - def on_hftstra_get_order_queue(self, id:int, stdCode:str, newOrdQue:POINTER(WTSOrdQueStruct), count:int, isLast:bool): + def on_hftstra_get_order_queue(self, id: int, stdCode: str, newOrdQue: POINTER(WTSOrdQueStruct), count: int, + isLast: bool): engine = self._engine ctx = engine.get_context(id) - + npHftData = WtNpOrdQueues(forceCopy=False) npHftData.set_data(newOrdQue, count) if ctx is not None: ctx.on_get_order_queue(bytes.decode(stdCode), npHftData) - def on_hftstra_order_detail(self, id:int, stdCode:str, newOrdDtl:POINTER(WTSOrdDtlStruct)): + def on_hftstra_order_detail(self, id: int, stdCode: str, newOrdDtl: POINTER(WTSOrdDtlStruct)): stdCode = bytes.decode(stdCode) engine = self._engine ctx = engine.get_context(id) - + if ctx is not None: ctx.on_order_detail(stdCode, newOrdDtl.to_tuple()) - def on_hftstra_get_order_detail(self, id:int, stdCode:str, newOrdDtl:POINTER(WTSOrdDtlStruct), count:int, isLast:bool): + def on_hftstra_get_order_detail(self, id: int, stdCode: str, newOrdDtl: POINTER(WTSOrdDtlStruct), count: int, + isLast: bool): engine = self._engine ctx = engine.get_context(id) - + npHftData = WtNpOrdDetails(forceCopy=False) npHftData.set_data(newOrdDtl, count) - + if ctx is not None: ctx.on_get_order_detail(bytes.decode(stdCode), npHftData) - def on_hftstra_transaction(self, id:int, stdCode:str, newTrans:POINTER(WTSTransStruct)): + def on_hftstra_transaction(self, id: int, stdCode: str, newTrans: POINTER(WTSTransStruct)): stdCode = bytes.decode(stdCode) engine = self._engine ctx = engine.get_context(id) - + if ctx is not None: ctx.on_transaction(stdCode, newTrans.to_tuple()) - - def on_hftstra_get_transaction(self, id:int, stdCode:str, newTrans:POINTER(WTSTransStruct), count:int, isLast:bool): + + def on_hftstra_get_transaction(self, id: int, stdCode: str, newTrans: POINTER(WTSTransStruct), count: int, + isLast: bool): engine = self._engine ctx = engine.get_context(id) npHftData = WtNpTransactions(forceCopy=False) npHftData.set_data(newTrans, count) - + if ctx is not None: ctx.on_get_transaction(bytes.decode(stdCode), npHftData) - def on_load_fnl_his_bars(self, stdCode:str, period:str) -> bool: + def on_load_fnl_his_bars(self, stdCode: str, period: str) -> bool: engine = self._engine loader = engine.get_extended_data_loader() if loader is None: @@ -311,7 +321,7 @@ def on_load_fnl_his_bars(self, stdCode:str, period:str) -> bool: return loader.load_final_his_bars(bytes.decode(stdCode), bytes.decode(period), self.api.feed_raw_bars) - def on_load_raw_his_bars(self, stdCode:str, period:str) -> bool: + def on_load_raw_his_bars(self, stdCode: str, period: str) -> bool: engine = self._engine loader = engine.get_extended_data_loader() if loader is None: @@ -319,15 +329,15 @@ def on_load_raw_his_bars(self, stdCode:str, period:str) -> bool: return loader.load_raw_his_bars(bytes.decode(stdCode), bytes.decode(period), self.api.feed_raw_bars) - def feed_adj_factors(self, stdCode:str, dates:list, factors:list): + def feed_adj_factors(self, stdCode: str, dates: list, factors: list): stdCode = bytes(stdCode, encoding="utf8") - ''' + """ TODO 这里类型要转一下! 底层接口是传数组的 feed_adj_factors(WtString stdCode, WtUInt32* dates, double* factors, WtUInt32 count) - ''' + """ self.api.feed_adj_factors(stdCode, dates, factors, len(dates)) - def on_load_adj_factors(self, stdCode:str) -> bool: + def on_load_adj_factors(self, stdCode: str) -> bool: engine = self._engine loader = engine.get_extended_data_loader() if loader is None: @@ -336,35 +346,35 @@ def on_load_adj_factors(self, stdCode:str) -> bool: stdCode = bytes.decode(stdCode) return loader.load_adj_factors(stdCode, self.feed_adj_factors) - def on_load_his_ticks(self, stdCode:str, uDate:int) -> bool: + def on_load_his_ticks(self, stdCode: str, uDate: int) -> bool: engine = self._engine loader = engine.get_extended_data_loader() if loader is None: return False - + # feed_raw_ticks(WTSTickStruct* ticks, WtUInt32 count); return loader.load_his_ticks(bytes.decode(stdCode), uDate, self.api.feed_raw_ticks) - def write_log(self, level, message:str, catName:str = ""): - self.api.write_log(level, ph.auto_encode(message), bytes(catName, encoding = "utf8")) + def write_log(self, level, message: str, catName: str = ""): + self.api.write_log(level, ph.auto_encode(message), bytes(catName, encoding="utf8")) - def set_time_range(self, beginTime:int, endTime:int): - ''' + def set_time_range(self, beginTime: int, endTime: int): + """ 设置回测时间区间 @beginTime 开始时间, 格式如yyyymmddHHMM @endTime 结束时间, 格式如yyyymmddHHMM - ''' + """ self.api.set_time_range(beginTime, endTime) - def enable_tick(self, bEnabled:bool = True): - ''' + def enable_tick(self, bEnabled: bool = True): + """ 启用tick回测 @bEnabled 是否启用 - ''' + """ self.api.enable_tick(bEnabled) - ### 实盘和回测有差异 ### - def run_backtest(self, bNeedDump:bool = False, bAsync:bool = False): + # ## 实盘和回测有差异 ### + def run_backtest(self, bNeedDump: bool = False, bAsync: bool = False): self.api.run_backtest(bNeedDump, bAsync) def stop_backtest(self): @@ -376,17 +386,18 @@ def release_backtest(self): def clear_cache(self): self.api.clear_cache() - def get_raw_stdcode(self, stdCode:str): - return bytes.decode(self.api.get_raw_stdcode(bytes(stdCode, encoding = "utf8"))) + def get_raw_stdcode(self, stdCode: str): + return bytes.decode(self.api.get_raw_stdcode(bytes(stdCode, encoding="utf8"))) + + def config_backtest(self, cfgfile: str = 'config.yaml', isFile: bool = True): + self.api.config_backtest(bytes(cfgfile, encoding="utf8"), isFile) - def config_backtest(self, cfgfile:str = 'config.yaml', isFile:bool = True): - self.api.config_backtest(bytes(cfgfile, encoding = "utf8"), isFile) - ### 实盘和回测有差异 ### + # ## 实盘和回测有差异 ### - def initialize_cta(self, logCfg:str = "logcfgbt.yaml", isFile:bool = True, outDir:str = "./outputs_bt"): - ''' + def initialize_cta(self, logCfg: str = "logcfgbt.yaml", isFile: bool = True, outDir: str = "./outputs_bt"): + """ C接口初始化 - ''' + """ self.cb_stra_init = CB_STRATEGY_INIT(self.on_stra_init) self.cb_stra_tick = CB_STRATEGY_TICK(self.on_stra_tick) self.cb_stra_calc = CB_STRATEGY_CALC(self.on_stra_calc) @@ -398,18 +409,19 @@ def initialize_cta(self, logCfg:str = "logcfgbt.yaml", isFile:bool = True, outDi self.cb_engine_event = CB_ENGINE_EVENT(self.on_engine_event) try: self.api.register_evt_callback(self.cb_engine_event) - self.api.register_cta_callbacks(self.cb_stra_init, self.cb_stra_tick, - self.cb_stra_calc, self.cb_stra_bar, self.cb_session_event, self.cb_stra_calc_done, self.cb_stra_cond_trigger) - self.api.init_backtest(bytes(logCfg, encoding = "utf8"), isFile, bytes(outDir, encoding = "utf8")) + self.api.register_cta_callbacks(self.cb_stra_init, self.cb_stra_tick, + self.cb_stra_calc, self.cb_stra_bar, self.cb_session_event, + self.cb_stra_calc_done, self.cb_stra_cond_trigger) + self.api.init_backtest(bytes(logCfg, encoding="utf8"), isFile, bytes(outDir, encoding="utf8")) except OSError as oe: print(oe) - self.write_log(102, "WonderTrader CTA backtest framework initialzied, version: %s" % (self.ver)) + self.write_log(102, "WonderTrader CTA backtest framework initialzied, version: %s" % self.ver) - def initialize_hft(self, logCfg:str = "logcfgbt.yaml", isFile:bool = True, outDir:str = "./outputs_bt"): - ''' + def initialize_hft(self, logCfg: str = "logcfgbt.yaml", isFile: bool = True, outDir: str = "./outputs_bt"): + """ C接口初始化 - ''' + """ self.cb_stra_init = CB_STRATEGY_INIT(self.on_stra_init) self.cb_stra_tick = CB_STRATEGY_TICK(self.on_stra_tick) self.cb_stra_bar = CB_STRATEGY_BAR(self.on_stra_bar) @@ -427,20 +439,21 @@ def initialize_hft(self, logCfg:str = "logcfgbt.yaml", isFile:bool = True, outDi try: self.api.register_evt_callback(self.cb_engine_event) - self.api.register_hft_callbacks(self.cb_stra_init, self.cb_stra_tick, self.cb_stra_bar, - self.cb_hftstra_channel_evt, self.cb_hftstra_order, self.cb_hftstra_trade, - self.cb_hftstra_entrust, self.cb_hftstra_order_detail, self.cb_hftstra_order_queue, - self.cb_hftstra_transaction, self.cb_session_event) - self.api.init_backtest(bytes(logCfg, encoding = "utf8"), isFile, bytes(outDir, encoding = "utf8")) + self.api.register_hft_callbacks(self.cb_stra_init, self.cb_stra_tick, self.cb_stra_bar, + self.cb_hftstra_channel_evt, self.cb_hftstra_order, self.cb_hftstra_trade, + self.cb_hftstra_entrust, self.cb_hftstra_order_detail, + self.cb_hftstra_order_queue, + self.cb_hftstra_transaction, self.cb_session_event) + self.api.init_backtest(bytes(logCfg, encoding="utf8"), isFile, bytes(outDir, encoding="utf8")) except OSError as oe: print(oe) - self.write_log(102, "WonderTrader HFT backtest framework initialzied, version: %s" % (self.ver)) + self.write_log(102, "WonderTrader HFT backtest framework initialzied, version: %s" % self.ver) - def initialize_sel(self, logCfg:str = "logcfgbt.yaml", isFile:bool = True, outDir:str = "./outputs_bt"): - ''' + def initialize_sel(self, logCfg: str = "logcfgbt.yaml", isFile: bool = True, outDir: str = "./outputs_bt"): + """ C接口初始化 - ''' + """ self.cb_stra_init = CB_STRATEGY_INIT(self.on_stra_init) self.cb_stra_tick = CB_STRATEGY_TICK(self.on_stra_tick) self.cb_stra_calc = CB_STRATEGY_CALC(self.on_stra_calc) @@ -452,773 +465,801 @@ def initialize_sel(self, logCfg:str = "logcfgbt.yaml", isFile:bool = True, outDi try: self.api.register_evt_callback(self.cb_engine_event) - self.api.register_sel_callbacks(self.cb_stra_init, self.cb_stra_tick, - self.cb_stra_calc, self.cb_stra_bar, self.cb_session_event, self.cb_stra_calc_done) - self.api.init_backtest(bytes(logCfg, encoding = "utf8"), isFile, bytes(outDir, encoding = "utf8")) + self.api.register_sel_callbacks(self.cb_stra_init, self.cb_stra_tick, + self.cb_stra_calc, self.cb_stra_bar, self.cb_session_event, + self.cb_stra_calc_done) + self.api.init_backtest(bytes(logCfg, encoding="utf8"), isFile, bytes(outDir, encoding="utf8")) except OSError as oe: print(oe) - self.write_log(102, "WonderTrader SEL backtest framework initialzied, version: %s" % (self.ver)) + self.write_log(102, "WonderTrader SEL backtest framework initialzied, version: %s" % self.ver) - def register_extended_data_loader(self, bAutoTrans:bool = True): - ''' + def register_extended_data_loader(self, bAutoTrans: bool = True): + """ 注册扩展历史数据加载器 @bAutoTrans 是否自动转储 - ''' + """ self.cb_load_fnlbars = FUNC_LOAD_HISBARS(self.on_load_fnl_his_bars) self.cb_load_rawbars = FUNC_LOAD_HISBARS(self.on_load_raw_his_bars) self.cb_load_histicks = FUNC_LOAD_HISTICKS(self.on_load_his_ticks) self.cb_load_adjfacts = FUNC_LOAD_ADJFACTS(self.on_load_adj_factors) - self.api.register_ext_data_loader(self.cb_load_fnlbars, self.cb_load_rawbars, self.cb_load_adjfacts, self.cb_load_histicks, bAutoTrans) + self.api.register_ext_data_loader(self.cb_load_fnlbars, self.cb_load_rawbars, self.cb_load_adjfacts, + self.cb_load_histicks, bAutoTrans) - def cta_enter_long(self, id:int, stdCode:str, qty:float, usertag:str, limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def cta_enter_long(self, id: int, stdCode: str, qty: float, usertag: str, limitprice: float = 0.0, + stopprice: float = 0.0): + """ 开多 @id 策略id @stdCode 合约代码 @qty 手数, 大于等于0 - ''' - self.api.cta_enter_long(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8"), limitprice, stopprice) + """ + self.api.cta_enter_long(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8"), limitprice, + stopprice) - def cta_exit_long(self, id:int, stdCode:str, qty:float, usertag:str, limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def cta_exit_long(self, id: int, stdCode: str, qty: float, usertag: str, limitprice: float = 0.0, + stopprice: float = 0.0): + """ 平多 @id 策略id @stdCode 合约代码 @qty 手数, 大于等于0 - ''' - self.api.cta_exit_long(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8"), limitprice, stopprice) + """ + self.api.cta_exit_long(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8"), limitprice, + stopprice) - def cta_enter_short(self, id:int, stdCode:str, qty:float, usertag:str, limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def cta_enter_short(self, id: int, stdCode: str, qty: float, usertag: str, limitprice: float = 0.0, + stopprice: float = 0.0): + """ 开空 @id 策略id @stdCode 合约代码 @qty 手数, 大于等于0 - ''' - self.api.cta_enter_short(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8"), limitprice, stopprice) + """ + self.api.cta_enter_short(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8"), limitprice, + stopprice) - def cta_exit_short(self, id:int, stdCode:str, qty:float, usertag:str, limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def cta_exit_short(self, id: int, stdCode: str, qty: float, usertag: str, limitprice: float = 0.0, + stopprice: float = 0.0): + """ 平空 @id 策略id @stdCode 合约代码 @qty 手数, 大于等于0 - ''' - self.api.cta_exit_short(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8"), limitprice, stopprice) - - def cta_get_bars(self, id:int, stdCode:str, period:str, count:int, isMain:bool): - ''' + """ + self.api.cta_exit_short(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8"), limitprice, + stopprice) + + def cta_get_bars(self, id: int, stdCode: str, period: str, count: int, isMain: bool): + """ 读取K线 @id 策略id @stdCode 合约代码 @period 周期, 如m1/m3/d1等 @count 条数 @isMain 是否主K线 - ''' - return self.api.cta_get_bars(id, bytes(stdCode, encoding = "utf8"), bytes(period, encoding = "utf8"), count, isMain, CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) - - def cta_get_ticks(self, id:int, stdCode:str, count:int): - ''' + """ + return self.api.cta_get_bars(id, bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8"), count, isMain, + CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) + + def cta_get_ticks(self, id: int, stdCode: str, count: int): + """ 读取Tick @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.cta_get_ticks(id, bytes(stdCode, encoding = "utf8"), count, CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) + """ + return self.api.cta_get_ticks(id, bytes(stdCode, encoding="utf8"), count, + CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) - def cta_get_position_profit(self, id:int, stdCode:str): - ''' + def cta_get_position_profit(self, id: int, stdCode: str): + """ 获取浮动盈亏 @id 策略id @stdCode 合约代码 @return 指定合约的浮动盈亏 - ''' - return self.api.cta_get_position_profit(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_position_profit(id, bytes(stdCode, encoding="utf8")) - def cta_get_position_avgpx(self, id:int, stdCode:str): - ''' + def cta_get_position_avgpx(self, id: int, stdCode: str): + """ 获取持仓均价 @id 策略id @stdCode 合约代码 @return 指定合约的持仓均价 - ''' - return self.api.cta_get_position_avgpx(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_position_avgpx(id, bytes(stdCode, encoding="utf8")) - def cta_get_all_position(self, id:int): - ''' + def cta_get_all_position(self, id: int): + """ 获取全部持仓 @id 策略id - ''' + """ return self.api.cta_get_all_position(id, CB_STRATEGY_GET_POSITION(self.on_stra_get_position)) - - def cta_get_position(self, id:int, stdCode:str, bonlyvalid:bool = False, usertag:str = ""): - ''' + + def cta_get_position(self, id: int, stdCode: str, bonlyvalid: bool = False, usertag: str = ""): + """ 获取持仓 @id 策略id @stdCode 合约代码 @usertag 进场标记, 如果为空则获取该合约全部持仓 @return 指定合约的持仓手数, 正为多, 负为空 - ''' - return self.api.cta_get_position(id, bytes(stdCode, encoding = "utf8"), bonlyvalid, bytes(usertag, encoding = "utf8")) + """ + return self.api.cta_get_position(id, bytes(stdCode, encoding="utf8"), bonlyvalid, + bytes(usertag, encoding="utf8")) - def cta_get_fund_data(self, id:int, flag:int) -> float: - ''' + def cta_get_fund_data(self, id: int, flag: int) -> float: + """ 获取资金数据 @id 策略id @flag 0-动态权益, 1-总平仓盈亏, 2-总浮动盈亏, 3-总手续费 @return 资金数据 - ''' + """ return self.api.cta_get_fund_data(id, flag) - def cta_get_price(self, stdCode:str) -> float: - ''' + def cta_get_price(self, stdCode: str) -> float: + """ @stdCode 合约代码 @return 指定合约的最新价格 - ''' - return self.api.cta_get_price(bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_price(bytes(stdCode, encoding="utf8")) - def cta_get_day_price(self, stdCode:str, flag:int = 0) -> float: - ''' + def cta_get_day_price(self, stdCode: str, flag: int = 0) -> float: + """ 获取当日价格 @stdCode 合约代码 @flag 价格标记, 0-开盘价, 1-最高价, 2-最低价, 3-最新价 @return 指定合约的价格 - ''' - return self.api.cta_get_day_price(bytes(stdCode, encoding = "utf8"), flag) + """ + return self.api.cta_get_day_price(bytes(stdCode, encoding="utf8"), flag) - def cta_set_position(self, id:int, stdCode:str, qty:float, usertag:str = "", limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def cta_set_position(self, id: int, stdCode: str, qty: float, usertag: str = "", limitprice: float = 0.0, + stopprice: float = 0.0): + """ 设置目标仓位 @id 策略id @stdCode 合约代码 @qty 目标仓位, 正为多, 负为空 - ''' - self.api.cta_set_position(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8"), limitprice, stopprice) + """ + self.api.cta_set_position(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8"), limitprice, + stopprice) def cta_get_tdate(self) -> int: - ''' + """ 获取当前交易日 @return 当前交易日 - ''' + """ return self.api.cta_get_tdate() def cta_get_date(self) -> int: - ''' + """ 获取当前日期 @return 当前日期 - ''' + """ return self.api.cta_get_date() def cta_get_time(self) -> int: - ''' + """ 获取当前时间 @return 当前时间 - ''' + """ return self.api.cta_get_time() - def cta_get_first_entertime(self, id:int, stdCode:str) -> int: - ''' + def cta_get_first_entertime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的首次进场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.cta_get_first_entertime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_first_entertime(id, bytes(stdCode, encoding="utf8")) - def cta_get_last_entertime(self, id:int, stdCode:str) -> int: - ''' + def cta_get_last_entertime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的最后进场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.cta_get_last_entertime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_last_entertime(id, bytes(stdCode, encoding="utf8")) - def cta_get_last_entertag(self, id:int, stdCode:str) -> str: - ''' + def cta_get_last_entertag(self, id: int, stdCode: str) -> str: + """ 获取当前持仓的最后进场标记 @stdCode 合约代码 @return 进场标记 - ''' - return bytes.decode(self.api.cta_get_last_entertag(id, bytes(stdCode, encoding = "utf8"))) + """ + return bytes.decode(self.api.cta_get_last_entertag(id, bytes(stdCode, encoding="utf8"))) - def cta_get_last_exittime(self, id:int, stdCode:str) -> int: - ''' + def cta_get_last_exittime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的最后出场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.cta_get_last_exittime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_last_exittime(id, bytes(stdCode, encoding="utf8")) - def cta_log_text(self, id:int, level:int, message:str): - ''' + def cta_log_text(self, id: int, level: int, message: str): + """ 日志输出 @id 策略ID @level 日志级别 @message 日志内容 - ''' + """ self.api.cta_log_text(id, level, ph.auto_encode(message)) - def cta_get_detail_entertime(self, id:int, stdCode:str, usertag:str) -> int: - ''' + def cta_get_detail_entertime(self, id: int, stdCode: str, usertag: str) -> int: + """ 获取指定标记的持仓的进场时间 @id 策略id @stdCode 合约代码 @usertag 进场标记 @return 进场时间, 格式如201907260932 - ''' - return self.api.cta_get_detail_entertime(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8")) + """ + return self.api.cta_get_detail_entertime(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8")) - def cta_get_detail_cost(self, id:int, stdCode:str, usertag:str) -> float: - ''' + def cta_get_detail_cost(self, id: int, stdCode: str, usertag: str) -> float: + """ 获取指定标记的持仓的开仓价 @id 策略id @stdCode 合约代码 @usertag 进场标记 @return 开仓价 - ''' - return self.api.cta_get_detail_cost(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8")) + """ + return self.api.cta_get_detail_cost(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8")) - def cta_get_detail_profit(self, id:int, stdCode:str, usertag:str, flag:int): - ''' + def cta_get_detail_profit(self, id: int, stdCode: str, usertag: str, flag: int): + """ 获取指定标记的持仓的盈亏 @id 策略id @stdCode 合约代码 @usertag 进场标记 @flag 盈亏记号, 0-浮动盈亏, 1-最大浮盈, -1-最大亏损(负数), 2-最大浮盈价格, -2-最大浮亏价格 @return 盈亏 - ''' - return self.api.cta_get_detail_profit(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8"), flag) + """ + return self.api.cta_get_detail_profit(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8"), + flag) - def cta_save_user_data(self, id:int, key:str, val:str): - ''' + def cta_save_user_data(self, id: int, key: str, val: str): + """ 保存用户数据 @id 策略id @key 数据名 @val 数据值 - ''' - self.api.cta_save_userdata(id, bytes(key, encoding = "utf8"), bytes(val, encoding = "utf8")) + """ + self.api.cta_save_userdata(id, bytes(key, encoding="utf8"), bytes(val, encoding="utf8")) - def cta_load_user_data(self, id:int, key:str, defVal:str = ""): - ''' + def cta_load_user_data(self, id: int, key: str, defVal: str = ""): + """ 加载用户数据 @id 策略id @key 数据名 @defVal 默认值 - ''' - ret = self.api.cta_load_userdata(id, bytes(key, encoding = "utf8"), bytes(defVal, encoding = "utf8")) + """ + ret = self.api.cta_load_userdata(id, bytes(key, encoding="utf8"), bytes(defVal, encoding="utf8")) return bytes.decode(ret) - def cta_sub_ticks(self, id:int, stdCode:str): - ''' + def cta_sub_ticks(self, id: int, stdCode: str): + """ 订阅行情 @id 策略id @stdCode 品种代码 - ''' - self.api.cta_sub_ticks(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.cta_sub_ticks(id, bytes(stdCode, encoding="utf8")) - def cta_sub_bar_events(self, id:int, stdCode:str, period:str): - ''' + def cta_sub_bar_events(self, id: int, stdCode: str, period: str): + """ 订阅K线事件 @id 策略id @stdCode 品种代码 @period 周期 - ''' - self.api.cta_sub_bar_events(id, bytes(stdCode, encoding = "utf8"), bytes(period, encoding = "utf8")) + """ + self.api.cta_sub_bar_events(id, bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8")) - def cta_step(self, id:int) -> bool: - ''' + def cta_step(self, id: int) -> bool: + """ 单步执行 @id 策略id - ''' + """ return self.api.cta_step(id) - def cta_set_chart_kline(self, id:int, stdCode:str, period:str): - ''' + def cta_set_chart_kline(self, id: int, stdCode: str, period: str): + """ 设置图表K线 @stdCode 合约代码 @period K线周期 - ''' - self.api.cta_set_chart_kline(id, bytes(stdCode, encoding = "utf8"), bytes(period, encoding = "utf8")) + """ + self.api.cta_set_chart_kline(id, bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8")) - def cta_add_chart_mark(self, id:int, price:float, icon:str, tag:str = 'Notag'): - ''' + def cta_add_chart_mark(self, id: int, price: float, icon: str, tag: str = 'Notag'): + """ 添加图表标记 @price 价格, 决定图标出现的位置 @icon 图标, 系统一定的图标ID @tag 标签, 自定义的 - ''' - self.api.cta_add_chart_mark(id, price, bytes(icon, encoding = "utf8"), bytes(tag, encoding = "utf8")) + """ + self.api.cta_add_chart_mark(id, price, bytes(icon, encoding="utf8"), bytes(tag, encoding="utf8")) - def cta_register_index(self, id:int, idxName:str, idxType:int = 1): - ''' + def cta_register_index(self, id: int, idxName: str, idxType: int = 1): + """ 注册指标, on_init调用 @idxName 指标名 @idxType 指标类型, 0-主图指标, 1-副图指标 - ''' - self.api.cta_register_index(id, bytes(idxName, encoding = "utf8"), idxType) + """ + self.api.cta_register_index(id, bytes(idxName, encoding="utf8"), idxType) - def cta_register_index_line(self, id:int, idxName:str, lineName:str, lineType:int = 0) -> bool: - ''' + def cta_register_index_line(self, id: int, idxName: str, lineName: str, lineType: int = 0) -> bool: + """ 注册指标线, on_init调用 @idxName 指标名称 @lineName 线名称 @lineType 线型, 0-曲线, 1-柱子 - ''' - return self.api.cta_register_index_line(id, bytes(idxName, encoding = "utf8"), bytes(lineName, encoding = "utf8"), lineType) + """ + return self.api.cta_register_index_line(id, bytes(idxName, encoding="utf8"), bytes(lineName, encoding="utf8"), + lineType) - def cta_add_index_baseline(self, id:int, idxName:str, lineName:str, value:float) -> bool: - ''' + def cta_add_index_baseline(self, id: int, idxName: str, lineName: str, value: float) -> bool: + """ 添加基准线, on_init调用 @idxName 指标名称 @lineName 线名称 @value 数值 - ''' - return self.api.cta_add_index_baseline(id, bytes(idxName, encoding = "utf8"), bytes(lineName, encoding = "utf8"), value) + """ + return self.api.cta_add_index_baseline(id, bytes(idxName, encoding="utf8"), bytes(lineName, encoding="utf8"), + value) - def cta_set_index_value(self, id:int, idxName:str, lineName:str, val:float) -> bool: - ''' + def cta_set_index_value(self, id: int, idxName: str, lineName: str, val: float) -> bool: + """ 设置指标值, 只有在oncalc的时候才生效 @idxName 指标名称 @lineName 线名称 - ''' - return self.api.cta_set_index_value(id, bytes(idxName, encoding = "utf8"), bytes(lineName, encoding = "utf8"), val) - - - '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - '''SEL接口''' - def sel_get_bars(self, id:int, stdCode:str, period:str, count:int): - ''' + """ + return self.api.cta_set_index_value(id, bytes(idxName, encoding="utf8"), bytes(lineName, encoding="utf8"), val) + + """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + """SEL接口""" + + def sel_get_bars(self, id: int, stdCode: str, period: str, count: int): + """ 读取K线 @id 策略id @stdCode 合约代码 @period 周期, 如m1/m3/d1等 @count 条数 - ''' - return self.api.sel_get_bars(id, bytes(stdCode, encoding = "utf8"), bytes(period, encoding = "utf8"), count, CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) - - def sel_get_ticks(self, id:int, stdCode:str, count:int): - ''' + """ + return self.api.sel_get_bars(id, bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8"), count, + CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) + + def sel_get_ticks(self, id: int, stdCode: str, count: int): + """ 读取Tick @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.sel_get_ticks(id, bytes(stdCode, encoding = "utf8"), count, CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) + """ + return self.api.sel_get_ticks(id, bytes(stdCode, encoding="utf8"), count, + CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) - def sel_save_user_data(self, id:int, key:str, val:str): - ''' + def sel_save_user_data(self, id: int, key: str, val: str): + """ 保存用户数据 @id 策略id @key 数据名 @val 数据值 - ''' - self.api.sel_save_userdata(id, bytes(key, encoding = "utf8"), bytes(val, encoding = "utf8")) + """ + self.api.sel_save_userdata(id, bytes(key, encoding="utf8"), bytes(val, encoding="utf8")) - def sel_load_user_data(self, id:int, key:str, defVal:str = ""): - ''' + def sel_load_user_data(self, id: int, key: str, defVal: str = ""): + """ 加载用户数据 @id 策略id @key 数据名 @defVal 默认值 - ''' - ret = self.api.sel_load_userdata(id, bytes(key, encoding = "utf8"), bytes(defVal, encoding = "utf8")) + """ + ret = self.api.sel_load_userdata(id, bytes(key, encoding="utf8"), bytes(defVal, encoding="utf8")) return bytes.decode(ret) - def sel_get_all_position(self, id:int): - ''' + def sel_get_all_position(self, id: int): + """ 获取全部持仓 @id 策略id - ''' + """ return self.api.sel_get_all_position(id, CB_STRATEGY_GET_POSITION(self.on_stra_get_position)) - def sel_get_position(self, id:int, stdCode:str, bonlyvalid:bool = False, usertag:str = ""): - ''' + def sel_get_position(self, id: int, stdCode: str, bonlyvalid: bool = False, usertag: str = ""): + """ 获取持仓 @id 策略id @stdCode 合约代码 @usertag 进场标记, 如果为空则获取该合约全部持仓 @return 指定合约的持仓手数, 正为多, 负为空 - ''' - return self.api.sel_get_position(id, bytes(stdCode, encoding = "utf8"), bonlyvalid, bytes(usertag, encoding = "utf8")) + """ + return self.api.sel_get_position(id, bytes(stdCode, encoding="utf8"), bonlyvalid, + bytes(usertag, encoding="utf8")) - def sel_get_price(self, stdCode:str): - ''' + def sel_get_price(self, stdCode: str): + """ @stdCode 合约代码 @return 指定合约的最新价格 - ''' - return self.api.sel_get_price(bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_price(bytes(stdCode, encoding="utf8")) - def sel_set_position(self, id:int, stdCode:str, qty:float, usertag:str = ""): - ''' + def sel_set_position(self, id: int, stdCode: str, qty: float, usertag: str = ""): + """ 设置目标仓位 @id 策略id @stdCode 合约代码 @qty 目标仓位, 正为多, 负为空 - ''' - self.api.sel_set_position(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8")) - + """ + self.api.sel_set_position(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8")) + def sel_get_tdate(self) -> int: - ''' + """ 获取当前交易日 @return 当前交易日 - ''' + """ return self.api.sel_get_tdate() - + def sel_get_date(self): - ''' + """ 获取当前日期 @return 当前日期 - ''' + """ return self.api.sel_get_date() def sel_get_time(self): - ''' + """ 获取当前时间 @return 当前时间 - ''' + """ return self.api.sel_get_time() - def sel_log_text(self, id:int, level:int, message:str): - ''' + def sel_log_text(self, id: int, level: int, message: str): + """ 日志输出 @id 策略ID @level 日志级别 @message 日志内容 - ''' + """ self.api.sel_log_text(id, level, ph.auto_encode(message)) - def sel_sub_ticks(self, id:int, stdCode:str): - ''' + def sel_sub_ticks(self, id: int, stdCode: str): + """ 订阅行情 @id 策略id @stdCode 品种代码 - ''' - self.api.sel_sub_ticks(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.sel_sub_ticks(id, bytes(stdCode, encoding="utf8")) - def sel_get_day_price(self, stdCode:str, flag:int = 0) -> float: - ''' + def sel_get_day_price(self, stdCode: str, flag: int = 0) -> float: + """ 获取当日价格 @stdCode 合约代码 @flag 价格标记, 0-开盘价, 1-最高价, 2-最低价, 3-最新价 @return 指定合约的价格 - ''' - return self.api.sel_get_day_price(bytes(stdCode, encoding = "utf8"), flag) + """ + return self.api.sel_get_day_price(bytes(stdCode, encoding="utf8"), flag) - def sel_get_fund_data(self, id:int, flag:int) -> float: - ''' + def sel_get_fund_data(self, id: int, flag: int) -> float: + """ 获取资金数据 @id 策略id @flag 0-动态权益, 1-总平仓盈亏, 2-总浮动盈亏, 3-总手续费 @return 资金数据 - ''' + """ return self.api.sel_get_fund_data(id, flag) - def sel_get_position_profit(self, id:int, stdCode:str): - ''' + def sel_get_position_profit(self, id: int, stdCode: str): + """ 获取浮动盈亏 @id 策略id @stdCode 合约代码 @return 指定合约的浮动盈亏 - ''' - return self.api.sel_get_position_profit(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_position_profit(id, bytes(stdCode, encoding="utf8")) - def sel_get_position_avgpx(self, id:int, stdCode:str): - ''' + def sel_get_position_avgpx(self, id: int, stdCode: str): + """ 获取持仓均价 @id 策略id @stdCode 合约代码 @return 指定合约的持仓均价 - ''' - return self.api.sel_get_position_avgpx(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_position_avgpx(id, bytes(stdCode, encoding="utf8")) - def sel_get_first_entertime(self, id:int, stdCode:str) -> int: - ''' + def sel_get_first_entertime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的首次进场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.sel_get_first_entertime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_first_entertime(id, bytes(stdCode, encoding="utf8")) - def sel_get_last_entertime(self, id:int, stdCode:str) -> int: - ''' + def sel_get_last_entertime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的最后进场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.sel_get_last_entertime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_last_entertime(id, bytes(stdCode, encoding="utf8")) - def sel_get_last_entertag(self, id:int, stdCode:str) -> str: - ''' + def sel_get_last_entertag(self, id: int, stdCode: str) -> str: + """ 获取当前持仓的最后进场标记 @stdCode 合约代码 @return 进场标记 - ''' - return bytes.decode(self.api.sel_get_last_entertag(id, bytes(stdCode, encoding = "utf8"))) + """ + return bytes.decode(self.api.sel_get_last_entertag(id, bytes(stdCode, encoding="utf8"))) - def sel_get_last_exittime(self, id:int, stdCode:str) -> int: - ''' + def sel_get_last_exittime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的最后出场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.sel_get_last_exittime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_last_exittime(id, bytes(stdCode, encoding="utf8")) - def sel_get_detail_entertime(self, id:int, stdCode:str, usertag:str) -> int: - ''' + def sel_get_detail_entertime(self, id: int, stdCode: str, usertag: str) -> int: + """ 获取指定标记的持仓的进场时间 @id 策略id @stdCode 合约代码 @usertag 进场标记 @return 进场时间, 格式如201907260932 - ''' - return self.api.sel_get_detail_entertime(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8")) + """ + return self.api.sel_get_detail_entertime(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8")) - def sel_get_detail_cost(self, id:int, stdCode:str, usertag:str) -> float: - ''' + def sel_get_detail_cost(self, id: int, stdCode: str, usertag: str) -> float: + """ 获取指定标记的持仓的开仓价 @id 策略id @stdCode 合约代码 @usertag 进场标记 @return 开仓价 - ''' - return self.api.sel_get_detail_cost(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8")) + """ + return self.api.sel_get_detail_cost(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8")) - def sel_get_detail_profit(self, id:int, stdCode:str, usertag:str, flag:int): - ''' + def sel_get_detail_profit(self, id: int, stdCode: str, usertag: str, flag: int): + """ 获取指定标记的持仓的盈亏 @id 策略id @stdCode 合约代码 @usertag 进场标记 @flag 盈亏记号, 0-浮动盈亏, 1-最大浮盈, -1-最大亏损(负数), 2-最大浮盈价格, -2-最大浮亏价格 @return 盈亏 - ''' - return self.api.sel_get_detail_profit(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8"), flag) + """ + return self.api.sel_get_detail_profit(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8"), + flag) + """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + """HFT接口""" - '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - '''HFT接口''' - def hft_get_bars(self, id:int, stdCode:str, period:str, count:int): - ''' + def hft_get_bars(self, id: int, stdCode: str, period: str, count: int): + """ 读取K线 @id 策略id @stdCode 合约代码 @period 周期, 如m1/m3/d1等 @count 条数 - ''' - return self.api.hft_get_bars(id, bytes(stdCode, encoding = "utf8"), bytes(period, encoding = "utf8"), count, CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) - - def hft_get_ticks(self, id:int, stdCode:str, count:int): - ''' + """ + return self.api.hft_get_bars(id, bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8"), count, + CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) + + def hft_get_ticks(self, id: int, stdCode: str, count: int): + """ 读取Tick @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.hft_get_ticks(id, bytes(stdCode, encoding = "utf8"), count, CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) + """ + return self.api.hft_get_ticks(id, bytes(stdCode, encoding="utf8"), count, + CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) - def hft_get_ordque(self, id:int, stdCode:str, count:int): - ''' + def hft_get_ordque(self, id: int, stdCode: str, count: int): + """ 读取委托队列 @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.hft_get_ordque(id, bytes(stdCode, encoding = "utf8"), count, CB_HFTSTRA_GET_ORDQUE(self.on_hftstra_get_order_queue)) + """ + return self.api.hft_get_ordque(id, bytes(stdCode, encoding="utf8"), count, + CB_HFTSTRA_GET_ORDQUE(self.on_hftstra_get_order_queue)) - def hft_get_orddtl(self, id:int, stdCode:str, count:int): - ''' + def hft_get_orddtl(self, id: int, stdCode: str, count: int): + """ 读取逐笔委托 @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.hft_get_orddtl(id, bytes(stdCode, encoding = "utf8"), count, CB_HFTSTRA_GET_ORDDTL(self.on_hftstra_get_order_detail)) + """ + return self.api.hft_get_orddtl(id, bytes(stdCode, encoding="utf8"), count, + CB_HFTSTRA_GET_ORDDTL(self.on_hftstra_get_order_detail)) - def hft_get_trans(self, id:int, stdCode:str, count:int): - ''' + def hft_get_trans(self, id: int, stdCode: str, count: int): + """ 读取逐笔成交 @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.hft_get_trans(id, bytes(stdCode, encoding = "utf8"), count, CB_HFTSTRA_GET_TRANS(self.on_hftstra_get_transaction)) + """ + return self.api.hft_get_trans(id, bytes(stdCode, encoding="utf8"), count, + CB_HFTSTRA_GET_TRANS(self.on_hftstra_get_transaction)) - def hft_save_user_data(self, id:int, key:str, val:str): - ''' + def hft_save_user_data(self, id: int, key: str, val: str): + """ 保存用户数据 @id 策略id @key 数据名 @val 数据值 - ''' - self.api.hft_save_userdata(id, bytes(key, encoding = "utf8"), bytes(val, encoding = "utf8")) + """ + self.api.hft_save_userdata(id, bytes(key, encoding="utf8"), bytes(val, encoding="utf8")) - def hft_load_user_data(self, id:int, key:str, defVal:str = ""): - ''' + def hft_load_user_data(self, id: int, key: str, defVal: str = ""): + """ 加载用户数据 @id 策略id @key 数据名 @defVal 默认值 - ''' - ret = self.api.hft_load_userdata(id, bytes(key, encoding = "utf8"), bytes(defVal, encoding = "utf8")) + """ + ret = self.api.hft_load_userdata(id, bytes(key, encoding="utf8"), bytes(defVal, encoding="utf8")) return bytes.decode(ret) - def hft_get_position(self, id:int, stdCode:str, bonlyvalid:bool = False): - ''' + def hft_get_position(self, id: int, stdCode: str, bonlyvalid: bool = False): + """ 获取持仓 @id 策略id @stdCode 合约代码 @return 指定合约的持仓手数, 正为多, 负为空 - ''' - return self.api.hft_get_position(id, bytes(stdCode, encoding = "utf8"), bonlyvalid) + """ + return self.api.hft_get_position(id, bytes(stdCode, encoding="utf8"), bonlyvalid) - def hft_get_position_profit(self, id:int, stdCode:str): - ''' + def hft_get_position_profit(self, id: int, stdCode: str): + """ 获取持仓盈亏 @id 策略id @stdCode 合约代码 @return 指定持仓的浮动盈亏 - ''' - return self.api.hft_get_position_profit(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.hft_get_position_profit(id, bytes(stdCode, encoding="utf8")) - def hft_get_position_avgpx(self, id:int, stdCode:str): - ''' + def hft_get_position_avgpx(self, id: int, stdCode: str): + """ 获取持仓均价 @id 策略id @stdCode 合约代码 @return 指定持仓的浮动盈亏 - ''' - return self.api.hft_get_position_avgpx(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.hft_get_position_avgpx(id, bytes(stdCode, encoding="utf8")) - def hft_get_undone(self, id:int, stdCode:str): - ''' + def hft_get_undone(self, id: int, stdCode: str): + """ 获取持仓 @id 策略id @stdCode 合约代码 @return 指定合约的持仓手数, 正为多, 负为空 - ''' - return self.api.hft_get_undone(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.hft_get_undone(id, bytes(stdCode, encoding="utf8")) - def hft_get_price(self, stdCode:str): - ''' + def hft_get_price(self, stdCode: str): + """ @stdCode 合约代码 @return 指定合约的最新价格 - ''' - return self.api.hft_get_price(bytes(stdCode, encoding = "utf8")) + """ + return self.api.hft_get_price(bytes(stdCode, encoding="utf8")) def hft_get_date(self): - ''' + """ 获取当前日期 @return 当前日期 - ''' + """ return self.api.hft_get_date() def hft_get_time(self): - ''' + """ 获取当前时间 @return 当前时间 - ''' + """ return self.api.hft_get_time() def hft_get_secs(self): - ''' + """ 获取当前时间 @return 当前时间 - ''' + """ return self.api.hft_get_secs() - def hft_log_text(self, id:int, level:int, message:str): - ''' + def hft_log_text(self, id: int, level: int, message: str): + """ 日志输出 @id 策略ID @level 日志级别 @message 日志内容 - ''' + """ self.api.hft_log_text(id, level, ph.auto_encode(message)) - def hft_sub_ticks(self, id:int, stdCode:str): - ''' + def hft_sub_ticks(self, id: int, stdCode: str): + """ 订阅实时行情数据 @id 策略ID @stdCode 品种代码 - ''' - self.api.hft_sub_ticks(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.hft_sub_ticks(id, bytes(stdCode, encoding="utf8")) - def hft_sub_order_queue(self, id:int, stdCode:str): - ''' + def hft_sub_order_queue(self, id: int, stdCode: str): + """ 订阅实时委托队列数据 @id 策略ID @stdCode 品种代码 - ''' - self.api.hft_sub_order_queue(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.hft_sub_order_queue(id, bytes(stdCode, encoding="utf8")) - def hft_sub_order_detail(self, id:int, stdCode:str): - ''' + def hft_sub_order_detail(self, id: int, stdCode: str): + """ 订阅逐笔委托数据 @id 策略ID @stdCode 品种代码 - ''' - self.api.hft_sub_order_detail(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.hft_sub_order_detail(id, bytes(stdCode, encoding="utf8")) - def hft_sub_transaction(self, id:int, stdCode:str): - ''' + def hft_sub_transaction(self, id: int, stdCode: str): + """ 订阅逐笔成交数据 @id 策略ID @stdCode 品种代码 - ''' - self.api.hft_sub_transaction(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.hft_sub_transaction(id, bytes(stdCode, encoding="utf8")) - def hft_cancel(self, id:int, localid:int): - ''' + def hft_cancel(self, id: int, localid: int): + """ 撤销指定订单 @id 策略ID @localid 下单时返回的本地订单号 - ''' + """ return self.api.hft_cancel(id, localid) - def hft_cancel_all(self, id:int, stdCode:str, isBuy:bool): - ''' + def hft_cancel_all(self, id: int, stdCode: str, isBuy: bool): + """ 撤销指定品种的全部买入订单or卖出订单 @id 策略ID @stdCode 品种代码 @isBuy 买入or卖出 - ''' - ret = self.api.hft_cancel_all(id, bytes(stdCode, encoding = "utf8"), isBuy) + """ + ret = self.api.hft_cancel_all(id, bytes(stdCode, encoding="utf8"), isBuy) return bytes.decode(ret) - def hft_buy(self, id:int, stdCode:str, price:float, qty:float, userTag:str, flag:int): - ''' + def hft_buy(self, id: int, stdCode: str, price: float, qty: float, userTag: str, flag: int): + """ 买入指令 @id 策略ID @stdCode 品种代码 @price 买入价格, 0为市价 @qty 买入数量 - ''' - ret = self.api.hft_buy(id, bytes(stdCode, encoding = "utf8"), price, qty, bytes(userTag, encoding = "utf8"), flag) + """ + ret = self.api.hft_buy(id, bytes(stdCode, encoding="utf8"), price, qty, bytes(userTag, encoding="utf8"), flag) return bytes.decode(ret) - def hft_sell(self, id:int, stdCode:str, price:float, qty:float, userTag:str, flag:int): - ''' + def hft_sell(self, id: int, stdCode: str, price: float, qty: float, userTag: str, flag: int): + """ 卖出指令 @id 策略ID @stdCode 品种代码 @price 卖出价格, 0为市价 @qty 卖出数量 - ''' - ret = self.api.hft_sell(id, bytes(stdCode, encoding = "utf8"), price, qty, bytes(userTag, encoding = "utf8"), flag) + """ + ret = self.api.hft_sell(id, bytes(stdCode, encoding="utf8"), price, qty, bytes(userTag, encoding="utf8"), flag) return bytes.decode(ret) - def hft_step(self, id:int): - ''' + def hft_step(self, id: int): + """ 单步执行 @id 策略id - ''' + """ self.api.hft_step(id) + """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + """本地撮合接口""" - '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - '''本地撮合接口''' - def init_cta_mocker(self, name:str, slippage:int = 0, hook:bool = False, persistData:bool = True, incremental:bool = False, isRatioSlp:bool = False) -> int: - ''' + def init_cta_mocker(self, name: str, slippage: int = 0, hook: bool = False, persistData: bool = True, + incremental: bool = False, isRatioSlp: bool = False) -> int: + """ 创建策略环境 @name 策略名称 @slippage 滑点大小 @@ -1227,22 +1268,25 @@ def init_cta_mocker(self, name:str, slippage:int = 0, hook:bool = False, persist @incremental 是否增量回测, 默认为False @isRatioSlp 滑点是否是比例, 默认为False, 如果为True, 则slippage为万分比 @return 系统内策略ID - ''' - return self.api.init_cta_mocker(bytes(name, encoding = "utf8"), slippage, hook, persistData, incremental, isRatioSlp) + """ + return self.api.init_cta_mocker(bytes(name, encoding="utf8"), slippage, hook, persistData, incremental, + isRatioSlp) - def init_hft_mocker(self, name:str, hook:bool = False) -> int: - ''' + def init_hft_mocker(self, name: str, hook: bool = False) -> int: + """ 创建策略环境 @name 策略名称 @return 系统内策略ID - ''' - return self.api.init_hft_mocker(bytes(name, encoding = "utf8"), hook) + """ + return self.api.init_hft_mocker(bytes(name, encoding="utf8"), hook) - def init_sel_mocker(self, name:str, date:int, time:int, period:str, trdtpl:str = "CHINA", session:str = "TRADING", slippage:int = 0, isRatioSlp:bool = False) -> int: - ''' + def init_sel_mocker(self, name: str, date: int, time: int, period: str, trdtpl: str = "CHINA", + session: str = "TRADING", slippage: int = 0, isRatioSlp: bool = False) -> int: + """ 创建策略环境 @name 策略名称 @return 系统内策略ID - ''' - return self.api.init_sel_mocker(bytes(name, encoding = "utf8"), date, time, - bytes(period, encoding = "utf8"), bytes(trdtpl, encoding = "utf8"), bytes(session, encoding = "utf8"), slippage, isRatioSlp) + """ + return self.api.init_sel_mocker(bytes(name, encoding="utf8"), date, time, + bytes(period, encoding="utf8"), bytes(trdtpl, encoding="utf8"), + bytes(session, encoding="utf8"), slippage, isRatioSlp) diff --git a/wtpy/wrapper/WtDtHelper.py b/wtpy/wrapper/WtDtHelper.py index 1f05ab4c..8ae40b75 100644 --- a/wtpy/wrapper/WtDtHelper.py +++ b/wtpy/wrapper/WtDtHelper.py @@ -1,25 +1,29 @@ +import logging +import os from ctypes import cdll, CFUNCTYPE, c_char_p, c_void_p, c_bool, POINTER, c_uint32, c_uint64 -from wtpy.WtCoreDefs import WTSTickStruct, WTSBarStruct, WTSOrdDtlStruct, WTSOrdQueStruct, WTSTransStruct -from wtpy.WtDataDefs import WtTickCache, WtNpOrdDetails, WtNpOrdQueues, WtNpTransactions, WtNpKline, WtNpTicks, WtBarCache + from wtpy.SessionMgr import SessionInfo -from wtpy.wrapper.PlatformHelper import PlatformHelper as ph +from wtpy.WtCoreDefs import WTSTickStruct, WTSBarStruct, WTSOrdDtlStruct, WTSOrdQueStruct, WTSTransStruct +from wtpy.WtDataDefs import WtTickCache, WtNpOrdDetails, WtNpOrdQueues, WtNpTransactions, WtNpKline, WtNpTicks, \ + WtBarCache from wtpy.WtUtilDefs import singleton -import os,logging +from wtpy.wrapper.PlatformHelper import PlatformHelper as ph + +CB_DTHELPER_LOG = CFUNCTYPE(c_void_p, c_char_p) +CB_DTHELPER_TICK = CFUNCTYPE(c_void_p, POINTER(WTSTickStruct), c_uint32, c_bool) +CB_DTHELPER_ORDQUE = CFUNCTYPE(c_void_p, POINTER(WTSOrdDtlStruct), c_uint32, c_bool) +CB_DTHELPER_ORDDTL = CFUNCTYPE(c_void_p, POINTER(WTSOrdQueStruct), c_uint32, c_bool) +CB_DTHELPER_TRANS = CFUNCTYPE(c_void_p, POINTER(WTSTransStruct), c_uint32, c_bool) +CB_DTHELPER_BAR = CFUNCTYPE(c_void_p, POINTER(WTSBarStruct), c_uint32, c_bool) -CB_DTHELPER_LOG = CFUNCTYPE(c_void_p, c_char_p) -CB_DTHELPER_TICK = CFUNCTYPE(c_void_p, POINTER(WTSTickStruct), c_uint32, c_bool) -CB_DTHELPER_ORDQUE = CFUNCTYPE(c_void_p, POINTER(WTSOrdDtlStruct), c_uint32, c_bool) -CB_DTHELPER_ORDDTL = CFUNCTYPE(c_void_p, POINTER(WTSOrdQueStruct), c_uint32, c_bool) -CB_DTHELPER_TRANS = CFUNCTYPE(c_void_p, POINTER(WTSTransStruct), c_uint32, c_bool) -CB_DTHELPER_BAR = CFUNCTYPE(c_void_p, POINTER(WTSBarStruct), c_uint32, c_bool) +CB_DTHELPER_COUNT = CFUNCTYPE(c_void_p, c_uint32) -CB_DTHELPER_COUNT = CFUNCTYPE(c_void_p, c_uint32) @singleton class WtDataHelper: - ''' + """ Wt平台数据组件C接口底层对接模块 - ''' + """ # api可以作为公共变量 api = None @@ -32,230 +36,249 @@ def __init__(self): a = (paths[:-1] + (dllname,)) _path = os.path.join(*a) self.api = cdll.LoadLibrary(_path) - + self.cb_dthelper_log = CB_DTHELPER_LOG(self.on_log_output) - self.api.resample_bars.argtypes = [c_char_p, CB_DTHELPER_BAR, CB_DTHELPER_COUNT, c_uint64, c_uint64, c_char_p, c_uint32, c_char_p, CB_DTHELPER_LOG] + self.api.resample_bars.argtypes = [c_char_p, CB_DTHELPER_BAR, CB_DTHELPER_COUNT, c_uint64, c_uint64, c_char_p, + c_uint32, c_char_p, CB_DTHELPER_LOG] - def on_log_output(self, message:str): + def on_log_output(self, message: str): message = bytes.decode(message, 'utf-8') logging.info(message) - def dump_bars(self, binFolder:str, csvFolder:str, strFilter:str=""): - ''' + def dump_bars(self, binFolder: str, csvFolder: str, strFilter: str = ""): + """ 将目录下的.dsb格式的历史K线数据导出为.csv格式 @binFolder .dsb文件存储目录 @csvFolder .csv文件的输出目录 @strFilter 代码过滤器(暂未启用) - ''' - self.api.dump_bars(bytes(binFolder, encoding="utf8"), bytes(csvFolder, encoding="utf8"), bytes(strFilter, encoding="utf8"), self.cb_dthelper_log) + """ + self.api.dump_bars(bytes(binFolder, encoding="utf8"), bytes(csvFolder, encoding="utf8"), + bytes(strFilter, encoding="utf8"), self.cb_dthelper_log) - def dump_ticks(self, binFolder: str, csvFolder: str, strFilter: str=""): - ''' + def dump_ticks(self, binFolder: str, csvFolder: str, strFilter: str = ""): + """ 将目录下的.dsb格式的历史Tik数据导出为.csv格式 @binFolder .dsb文件存储目录 @csvFolder .csv文件的输出目录 @strFilter 代码过滤器(暂未启用) - ''' - self.api.dump_ticks(bytes(binFolder, encoding="utf8"), bytes(csvFolder, encoding="utf8"), bytes(strFilter, encoding="utf8"), self.cb_dthelper_log) + """ + self.api.dump_ticks(bytes(binFolder, encoding="utf8"), bytes(csvFolder, encoding="utf8"), + bytes(strFilter, encoding="utf8"), self.cb_dthelper_log) def trans_csv_bars(self, csvFolder: str, binFolder: str, period: str): - ''' + """ 将目录下的.csv格式的历史K线数据转成.dsb格式 @csvFolder .csv文件的输出目录 @binFolder .dsb文件存储目录 @period K线周期,m1-1分钟线,m5-5分钟线,d-日线 - ''' - self.api.trans_csv_bars(bytes(csvFolder, encoding="utf8"), bytes(binFolder, encoding="utf8"), bytes(period, encoding="utf8"), self.cb_dthelper_log) - - def trans_bars(self, barFile:str, getter, count:int, period:str) -> bool: - ''' + """ + self.api.trans_csv_bars(bytes(csvFolder, encoding="utf8"), bytes(binFolder, encoding="utf8"), + bytes(period, encoding="utf8"), self.cb_dthelper_log) + + def trans_bars(self, barFile: str, getter, count: int, period: str) -> bool: + """ 将K线转储到dsb文件中 @barFile 要存储的文件路径 @getter 获取bar的回调函数 @count 一共要写入的数据条数 @period 周期,m1/m5/d - ''' + """ raise Exception("Method trans_bars is removed from core, use store_bars instead") # cb = CB_DTHELPER_BAR_GETTER(getter) # return self.api.trans_bars(bytes(barFile, encoding="utf8"), cb, count, bytes(period, encoding="utf8"), self.cb_dthelper_log) - def trans_ticks(self, tickFile:str, getter, count:int) -> bool: - ''' + def trans_ticks(self, tickFile: str, getter, count: int) -> bool: + """ 将Tick数据转储到dsb文件中 @tickFile 要存储的文件路径 @getter 获取tick的回调函数 @count 一共要写入的数据条数 - ''' + """ raise Exception("Method trans_ticks is removed from core, use store_ticks instead") # cb = CB_DTHELPER_TICK_GETTER(getter) # return self.api.trans_ticks(bytes(tickFile, encoding="utf8"), cb, count, self.cb_dthelper_log) - def store_bars(self, barFile:str, firstBar:POINTER(WTSBarStruct), count:int, period:str) -> bool: - ''' + def store_bars(self, barFile: str, firstBar: POINTER(WTSBarStruct), count: int, period: str) -> bool: + """ 将K线转储到dsb文件中 @barFile 要存储的文件路径 @firstBar 第一条bar的指针 @count 一共要写入的数据条数 @period 周期,m1/m5/d - ''' - return self.api.store_bars(bytes(barFile, encoding="utf8"), firstBar, count, bytes(period, encoding="utf8"), self.cb_dthelper_log) + """ + return self.api.store_bars(bytes(barFile, encoding="utf8"), firstBar, count, bytes(period, encoding="utf8"), + self.cb_dthelper_log) - def store_ticks(self, tickFile:str, firstTick:POINTER(WTSTickStruct), count:int) -> bool: - ''' + def store_ticks(self, tickFile: str, firstTick: POINTER(WTSTickStruct), count: int) -> bool: + """ 将Tick数据转储到dsb文件中 @tickFile 要存储的文件路径 @firstTick 第一条tick的指针 @count 一共要写入的数据条数 - ''' + """ # cb = CB_DTHELPER_TICK_GETTER(getter) return self.api.store_ticks(bytes(tickFile, encoding="utf8"), firstTick, count, self.cb_dthelper_log) - - def store_order_details(self, targetFile:str, firstItem:POINTER(WTSOrdDtlStruct), count:int) -> bool: - ''' + + def store_order_details(self, targetFile: str, firstItem: POINTER(WTSOrdDtlStruct), count: int) -> bool: + """ 将委托明细数据转储到dsb文件中 @tickFile 要存储的文件路径 @firstItem 第一条数据的指针 @count 一共要写入的数据条数 - ''' + """ return self.api.store_order_details(bytes(targetFile, encoding="utf8"), firstItem, count, self.cb_dthelper_log) - - def store_order_queues(self, targetFile:str, firstItem:POINTER(WTSOrdQueStruct), count:int) -> bool: - ''' + + def store_order_queues(self, targetFile: str, firstItem: POINTER(WTSOrdQueStruct), count: int) -> bool: + """ 将委托队列数据转储到dsb文件中 @tickFile 要存储的文件路径 @firstItem 第一条数据的指针 @count 一共要写入的数据条数 - ''' + """ return self.api.store_order_queues(bytes(targetFile, encoding="utf8"), firstItem, count, self.cb_dthelper_log) - - def store_transactions(self, targetFile:str, firstItem:POINTER(WTSTransStruct), count:int) -> bool: - ''' + + def store_transactions(self, targetFile: str, firstItem: POINTER(WTSTransStruct), count: int) -> bool: + """ 将逐笔成交数据转储到dsb文件中 @tickFile 要存储的文件路径 @firstItem 第一条数据的指针 @count 一共要写入的数据条数 - ''' + """ return self.api.store_transactions(bytes(targetFile, encoding="utf8"), firstItem, count, self.cb_dthelper_log) - - def read_dsb_bars(self, barFile: str, isDay:bool = False) -> WtNpKline: - ''' + + def read_dsb_bars(self, barFile: str, isDay: bool = False) -> WtNpKline: + """ 读取.dsb格式的K线数据 @tickFile .dsb的K线数据文件 @return WtNpKline,可以通过WtNpKline.ndarray获取numpy的ndarray对象 - ''' + """ bar_cache = WtBarCache(isDay, forceCopy=True) - if 0 == self.api.read_dsb_bars(bytes(barFile, encoding="utf8"), CB_DTHELPER_BAR(bar_cache.on_read_bar), CB_DTHELPER_COUNT(bar_cache.on_data_count), self.cb_dthelper_log): + if 0 == self.api.read_dsb_bars(bytes(barFile, encoding="utf8"), CB_DTHELPER_BAR(bar_cache.on_read_bar), + CB_DTHELPER_COUNT(bar_cache.on_data_count), self.cb_dthelper_log): return None else: return bar_cache.records def read_dmb_ticks(self, tickFile: str) -> WtNpTicks: - ''' + """ 读取.dmb格式的tick数据 @tickFile .dmb的tick数据文件 @return WtNpTicks - ''' + """ tick_cache = WtTickCache(forceCopy=True) - if 0 == self.api.read_dmb_ticks(bytes(tickFile, encoding="utf8"), CB_DTHELPER_TICK(tick_cache.on_read_tick), CB_DTHELPER_COUNT(tick_cache.on_data_count), self.cb_dthelper_log): + if 0 == self.api.read_dmb_ticks(bytes(tickFile, encoding="utf8"), CB_DTHELPER_TICK(tick_cache.on_read_tick), + CB_DTHELPER_COUNT(tick_cache.on_data_count), self.cb_dthelper_log): return None else: return tick_cache.records def read_dmb_bars(self, barFile: str) -> WtNpKline: - ''' + """ 读取.dmb格式的K线数据 @tickFile .dmb的K线数据文件 @return WtNpKline - ''' + """ bar_cache = WtBarCache(forceCopy=True) - if 0 == self.api.read_dmb_bars(bytes(barFile, encoding="utf8"), CB_DTHELPER_BAR(bar_cache.on_read_bar), CB_DTHELPER_COUNT(bar_cache.on_data_count), self.cb_dthelper_log): + if 0 == self.api.read_dmb_bars(bytes(barFile, encoding="utf8"), CB_DTHELPER_BAR(bar_cache.on_read_bar), + CB_DTHELPER_COUNT(bar_cache.on_data_count), self.cb_dthelper_log): return None else: return bar_cache.records def read_dsb_ticks(self, tickFile: str) -> WtNpTicks: - ''' + """ 读取.dsb格式的tick数据 @tickFile .dsb的tick数据文件 @return WtNpTicks,可以通过WtNpTicks.ndarray获取numpy的ndarray对象 - ''' + """ tick_cache = WtTickCache(forceCopy=True) - if 0 == self.api.read_dsb_ticks(bytes(tickFile, encoding="utf8"), CB_DTHELPER_TICK(tick_cache.on_read_tick), CB_DTHELPER_COUNT(tick_cache.on_data_count), self.cb_dthelper_log): + if 0 == self.api.read_dsb_ticks(bytes(tickFile, encoding="utf8"), CB_DTHELPER_TICK(tick_cache.on_read_tick), + CB_DTHELPER_COUNT(tick_cache.on_data_count), self.cb_dthelper_log): return None else: return tick_cache.records def read_dsb_order_details(self, dataFile: str) -> WtNpOrdDetails: - ''' + """ 读取.dsb格式的委托明细数据 @dataFile .dsb的数据文件 @return WtNpOrdDetails - ''' + """ + class DataCache: def __init__(self): - self.records:WtNpOrdDetails = None + self.records: WtNpOrdDetails = None - def on_read_data(self, firstItem:POINTER(WTSOrdDtlStruct), count:int, isLast:bool): + def on_read_data(self, firstItem: POINTER(WTSOrdDtlStruct), count: int, isLast: bool): self.records = WtNpOrdDetails(forceCopy=True) self.records.set_data(firstItem, count) - def on_data_count(self, count:int): + def on_data_count(self, count: int): pass - + data_cache = DataCache() - if 0 == self.api.read_dsb_order_details(bytes(dataFile, encoding="utf8"), CB_DTHELPER_ORDDTL(data_cache.on_read_data), CB_DTHELPER_COUNT(data_cache.on_data_count), self.cb_dthelper_log): + if 0 == self.api.read_dsb_order_details(bytes(dataFile, encoding="utf8"), + CB_DTHELPER_ORDDTL(data_cache.on_read_data), + CB_DTHELPER_COUNT(data_cache.on_data_count), self.cb_dthelper_log): return None else: return data_cache.records - + def read_dsb_order_queues(self, dataFile: str) -> WtNpOrdQueues: - ''' + """ 读取.dsb格式的委托队列数据 @dataFile .dsb的数据文件 @return WtNpOrdQueues - ''' + """ + class DataCache: def __init__(self): - self.records:WtNpOrdQueues = None + self.records: WtNpOrdQueues = None - def on_read_data(self, firstItem:POINTER(WTSOrdQueStruct), count:int, isLast:bool): + def on_read_data(self, firstItem: POINTER(WTSOrdQueStruct), count: int, isLast: bool): self.records = WtNpOrdQueues(forceCopy=True) self.records.set_data(firstItem, count) - def on_data_count(self, count:int): + def on_data_count(self, count: int): pass - + data_cache = DataCache() - if 0 == self.api.read_dsb_order_queues(bytes(dataFile, encoding="utf8"), CB_DTHELPER_ORDQUE(data_cache.on_read_data), CB_DTHELPER_COUNT(data_cache.on_data_count), self.cb_dthelper_log): + if 0 == self.api.read_dsb_order_queues(bytes(dataFile, encoding="utf8"), + CB_DTHELPER_ORDQUE(data_cache.on_read_data), + CB_DTHELPER_COUNT(data_cache.on_data_count), self.cb_dthelper_log): return None else: return data_cache.records - + def read_dsb_transactions(self, dataFile: str) -> WtNpTransactions: - ''' + """ 读取.dsb格式的逐笔成交数据 @dataFile .dsb的数据文件 @return WtNpTransactions - ''' + """ + class DataCache: def __init__(self): - self.records:WtNpTransactions = None + self.records: WtNpTransactions = None - def on_read_data(self, firstItem:POINTER(WTSTransStruct), count:int, isLast:bool): + def on_read_data(self, firstItem: POINTER(WTSTransStruct), count: int, isLast: bool): self.records = WtNpTransactions(forceCopy=True) self.records.set_data(firstItem, count) - def on_data_count(self, count:int): + def on_data_count(self, count: int): pass - + data_cache = DataCache() - if 0 == self.api.read_dsb_transactions(bytes(dataFile, encoding="utf8"), CB_DTHELPER_TRANS(data_cache.on_read_data), CB_DTHELPER_COUNT(data_cache.on_data_count), self.cb_dthelper_log): + if 0 == self.api.read_dsb_transactions(bytes(dataFile, encoding="utf8"), + CB_DTHELPER_TRANS(data_cache.on_read_data), + CB_DTHELPER_COUNT(data_cache.on_data_count), self.cb_dthelper_log): return None else: return data_cache.records - - def resample_bars(self, barFile:str, period:str, times:int, fromTime:int, endTime:int, sessInfo:SessionInfo, alignSection:bool = False) -> WtNpKline: - ''' + + def resample_bars(self, barFile: str, period: str, times: int, fromTime: int, endTime: int, sessInfo: SessionInfo, + alignSection: bool = False) -> WtNpKline: + """ 重采样K线 @barFile dsb格式的K线数据文件 @period 基础K线周期,m1/m5/d @@ -263,10 +286,12 @@ def resample_bars(self, barFile:str, period:str, times:int, fromTime:int, endTim @fromTime 开始时间,日线数据格式yyyymmdd,分钟线数据为格式为yyyymmddHHMMSS @endTime 结束时间,日线数据格式yyyymmdd,分钟线数据为格式为yyyymmddHHMMSS @sessInfo 交易时间模板 - ''' + """ bar_cache = WtBarCache(forceCopy=True) - if 0 == self.api.resample_bars(bytes(barFile, encoding="utf8"), CB_DTHELPER_BAR(bar_cache.on_read_bar), CB_DTHELPER_COUNT(bar_cache.on_data_count), - fromTime, endTime, bytes(period,'utf8'), times, bytes(sessInfo.toString(),'utf8'), self.cb_dthelper_log, alignSection): + if 0 == self.api.resample_bars(bytes(barFile, encoding="utf8"), CB_DTHELPER_BAR(bar_cache.on_read_bar), + CB_DTHELPER_COUNT(bar_cache.on_data_count), + fromTime, endTime, bytes(period, 'utf8'), times, + bytes(sessInfo.toString(), 'utf8'), self.cb_dthelper_log, alignSection): return None else: - return bar_cache.records \ No newline at end of file + return bar_cache.records diff --git a/wtpy/wrapper/WtDtServoApi.py b/wtpy/wrapper/WtDtServoApi.py index faf98695..0481b786 100644 --- a/wtpy/wrapper/WtDtServoApi.py +++ b/wtpy/wrapper/WtDtServoApi.py @@ -5,15 +5,16 @@ from wtpy.WtUtilDefs import singleton import os -CB_GET_BAR = CFUNCTYPE(c_void_p, POINTER(WTSBarStruct), c_uint32, c_bool) -CB_GET_TICK = CFUNCTYPE(c_void_p, POINTER(WTSTickStruct), c_uint32, c_bool) -CB_DATA_COUNT = CFUNCTYPE(c_void_p, c_uint32) +CB_GET_BAR = CFUNCTYPE(c_void_p, POINTER(WTSBarStruct), c_uint32, c_bool) +CB_GET_TICK = CFUNCTYPE(c_void_p, POINTER(WTSTickStruct), c_uint32, c_bool) +CB_DATA_COUNT = CFUNCTYPE(c_void_p, c_uint32) + @singleton class WtDtServoApi: - ''' + """ Wt平台数据组件C接口底层对接模块 - ''' + """ # api可以作为公共变量 api = None @@ -40,92 +41,102 @@ def __init__(self): self.api.get_sbars_by_date.argtypes = [c_char_p, c_uint32, c_uint32, CB_GET_BAR, CB_DATA_COUNT] self.api.get_bars_by_date.argtypes = [c_char_p, c_char_p, c_uint32, CB_GET_BAR, CB_DATA_COUNT] - def initialize(self, cfgfile:str, isFile:bool, logcfg:str = 'logcfg.yaml'): - self.api.initialize(bytes(cfgfile, encoding = "utf8"), isFile, bytes(logcfg, encoding = "utf8")) + def initialize(self, cfgfile: str, isFile: bool, logcfg: str = 'logcfg.yaml'): + self.api.initialize(bytes(cfgfile, encoding="utf8"), isFile, bytes(logcfg, encoding="utf8")) def clear_cache(self): self.api.clear_cache() - def get_bars(self, stdCode:str, period:str, fromTime:int = None, dataCount:int = None, endTime:int = 0) -> WtNpKline: - ''' + def get_bars(self, stdCode: str, period: str, fromTime: int = None, dataCount: int = None, + endTime: int = 0) -> WtNpKline: + """ 获取K线数据 @stdCode 标准合约代码 @period 基础K线周期, m1/m5/d @fromTime 开始时间, 日线数据格式yyyymmdd, 分钟线数据为格式为yyyymmddHHMM @endTime 结束时间, 日线数据格式yyyymmdd, 分钟线数据为格式为yyyymmddHHMM, 为0则读取到最后一条 - ''' + """ bar_cache = WtBarCache() if fromTime is not None: - ret = self.api.get_bars_by_range(bytes(stdCode, encoding="utf8"), bytes(period,'utf8'), fromTime, endTime, CB_GET_BAR(bar_cache.on_read_bar), CB_DATA_COUNT(bar_cache.on_data_count)) + ret = self.api.get_bars_by_range(bytes(stdCode, encoding="utf8"), bytes(period, 'utf8'), fromTime, endTime, + CB_GET_BAR(bar_cache.on_read_bar), CB_DATA_COUNT(bar_cache.on_data_count)) else: - ret = self.api.get_bars_by_count(bytes(stdCode, encoding="utf8"), bytes(period,'utf8'), dataCount, endTime, CB_GET_BAR(bar_cache.on_read_bar), CB_DATA_COUNT(bar_cache.on_data_count)) + ret = self.api.get_bars_by_count(bytes(stdCode, encoding="utf8"), bytes(period, 'utf8'), dataCount, endTime, + CB_GET_BAR(bar_cache.on_read_bar), CB_DATA_COUNT(bar_cache.on_data_count)) if ret == 0: return None else: return bar_cache.records - def get_ticks(self, stdCode:str, fromTime:int = None, dataCount:int = None, endTime:int = 0) -> WtNpTicks: - ''' + def get_ticks(self, stdCode: str, fromTime: int = None, dataCount: int = None, endTime: int = 0) -> WtNpTicks: + """ 获取tick数据 @stdCode 标准合约代码 @fromTime 开始时间, 格式为yyyymmddHHMM @endTime 结束时间, 格式为yyyymmddHHMM, 为0则读取到最后一条 - ''' + """ tick_cache = WtTickCache() if fromTime is not None: - ret = self.api.get_ticks_by_range(bytes(stdCode, encoding="utf8"), fromTime, endTime, CB_GET_TICK(tick_cache.on_read_tick), CB_DATA_COUNT(tick_cache.on_data_count)) + ret = self.api.get_ticks_by_range(bytes(stdCode, encoding="utf8"), fromTime, endTime, + CB_GET_TICK(tick_cache.on_read_tick), + CB_DATA_COUNT(tick_cache.on_data_count)) else: - ret = self.api.get_ticks_by_count(bytes(stdCode, encoding="utf8"), dataCount, endTime, CB_GET_TICK(tick_cache.on_read_tick), CB_DATA_COUNT(tick_cache.on_data_count)) + ret = self.api.get_ticks_by_count(bytes(stdCode, encoding="utf8"), dataCount, endTime, + CB_GET_TICK(tick_cache.on_read_tick), + CB_DATA_COUNT(tick_cache.on_data_count)) if ret == 0: return None else: return tick_cache.records - def get_ticks_by_date(self, stdCode:str, iDate:int) -> WtNpTicks: - ''' + def get_ticks_by_date(self, stdCode: str, iDate: int) -> WtNpTicks: + """ 按天读取tick数据 @stdCode 标准合约代码 @iDate 数据日期, 格式为yyyymmdd - ''' + """ tick_cache = WtTickCache() - ret = self.api.get_ticks_by_date(bytes(stdCode, encoding="utf8"), iDate, CB_GET_TICK(tick_cache.on_read_tick), CB_DATA_COUNT(tick_cache.on_data_count)) + ret = self.api.get_ticks_by_date(bytes(stdCode, encoding="utf8"), iDate, CB_GET_TICK(tick_cache.on_read_tick), + CB_DATA_COUNT(tick_cache.on_data_count)) if ret == 0: return None else: return tick_cache.records - def get_sbars_by_date(self, stdCode:str, iSec:int, iDate:int) -> WtNpKline: - ''' + def get_sbars_by_date(self, stdCode: str, iSec: int, iDate: int) -> WtNpKline: + """ 按天读取秒线 @stdCode 标准合约代码 @iSec 周期, 单位s @iDate 数据日期, 格式为yyyymmdd - ''' + """ bar_cache = WtBarCache() - ret = self.api.get_sbars_by_date(bytes(stdCode, encoding="utf8"), iSec, iDate, CB_GET_BAR(bar_cache.on_read_bar), CB_DATA_COUNT(bar_cache.on_data_count)) + ret = self.api.get_sbars_by_date(bytes(stdCode, encoding="utf8"), iSec, iDate, + CB_GET_BAR(bar_cache.on_read_bar), CB_DATA_COUNT(bar_cache.on_data_count)) if ret == 0: return None else: return bar_cache.records - def get_bars_by_date(self, stdCode:str, period:str, iDate:int) -> WtNpKline: - ''' + def get_bars_by_date(self, stdCode: str, period: str, iDate: int) -> WtNpKline: + """ 按天读取分钟线 @stdCode 标准合约代码 @period 周期, 分钟线 @iDate 数据日期, 格式为yyyymmdd - ''' + """ if period[0] != 'm': return None bar_cache = WtBarCache() - ret = self.api.get_bars_by_date(bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8"), iDate, CB_GET_BAR(bar_cache.on_read_bar), CB_DATA_COUNT(bar_cache.on_data_count)) + ret = self.api.get_bars_by_date(bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8"), iDate, + CB_GET_BAR(bar_cache.on_read_bar), CB_DATA_COUNT(bar_cache.on_data_count)) if ret == 0: return None else: - return bar_cache.records \ No newline at end of file + return bar_cache.records diff --git a/wtpy/wrapper/WtDtWrapper.py b/wtpy/wrapper/WtDtWrapper.py index 2fe57b99..04f4d2ca 100644 --- a/wtpy/wrapper/WtDtWrapper.py +++ b/wtpy/wrapper/WtDtWrapper.py @@ -1,21 +1,23 @@ from ctypes import cdll, c_char_p, c_bool, POINTER from .PlatformHelper import PlatformHelper as ph from wtpy.WtUtilDefs import singleton -from wtpy.WtCoreDefs import WTSTickStruct, WTSBarStruct, CB_PARSER_EVENT, CB_PARSER_SUBCMD, FUNC_DUMP_HISBARS, FUNC_DUMP_HISTICKS +from wtpy.WtCoreDefs import WTSTickStruct, WTSBarStruct, CB_PARSER_EVENT, CB_PARSER_SUBCMD, FUNC_DUMP_HISBARS, \ + FUNC_DUMP_HISTICKS from wtpy.WtCoreDefs import EVENT_PARSER_CONNECT, EVENT_PARSER_DISCONNECT, EVENT_PARSER_INIT, EVENT_PARSER_RELEASE import os + # Python对接C接口的库 @singleton class WtDtWrapper: - ''' + """ Wt平台数据组件C接口底层对接模块 - ''' + """ # api可以作为公共变量 api = None ver = "Unknown" - + # 构造函数,传入动态库名 def __init__(self, engine): self._engine = engine @@ -33,44 +35,47 @@ def __init__(self, engine): self.api.create_ext_dumper.restype = c_bool self.api.create_ext_dumper.argtypes = [c_char_p] - def run_datakit(self, bAsync:bool = False): - ''' + def run_datakit(self, bAsync: bool = False): + """ 启动数据组件 - ''' + """ self.api.start(bAsync) - def write_log(self, level, message:str, catName:str = ""): - ''' + def write_log(self, level, message: str, catName: str = ""): + """ 向组件输出日志 - ''' - self.api.write_log(level, bytes(message, encoding = "utf8").decode('utf-8').encode('gbk'), bytes(catName, encoding = "utf8")) + """ + self.api.write_log(level, bytes(message, encoding="utf8").decode('utf-8').encode('gbk'), + bytes(catName, encoding="utf8")) - def initialize(self, cfgfile:str = "dtcfg.yaml", logprofile:str = "logcfgdt.jsyamlon", bCfgFile:bool = True, bLogCfgFile:bool = True): - ''' + def initialize(self, cfgfile: str = "dtcfg.yaml", logprofile: str = "logcfgdt.jsyamlon", bCfgFile: bool = True, + bLogCfgFile: bool = True): + """ C接口初始化 - ''' + """ try: - self.api.initialize(bytes(cfgfile, encoding = "utf8"), bytes(logprofile, encoding = "utf8"), bCfgFile, bLogCfgFile) + self.api.initialize(bytes(cfgfile, encoding="utf8"), bytes(logprofile, encoding="utf8"), bCfgFile, + bLogCfgFile) self.register_extended_module_callbacks() except OSError as oe: print(oe) - self.write_log(102, "WonderTrader datakit initialzied,version: %s" % (self.ver)) + self.write_log(102, "WonderTrader datakit initialzied,version: %s" % self.ver) - def create_extended_parser(self, id:str) -> bool: - return self.api.create_ext_parser(bytes(id, encoding = "utf8")) + def create_extended_parser(self, id: str) -> bool: + return self.api.create_ext_parser(bytes(id, encoding="utf8")) - def push_quote_from_exetended_parser(self, id:str, newTick:POINTER(WTSTickStruct), uProcFlag:int = 1): - return self.api.parser_push_quote(bytes(id, encoding = "utf8"), newTick, uProcFlag) + def push_quote_from_exetended_parser(self, id: str, newTick: POINTER(WTSTickStruct), uProcFlag: int = 1): + return self.api.parser_push_quote(bytes(id, encoding="utf8"), newTick, uProcFlag) - def register_extended_module_callbacks(self,): + def register_extended_module_callbacks(self, ): self.cb_parser_event = CB_PARSER_EVENT(self.on_parser_event) self.cb_parser_subcmd = CB_PARSER_SUBCMD(self.on_parser_sub) self.api.register_parser_callbacks(self.cb_parser_event, self.cb_parser_subcmd) - def create_extended_dumper(self, id:str) -> bool: - return self.api.create_ext_dumper(bytes(id, encoding = "utf8")) + def create_extended_dumper(self, id: str) -> bool: + return self.api.create_ext_dumper(bytes(id, encoding="utf8")) def register_extended_data_dumper(self): self.cb_bars_dumper = FUNC_DUMP_HISBARS(self.dump_his_bars) @@ -78,13 +83,13 @@ def register_extended_data_dumper(self): self.api.register_extended_dumper(self.cb_bars_dumper, self.cb_ticks_dumper) - def on_parser_event(self, evtId:int, id:str): + def on_parser_event(self, evtId: int, id: str): id = bytes.decode(id) engine = self._engine parser = engine.get_extended_parser(id) if parser is None: return - + if evtId == EVENT_PARSER_INIT: parser.init(engine) elif evtId == EVENT_PARSER_CONNECT: @@ -94,7 +99,7 @@ def on_parser_event(self, evtId:int, id:str): elif evtId == EVENT_PARSER_RELEASE: parser.release() - def on_parser_sub(self, id:str, fullCode:str, isForSub:bool): + def on_parser_sub(self, id: str, fullCode: str, isForSub: bool): id = bytes.decode(id) engine = self._engine parser = engine.get_extended_parser(id) @@ -106,7 +111,7 @@ def on_parser_sub(self, id:str, fullCode:str, isForSub:bool): else: parser.unsubscribe(fullCode) - def dump_his_bars(self, id:str, fullCode:str, period:str, bars:POINTER(WTSBarStruct), count:int) -> bool: + def dump_his_bars(self, id: str, fullCode: str, period: str, bars: POINTER(WTSBarStruct), count: int) -> bool: id = bytes.decode(id) engine = self._engine dumper = engine.get_extended_data_dumper(id) @@ -118,7 +123,7 @@ def dump_his_bars(self, id:str, fullCode:str, period:str, bars:POINTER(WTSBarStr return dumper.dump_his_bars(fullCode, period, bars, count) - def dump_his_ticks(self, id:str, fullCode:str, uDate:int, ticks:POINTER(WTSTickStruct), count:int) -> bool: + def dump_his_ticks(self, id: str, fullCode: str, uDate: int, ticks: POINTER(WTSTickStruct), count: int) -> bool: id = bytes.decode(id) engine = self._engine dumper = engine.get_extended_data_dumper(id) diff --git a/wtpy/wrapper/WtExecApi.py b/wtpy/wrapper/WtExecApi.py index 9a40b988..3b17be9e 100644 --- a/wtpy/wrapper/WtExecApi.py +++ b/wtpy/wrapper/WtExecApi.py @@ -3,9 +3,9 @@ from wtpy.WtUtilDefs import singleton import os + @singleton class WtExecApi: - # api可以作为公共变量 api = None ver = "Unknown" @@ -19,7 +19,7 @@ def __init__(self): self.api.get_version.restype = c_char_p self.ver = bytes.decode(self.api.get_version()) - + self.api.write_log.argtypes = [c_int, c_char_p, c_char_p] self.api.config_exec.argtypes = [c_char_p, c_bool] self.api.init_exec.argtypes = [c_char_p, c_bool] @@ -31,18 +31,19 @@ def run(self): def release(self): self.api.release_exec() - def write_log(self, level:int, message:str, catName:str = ""): - self.api.write_log(level, bytes(message, encoding = "utf8").decode('utf-8').encode('gbk'), bytes(catName, encoding = "utf8")) + def write_log(self, level: int, message: str, catName: str = ""): + self.api.write_log(level, bytes(message, encoding="utf8").decode('utf-8').encode('gbk'), + bytes(catName, encoding="utf8")) - def config(self, cfgfile:str = 'cfgexec.yaml', isFile:bool = True): - self.api.config_exec(bytes(cfgfile, encoding = "utf8"), isFile) + def config(self, cfgfile: str = 'cfgexec.yaml', isFile: bool = True): + self.api.config_exec(bytes(cfgfile, encoding="utf8"), isFile) - def initialize(self, logCfg:str = "logcfgexec.yaml", isFile:bool = True): - ''' + def initialize(self, logCfg: str = "logcfgexec.yaml", isFile: bool = True): + """ C接口初始化 - ''' - self.api.init_exec(bytes(logCfg, encoding = "utf8"), isFile) - self.write_log(102, "WonderTrader independent execution framework initialzied,version: %s" % (self.ver)) + """ + self.api.init_exec(bytes(logCfg, encoding="utf8"), isFile) + self.write_log(102, "WonderTrader independent execution framework initialzied,version: %s" % self.ver) - def set_position(self, stdCode:str, target:float): - self.api.set_position(bytes(stdCode, encoding = "utf8"), target) + def set_position(self, stdCode: str, target: float): + self.api.set_position(bytes(stdCode, encoding="utf8"), target) diff --git a/wtpy/wrapper/WtMQWrapper.py b/wtpy/wrapper/WtMQWrapper.py index 7232779e..f693ab08 100644 --- a/wtpy/wrapper/WtMQWrapper.py +++ b/wtpy/wrapper/WtMQWrapper.py @@ -3,22 +3,23 @@ from wtpy.WtUtilDefs import singleton import os -CB_ON_MSG = CFUNCTYPE(c_void_p, c_uint32, c_char_p, POINTER(c_char), c_uint32) -CB_ON_LOG = CFUNCTYPE(c_void_p, c_uint32, c_char_p, c_bool) +CB_ON_MSG = CFUNCTYPE(c_void_p, c_uint32, c_char_p, POINTER(c_char), c_uint32) +CB_ON_LOG = CFUNCTYPE(c_void_p, c_uint32, c_char_p, c_bool) + # Python对接C接口的库 @singleton class WtMQWrapper: - ''' + """ Wt平台数据组件C接口底层对接模块 - ''' + """ # api可以作为公共变量 api = None ver = "Unknown" - + # 构造函数,传入动态库名 - def __init__(self, logger = None): + def __init__(self, logger=None): self._logger = logger dllname = ph.getModule("WtMsgQue") paths = os.path.split(__file__) @@ -32,32 +33,31 @@ def __init__(self, logger = None): self.api.create_server.argtypes = [c_char_p, c_bool] self.api.create_server.restype = c_ulong - def on_mq_log(self, id:int, message:str, bServer:bool): + def on_mq_log(self, id: int, message: str, bServer: bool): message = bytes.decode(message) if self._logger is not None: self._logger.info(message) else: print(message) - def create_server(self, url:str): + def create_server(self, url: str): return self.api.create_server(bytes(url, 'utf-8'), True) - def destroy_server(self, id:int): + def destroy_server(self, id: int): self.api.destroy_server(id) - def publish_message(self, id:int, topic:str, message:str): + def publish_message(self, id: int, topic: str, message: str): message = bytes(message, 'utf-8') self.api.publish_message(id, bytes(topic, 'utf-8'), message, len(message)) - def create_client(self, url:str, cbMsg:CB_ON_MSG): + def create_client(self, url: str, cbMsg: CB_ON_MSG): return self.api.create_client(bytes(url, 'utf-8'), cbMsg) - def destroy_client(self, id:int): + def destroy_client(self, id: int): self.api.destroy_client(id) - def subcribe_topic(self, id:int, topic:str): + def subcribe_topic(self, id: int, topic: str): self.api.subscribe_topic(id, bytes(topic, 'utf-8')) - def start_client(self, id:int): + def start_client(self, id: int): self.api.start_client(id) - diff --git a/wtpy/wrapper/WtWrapper.py b/wtpy/wrapper/WtWrapper.py index fc5243c7..03488dff 100644 --- a/wtpy/wrapper/WtWrapper.py +++ b/wtpy/wrapper/WtWrapper.py @@ -1,9 +1,12 @@ from ctypes import c_int32, cdll, c_char_p, c_bool, c_ulong, c_uint64, c_uint32, c_double, POINTER from wtpy.WtCoreDefs import CB_EXECUTER_CMD, CB_EXECUTER_INIT, CB_PARSER_EVENT, CB_PARSER_SUBCMD -from wtpy.WtCoreDefs import CB_STRATEGY_INIT, CB_STRATEGY_TICK, CB_STRATEGY_CALC, CB_STRATEGY_BAR, CB_STRATEGY_GET_BAR, CB_STRATEGY_GET_TICK, CB_STRATEGY_GET_POSITION, CB_STRATEGY_COND_TRIGGER +from wtpy.WtCoreDefs import CB_STRATEGY_INIT, CB_STRATEGY_TICK, CB_STRATEGY_CALC, CB_STRATEGY_BAR, CB_STRATEGY_GET_BAR, \ + CB_STRATEGY_GET_TICK, CB_STRATEGY_GET_POSITION, CB_STRATEGY_COND_TRIGGER from wtpy.WtCoreDefs import EVENT_PARSER_CONNECT, EVENT_PARSER_DISCONNECT, EVENT_PARSER_INIT, EVENT_PARSER_RELEASE -from wtpy.WtCoreDefs import CB_HFTSTRA_CHNL_EVT, CB_HFTSTRA_ENTRUST, CB_HFTSTRA_ORD, CB_HFTSTRA_TRD, CB_SESSION_EVENT, CB_HFTSTRA_POSITION -from wtpy.WtCoreDefs import CB_HFTSTRA_ORDQUE, CB_HFTSTRA_ORDDTL, CB_HFTSTRA_TRANS, CB_HFTSTRA_GET_ORDQUE, CB_HFTSTRA_GET_ORDDTL, CB_HFTSTRA_GET_TRANS +from wtpy.WtCoreDefs import CB_HFTSTRA_CHNL_EVT, CB_HFTSTRA_ENTRUST, CB_HFTSTRA_ORD, CB_HFTSTRA_TRD, CB_SESSION_EVENT, \ + CB_HFTSTRA_POSITION +from wtpy.WtCoreDefs import CB_HFTSTRA_ORDQUE, CB_HFTSTRA_ORDDTL, CB_HFTSTRA_TRANS, CB_HFTSTRA_GET_ORDQUE, \ + CB_HFTSTRA_GET_ORDDTL, CB_HFTSTRA_GET_TRANS from wtpy.WtCoreDefs import CHNL_EVENT_READY, CHNL_EVENT_LOST, CB_ENGINE_EVENT from wtpy.WtCoreDefs import FUNC_LOAD_HISBARS, FUNC_LOAD_HISTICKS, FUNC_LOAD_ADJFACTS from wtpy.WtCoreDefs import EVENT_ENGINE_INIT, EVENT_SESSION_BEGIN, EVENT_SESSION_END, EVENT_ENGINE_SCHDL @@ -13,17 +16,18 @@ from .PlatformHelper import PlatformHelper as ph import os + # Python对接C接口的库 @singleton class WtWrapper: - ''' + """ Wt平台C接口底层对接模块 - ''' + """ # api可以作为公共变量 api = None ver = "Unknown" - + # 构造函数, 传入动态库名 def __init__(self, engine): self._engine = engine @@ -32,7 +36,7 @@ def __init__(self, engine): a = (paths[:-1] + (dllname,)) _path = os.path.join(*a) self.api = cdll.LoadLibrary(_path) - + self.api.get_version.restype = c_char_p self.api.cta_get_last_entertime.restype = c_uint64 self.api.cta_get_first_entertime.restype = c_uint64 @@ -104,7 +108,7 @@ def __init__(self, engine): self.api.get_raw_stdcode.restype = c_char_p - def on_engine_event(self, evtid:int, evtDate:int, evtTime:int): + def on_engine_event(self, evtid: int, evtDate: int, evtTime: int): engine = self._engine if evtid == EVENT_ENGINE_INIT: engine.on_init() @@ -116,15 +120,15 @@ def on_engine_event(self, evtid:int, evtDate:int, evtTime:int): engine.on_session_end(evtDate) return - #回调函数 - def on_stra_init(self, id:int): + # 回调函数 + def on_stra_init(self, id: int): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_init() return - def on_session_event(self, id:int, udate:int, isBegin:bool): + def on_session_event(self, id: int, udate: int, isBegin: bool): engine = self._engine ctx = engine.get_context(id) if ctx is not None: @@ -134,41 +138,42 @@ def on_session_event(self, id:int, udate:int, isBegin:bool): ctx.on_session_end(udate) return - def on_stra_calc(self, id:int, curDate:int, curTime:int): + def on_stra_calc(self, id: int, curDate: int, curTime: int): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_calculate() return - - def on_stra_tick(self, id:int, stdCode:str, newTick:POINTER(WTSTickStruct)): + + def on_stra_tick(self, id: int, stdCode: str, newTick: POINTER(WTSTickStruct)): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_tick(bytes.decode(stdCode), newTick) return - - def on_stra_bar(self, id:int, stdCode:str, period:str, newBar:POINTER(WTSBarStruct)): + + def on_stra_bar(self, id: int, stdCode: str, period: str, newBar: POINTER(WTSBarStruct)): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_bar(bytes.decode(stdCode), bytes.decode(period), newBar) return - - def on_stra_get_bar(self, id:int, stdCode:str, period:str, curBar:POINTER(WTSBarStruct), count:int, isLast:bool): - ''' + + def on_stra_get_bar(self, id: int, stdCode: str, period: str, curBar: POINTER(WTSBarStruct), count: int, + isLast: bool): + """ 获取K线回调, 该回调函数因为是python主动发起的, 需要同步执行, 所以不走事件推送 @id 策略id @stdCode 合约代码 @period K线周期 @curBar 最新一条K线 @isLast 是否是最后一条 - ''' + """ engine = self._engine ctx = engine.get_context(id) period = bytes.decode(period) - isDay = period[0]=='d' + isDay = period[0] == 'd' npBars = WtNpKline(isDay, forceCopy=False) npBars.set_data(curBar, count) @@ -176,14 +181,14 @@ def on_stra_get_bar(self, id:int, stdCode:str, period:str, curBar:POINTER(WTSBar if ctx is not None: ctx.on_getbars(bytes.decode(stdCode), period, npBars) - def on_stra_get_tick(self, id:int, stdCode:str, curTick:POINTER(WTSTickStruct), count:int, isLast:bool): - ''' + def on_stra_get_tick(self, id: int, stdCode: str, curTick: POINTER(WTSTickStruct), count: int, isLast: bool): + """ 获取Tick回调, 该回调函数因为是python主动发起的, 需要同步执行, 所以不走事件推送 @id 策略id @stdCode 合约代码 @curTick 最新一笔Tick @isLast 是否是最后一条 - ''' + """ engine = self._engine ctx = engine.get_context(id) @@ -194,42 +199,44 @@ def on_stra_get_tick(self, id:int, stdCode:str, curTick:POINTER(WTSTickStruct), ctx.on_getticks(bytes.decode(stdCode), npTicks) return - def on_stra_get_position(self, id:int, stdCode:str, qty:float, frozen:float): + def on_stra_get_position(self, id: int, stdCode: str, qty: float, frozen: float): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_getpositions(bytes.decode(stdCode), qty, frozen) - def on_stra_cond_triggerd(self, id:int, stdCode:str, target:float, price:float, usertag:str): + def on_stra_cond_triggerd(self, id: int, stdCode: str, target: float, price: float, usertag: str): engine = self._engine ctx = engine.get_context(id) if ctx is not None: ctx.on_condition_triggered(bytes.decode(stdCode), target, price, bytes.decode(usertag)) - def on_hftstra_channel_evt(self, id:int, trader:str, evtid:int): + def on_hftstra_channel_evt(self, id: int, trader: str, evtid: int): engine = self._engine ctx = engine.get_context(id) - + if evtid == CHNL_EVENT_READY: ctx.on_channel_ready() elif evtid == CHNL_EVENT_LOST: ctx.on_channel_lost() - def on_hftstra_order(self, id:int, localid:int, stdCode:str, isBuy:bool, totalQty:float, leftQty:float, price:float, isCanceled:bool, userTag:str): + def on_hftstra_order(self, id: int, localid: int, stdCode: str, isBuy: bool, totalQty: float, leftQty: float, + price: float, isCanceled: bool, userTag: str): stdCode = bytes.decode(stdCode) userTag = bytes.decode(userTag) engine = self._engine ctx = engine.get_context(id) ctx.on_order(localid, stdCode, isBuy, totalQty, leftQty, price, isCanceled, userTag) - def on_hftstra_trade(self, id:int, localid:int, stdCode:str, isBuy:bool, qty:float, price:float, userTag:str): + def on_hftstra_trade(self, id: int, localid: int, stdCode: str, isBuy: bool, qty: float, price: float, + userTag: str): stdCode = bytes.decode(stdCode) userTag = bytes.decode(userTag) engine = self._engine ctx = engine.get_context(id) ctx.on_trade(localid, stdCode, isBuy, qty, price, userTag) - def on_hftstra_entrust(self, id:int, localid:int, stdCode:str, bSucc:bool, message:str, userTag:str): + def on_hftstra_entrust(self, id: int, localid: int, stdCode: str, bSucc: bool, message: str, userTag: str): stdCode = bytes.decode(stdCode) message = bytes.decode(message, "gbk") userTag = bytes.decode(userTag) @@ -237,13 +244,14 @@ def on_hftstra_entrust(self, id:int, localid:int, stdCode:str, bSucc:bool, messa ctx = engine.get_context(id) ctx.on_entrust(localid, stdCode, bSucc, message, userTag) - def on_hftstra_position(self, id:int, stdCode:str, isLong:bool, prevol:float, preavail:float, newvol:float, newavail:float): + def on_hftstra_position(self, id: int, stdCode: str, isLong: bool, prevol: float, preavail: float, newvol: float, + newavail: float): stdCode = bytes.decode(stdCode) engine = self._engine ctx = engine.get_context(id) ctx.on_position(stdCode, isLong, prevol, preavail, newvol, newavail) - def on_hftstra_order_queue(self, id:int, stdCode:str, newOrdQue:POINTER(WTSOrdQueStruct)): + def on_hftstra_order_queue(self, id: int, stdCode: str, newOrdQue: POINTER(WTSOrdQueStruct)): stdCode = bytes.decode(stdCode) engine = self._engine ctx = engine.get_context(id) @@ -251,7 +259,8 @@ def on_hftstra_order_queue(self, id:int, stdCode:str, newOrdQue:POINTER(WTSOrdQu if ctx is not None: ctx.on_order_queue(stdCode, newOrdQue) - def on_hftstra_get_order_queue(self, id:int, stdCode:str, newOrdQue:POINTER(WTSOrdQueStruct), count:int, isLast:bool): + def on_hftstra_get_order_queue(self, id: int, stdCode: str, newOrdQue: POINTER(WTSOrdQueStruct), count: int, + isLast: bool): engine = self._engine ctx = engine.get_context(id) npHftData = WtNpOrdQueues(forceCopy=False) @@ -260,47 +269,49 @@ def on_hftstra_get_order_queue(self, id:int, stdCode:str, newOrdQue:POINTER(WTSO if ctx is not None: ctx.on_get_order_queue(bytes.decode(stdCode), npHftData) - def on_hftstra_order_detail(self, id:int, stdCode:str, newOrdDtl:POINTER(WTSOrdDtlStruct)): + def on_hftstra_order_detail(self, id: int, stdCode: str, newOrdDtl: POINTER(WTSOrdDtlStruct)): engine = self._engine ctx = engine.get_context(id) - + if ctx is not None: ctx.on_order_detail(stdCode, newOrdDtl) - def on_hftstra_get_order_detail(self, id:int, stdCode:str, newOrdDtl:POINTER(WTSOrdDtlStruct), count:int, isLast:bool): + def on_hftstra_get_order_detail(self, id: int, stdCode: str, newOrdDtl: POINTER(WTSOrdDtlStruct), count: int, + isLast: bool): engine = self._engine ctx = engine.get_context(id) - + npHftData = WtNpOrdDetails(forceCopy=False) npHftData.set_data(newOrdDtl, count) - + if ctx is not None: ctx.on_get_order_detail(bytes.decode(stdCode), npHftData) - def on_hftstra_transaction(self, id:int, stdCode:str, newTrans:POINTER(WTSTransStruct)): + def on_hftstra_transaction(self, id: int, stdCode: str, newTrans: POINTER(WTSTransStruct)): engine = self._engine ctx = engine.get_context(id) - + if ctx is not None: ctx.on_transaction(stdCode, newTrans) - - def on_hftstra_get_transaction(self, d:int, stdCode:str, newTrans:POINTER(WTSTransStruct), count:int, isLast:bool): + + def on_hftstra_get_transaction(self, d: int, stdCode: str, newTrans: POINTER(WTSTransStruct), count: int, + isLast: bool): engine = self._engine ctx = engine.get_context(id) - + npHftData = WtNpTransactions(forceCopy=False) npHftData.set_data(newTrans, count) - + if ctx is not None: ctx.on_get_transaction(bytes.decode(stdCode), npHftData) - def on_parser_event(self, evtId:int, id:str): + def on_parser_event(self, evtId: int, id: str): id = bytes.decode(id) engine = self._engine parser = engine.get_extended_parser(id) if parser is None: return - + if evtId == EVENT_PARSER_INIT: parser.init(engine) elif evtId == EVENT_PARSER_CONNECT: @@ -310,7 +321,7 @@ def on_parser_event(self, evtId:int, id:str): elif evtId == EVENT_PARSER_RELEASE: parser.release() - def on_parser_sub(self, id:str, fullCode:str, isForSub:bool): + def on_parser_sub(self, id: str, fullCode: str, isForSub: bool): id = bytes.decode(id) engine = self._engine parser = engine.get_extended_parser(id) @@ -322,7 +333,7 @@ def on_parser_sub(self, id:str, fullCode:str, isForSub:bool): else: parser.unsubscribe(fullCode) - def on_executer_init(self, id:str): + def on_executer_init(self, id: str): id = bytes.decode(id) engine = self._engine executer = engine.get_extended_executer(id) @@ -331,7 +342,7 @@ def on_executer_init(self, id:str): executer.init() - def on_executer_cmd(self, id:str, stdCode:str, targetPos:float): + def on_executer_cmd(self, id: str, stdCode: str, targetPos: float): id = bytes.decode(id) engine = self._engine executer = engine.get_extended_executer(id) @@ -340,7 +351,7 @@ def on_executer_cmd(self, id:str, stdCode:str, targetPos:float): executer.set_position(bytes.decode(stdCode), targetPos) - def on_load_fnl_his_bars(self, stdCode:str, period:str): + def on_load_fnl_his_bars(self, stdCode: str, period: str): engine = self._engine loader = engine.get_extended_data_loader() if loader is None: @@ -349,7 +360,7 @@ def on_load_fnl_his_bars(self, stdCode:str, period:str): # feed_raw_bars(WTSBarStruct* bars, WtUInt32 count); loader.load_final_his_bars(bytes.decode(stdCode), bytes.decode(period), self.api.feed_raw_bars) - def on_load_raw_his_bars(self, stdCode:str, period:str): + def on_load_raw_his_bars(self, stdCode: str, period: str): engine = self._engine loader = engine.get_extended_data_loader() if loader is None: @@ -358,15 +369,15 @@ def on_load_raw_his_bars(self, stdCode:str, period:str): # feed_raw_bars(WTSBarStruct* bars, WtUInt32 count); loader.load_raw_his_bars(bytes.decode(stdCode), bytes.decode(period), self.api.feed_raw_bars) - def feed_adj_factors(self, stdCode:str, dates:list, factors:list): + def feed_adj_factors(self, stdCode: str, dates: list, factors: list): stdCode = bytes(stdCode, encoding="utf8") - ''' + """ TODO 这里类型要转一下! 底层接口是传数组的 feed_adj_factors(WtString stdCode, WtUInt32* dates, double* factors, WtUInt32 count) - ''' + """ self.api.feed_adj_factors(stdCode, dates, factors, len(dates)) - def on_load_adj_factors(self, stdCode:str) -> bool: + def on_load_adj_factors(self, stdCode: str) -> bool: engine = self._engine loader = engine.get_extended_data_loader() if loader is None: @@ -375,7 +386,7 @@ def on_load_adj_factors(self, stdCode:str) -> bool: stdCode = bytes.decode(stdCode) return loader.load_adj_factors(stdCode, self.feed_adj_factors) - def on_load_his_ticks(self, stdCode:str, uDate:int): + def on_load_his_ticks(self, stdCode: str, uDate: int): engine = self._engine loader = engine.get_extended_data_loader() if loader is None: @@ -384,32 +395,32 @@ def on_load_his_ticks(self, stdCode:str, uDate:int): # feed_raw_ticks(WTSTickStruct* ticks, WtUInt32 count); loader.load_his_ticks(bytes.decode(stdCode), uDate, self.api.feed_raw_ticks) - def write_log(self, level, message:str, catName:str = ""): - self.api.write_log(level, bytes(message, encoding = "utf8"), bytes(catName, encoding = "utf8")) + def write_log(self, level, message: str, catName: str = ""): + self.api.write_log(level, bytes(message, encoding="utf8"), bytes(catName, encoding="utf8")) - ### 实盘和回测有差异 ### - def run(self, bAsync:bool = True): + # ## 实盘和回测有差异 ### + def run(self, bAsync: bool = True): self.api.run_porter(bAsync) def release(self): self.api.release_porter() - def config(self, cfgfile:str = 'config.yaml', isFile:bool = True): - self.api.config_porter(bytes(cfgfile, encoding = "utf8"), isFile) + def config(self, cfgfile: str = 'config.yaml', isFile: bool = True): + self.api.config_porter(bytes(cfgfile, encoding="utf8"), isFile) - def get_raw_stdcode(self, stdCode:str): - return bytes.decode(self.api.get_raw_stdcode(bytes(stdCode, encoding = "utf8"))) + def get_raw_stdcode(self, stdCode: str): + return bytes.decode(self.api.get_raw_stdcode(bytes(stdCode, encoding="utf8"))) - def create_extended_parser(self, id:str) -> bool: - return self.api.create_ext_parser(bytes(id, encoding = "utf8")) + def create_extended_parser(self, id: str) -> bool: + return self.api.create_ext_parser(bytes(id, encoding="utf8")) - def create_extended_executer(self, id:str) -> bool: - return self.api.create_ext_executer(bytes(id, encoding = "utf8")) + def create_extended_executer(self, id: str) -> bool: + return self.api.create_ext_executer(bytes(id, encoding="utf8")) - def push_quote_from_exetended_parser(self, id:str, newTick:POINTER(WTSTickStruct), uProcFlag:int = 1): - return self.api.parser_push_quote(bytes(id, encoding = "utf8"), newTick, uProcFlag) + def push_quote_from_exetended_parser(self, id: str, newTick: POINTER(WTSTickStruct), uProcFlag: int = 1): + return self.api.parser_push_quote(bytes(id, encoding="utf8"), newTick, uProcFlag) - def register_extended_module_callbacks(self,): + def register_extended_module_callbacks(self, ): self.cb_parser_event = CB_PARSER_EVENT(self.on_parser_event) self.cb_parser_subcmd = CB_PARSER_SUBCMD(self.on_parser_sub) self.cb_executer_init = CB_EXECUTER_INIT(self.on_executer_init) @@ -419,21 +430,22 @@ def register_extended_module_callbacks(self,): self.api.register_exec_callbacks(self.cb_executer_init, self.cb_executer_cmd) def register_extended_data_loader(self): - ''' + """ 注册扩展历史数据加载器 - ''' + """ self.cb_load_fnlbars = FUNC_LOAD_HISBARS(self.on_load_fnl_his_bars) self.cb_load_rawbars = FUNC_LOAD_HISBARS(self.on_load_raw_his_bars) self.cb_load_histicks = FUNC_LOAD_HISTICKS(self.on_load_his_ticks) self.cb_load_adjfacts = FUNC_LOAD_ADJFACTS(self.on_load_adj_factors) - self.api.register_ext_data_loader(self.cb_load_fnlbars, self.cb_load_rawbars, self.cb_load_adjfacts, self.cb_load_histicks) + self.api.register_ext_data_loader(self.cb_load_fnlbars, self.cb_load_rawbars, self.cb_load_adjfacts, + self.cb_load_histicks) - ### 实盘和回测有差异 ### - def initialize_cta(self, logCfg:str = "logcfg.yaml", isFile:bool = True, genDir:str = 'generated'): - ''' + # ## 实盘和回测有差异 ### + def initialize_cta(self, logCfg: str = "logcfg.yaml", isFile: bool = True, genDir: str = 'generated'): + """ C接口初始化 - ''' + """ self.cb_stra_init = CB_STRATEGY_INIT(self.on_stra_init) self.cb_stra_tick = CB_STRATEGY_TICK(self.on_stra_tick) self.cb_stra_calc = CB_STRATEGY_CALC(self.on_stra_calc) @@ -444,18 +456,19 @@ def initialize_cta(self, logCfg:str = "logcfg.yaml", isFile:bool = True, genDir: self.cb_engine_event = CB_ENGINE_EVENT(self.on_engine_event) try: self.api.register_evt_callback(self.cb_engine_event) - self.api.register_cta_callbacks(self.cb_stra_init, self.cb_stra_tick, self.cb_stra_calc, self.cb_stra_bar, self.cb_session_event, self.cb_stra_cond_trigger) - self.api.init_porter(bytes(logCfg, encoding = "utf8"), isFile, bytes(genDir, encoding = "utf8")) + self.api.register_cta_callbacks(self.cb_stra_init, self.cb_stra_tick, self.cb_stra_calc, self.cb_stra_bar, + self.cb_session_event, self.cb_stra_cond_trigger) + self.api.init_porter(bytes(logCfg, encoding="utf8"), isFile, bytes(genDir, encoding="utf8")) self.register_extended_module_callbacks() except OSError as oe: print(oe) - self.write_log(102, "WonderTrader CTA production framework initialzied, version: %s" % (self.ver)) + self.write_log(102, "WonderTrader CTA production framework initialzied, version: %s" % self.ver) - def initialize_hft(self, logCfg:str = "logcfg.yaml", isFile:bool = True, genDir:str = 'generated'): - ''' + def initialize_hft(self, logCfg: str = "logcfg.yaml", isFile: bool = True, genDir: str = 'generated'): + """ C接口初始化 - ''' + """ self.cb_stra_init = CB_STRATEGY_INIT(self.on_stra_init) self.cb_stra_tick = CB_STRATEGY_TICK(self.on_stra_tick) self.cb_stra_bar = CB_STRATEGY_BAR(self.on_stra_bar) @@ -473,19 +486,21 @@ def initialize_hft(self, logCfg:str = "logcfg.yaml", isFile:bool = True, genDir: self.cb_engine_event = CB_ENGINE_EVENT(self.on_engine_event) try: self.api.register_evt_callback(self.cb_engine_event) - self.api.register_hft_callbacks(self.cb_stra_init, self.cb_stra_tick, self.cb_stra_bar, - self.cb_hftstra_chnl_evt, self.cb_hftstra_order, self.cb_hftstra_trade, self.cb_hftstra_entrust, - self.cb_hftstra_orddtl, self.cb_hftstra_ordque, self.cb_hftstra_trans, self.cb_session_event, self.cb_hftstra_position) - self.api.init_porter(bytes(logCfg, encoding = "utf8"), isFile, bytes(genDir, encoding = "utf8")) + self.api.register_hft_callbacks(self.cb_stra_init, self.cb_stra_tick, self.cb_stra_bar, + self.cb_hftstra_chnl_evt, self.cb_hftstra_order, self.cb_hftstra_trade, + self.cb_hftstra_entrust, + self.cb_hftstra_orddtl, self.cb_hftstra_ordque, self.cb_hftstra_trans, + self.cb_session_event, self.cb_hftstra_position) + self.api.init_porter(bytes(logCfg, encoding="utf8"), isFile, bytes(genDir, encoding="utf8")) except OSError as oe: print(oe) - self.write_log(102, "WonderTrader HFT production framework initialzied, version: %s" % (self.ver)) + self.write_log(102, "WonderTrader HFT production framework initialzied, version: %s" % self.ver) - def initialize_sel(self, logCfg:str = "logcfg.yaml", isFile:bool = True, genDir:str = 'generated'): - ''' + def initialize_sel(self, logCfg: str = "logcfg.yaml", isFile: bool = True, genDir: str = 'generated'): + """ C接口初始化 - ''' + """ self.cb_stra_init = CB_STRATEGY_INIT(self.on_stra_init) self.cb_stra_tick = CB_STRATEGY_TICK(self.on_stra_tick) self.cb_stra_calc = CB_STRATEGY_CALC(self.on_stra_calc) @@ -496,788 +511,818 @@ def initialize_sel(self, logCfg:str = "logcfg.yaml", isFile:bool = True, genDir: try: self.api.register_evt_callback(self.cb_engine_event) - self.api.register_sel_callbacks(self.cb_stra_init, self.cb_stra_tick, self.cb_stra_calc, self.cb_stra_bar, self.cb_session_event) - self.api.init_porter(bytes(logCfg, encoding = "utf8"), isFile, bytes(genDir, encoding = "utf8")) + self.api.register_sel_callbacks(self.cb_stra_init, self.cb_stra_tick, self.cb_stra_calc, self.cb_stra_bar, + self.cb_session_event) + self.api.init_porter(bytes(logCfg, encoding="utf8"), isFile, bytes(genDir, encoding="utf8")) self.register_extended_module_callbacks() except OSError as oe: print(oe) - self.write_log(102, "WonderTrader SEL production framework initialzied, version: %s" % (self.ver)) + self.write_log(102, "WonderTrader SEL production framework initialzied, version: %s" % self.ver) - def cta_enter_long(self, id:int, stdCode:str, qty:float, usertag:str, limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def cta_enter_long(self, id: int, stdCode: str, qty: float, usertag: str, limitprice: float = 0.0, + stopprice: float = 0.0): + """ 开多 @id 策略id @stdCode 合约代码 @qty 手数, 大于等于0 - ''' - self.api.cta_enter_long(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8"), limitprice, stopprice) + """ + self.api.cta_enter_long(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8"), limitprice, + stopprice) - def cta_exit_long(self, id:int, stdCode:str, qty:float, usertag:str, limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def cta_exit_long(self, id: int, stdCode: str, qty: float, usertag: str, limitprice: float = 0.0, + stopprice: float = 0.0): + """ 平多 @id 策略id @stdCode 合约代码 @qty 手数, 大于等于0 - ''' - self.api.cta_exit_long(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8"), limitprice, stopprice) + """ + self.api.cta_exit_long(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8"), limitprice, + stopprice) - def cta_enter_short(self, id:int, stdCode:str, qty:float, usertag:str, limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def cta_enter_short(self, id: int, stdCode: str, qty: float, usertag: str, limitprice: float = 0.0, + stopprice: float = 0.0): + """ 开空 @id 策略id @stdCode 合约代码 @qty 手数, 大于等于0 - ''' - self.api.cta_enter_short(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8"), limitprice, stopprice) + """ + self.api.cta_enter_short(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8"), limitprice, + stopprice) - def cta_exit_short(self, id:int, stdCode:str, qty:float, usertag:str, limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def cta_exit_short(self, id: int, stdCode: str, qty: float, usertag: str, limitprice: float = 0.0, + stopprice: float = 0.0): + """ 平空 @id 策略id @stdCode 合约代码 @qty 手数, 大于等于0 - ''' - self.api.cta_exit_short(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8"), limitprice, stopprice) - - def cta_get_bars(self, id:int, stdCode:str, period:str, count:int, isMain:bool): - ''' + """ + self.api.cta_exit_short(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8"), limitprice, + stopprice) + + def cta_get_bars(self, id: int, stdCode: str, period: str, count: int, isMain: bool): + """ 读取K线 @id 策略id @stdCode 合约代码 @period 周期, 如m1/m3/d1等 @count 条数 @isMain 是否主K线 - ''' - return self.api.cta_get_bars(id, bytes(stdCode, encoding = "utf8"), bytes(period, encoding = "utf8"), count, isMain, CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) - - def cta_get_ticks(self, id:int, stdCode:str, count:int): - ''' + """ + return self.api.cta_get_bars(id, bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8"), count, isMain, + CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) + + def cta_get_ticks(self, id: int, stdCode: str, count: int): + """ 读取Tick @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.cta_get_ticks(id, bytes(stdCode, encoding = "utf8"), count, CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) + """ + return self.api.cta_get_ticks(id, bytes(stdCode, encoding="utf8"), count, + CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) - def cta_get_position_profit(self, id:int, stdCode:str): - ''' + def cta_get_position_profit(self, id: int, stdCode: str): + """ 获取浮动盈亏 @id 策略id @stdCode 合约代码 @return 指定合约的浮动盈亏 - ''' - return self.api.cta_get_position_profit(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_position_profit(id, bytes(stdCode, encoding="utf8")) - def cta_get_position_avgpx(self, id:int, stdCode:str): - ''' + def cta_get_position_avgpx(self, id: int, stdCode: str): + """ 获取持仓均价 @id 策略id @stdCode 合约代码 @return 指定合约的持仓均价 - ''' - return self.api.cta_get_position_avgpx(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_position_avgpx(id, bytes(stdCode, encoding="utf8")) - def cta_get_all_position(self, id:int): - ''' + def cta_get_all_position(self, id: int): + """ 获取全部持仓 @id 策略id - ''' + """ return self.api.cta_get_all_position(id, CB_STRATEGY_GET_POSITION(self.on_stra_get_position)) - - def cta_get_position(self, id:int, stdCode:str, bonlyvalid:bool = False, usertag:str = ""): - ''' + + def cta_get_position(self, id: int, stdCode: str, bonlyvalid: bool = False, usertag: str = ""): + """ 获取持仓 @id 策略id @stdCode 合约代码 @bonlyvalid 只读可用持仓, 默认为False @usertag 进场标记, 如果为空则获取该合约全部持仓 @return 指定合约的持仓手数, 正为多, 负为空 - ''' - return self.api.cta_get_position(id, bytes(stdCode, encoding = "utf8"), bonlyvalid, bytes(usertag, encoding = "utf8")) + """ + return self.api.cta_get_position(id, bytes(stdCode, encoding="utf8"), bonlyvalid, + bytes(usertag, encoding="utf8")) - def cta_get_fund_data(self, id:int, flag:int) -> float: - ''' + def cta_get_fund_data(self, id: int, flag: int) -> float: + """ 获取资金数据 @id 策略id @flag 0-动态权益, 1-总平仓盈亏, 2-总浮动盈亏, 3-总手续费 @return 资金数据 - ''' + """ return self.api.cta_get_fund_data(id, flag) - def cta_get_price(self, stdCode:str) -> float: - ''' + def cta_get_price(self, stdCode: str) -> float: + """ 获取最新价格 @stdCode 合约代码 @return 指定合约的最新价格 - ''' - return self.api.cta_get_price(bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_price(bytes(stdCode, encoding="utf8")) - def cta_get_day_price(self, stdCode:str, flag:int = 0) -> float: - ''' + def cta_get_day_price(self, stdCode: str, flag: int = 0) -> float: + """ 获取当日价格 @stdCode 合约代码 @flag 价格标记, 0-开盘价, 1-最高价, 2-最低价, 3-最新价 @return 指定合约的价格 - ''' - return self.api.cta_get_day_price(bytes(stdCode, encoding = "utf8"), flag) + """ + return self.api.cta_get_day_price(bytes(stdCode, encoding="utf8"), flag) - def cta_set_position(self, id:int, stdCode:str, qty:float, usertag:str = "", limitprice:float = 0.0, stopprice:float = 0.0): - ''' + def cta_set_position(self, id: int, stdCode: str, qty: float, usertag: str = "", limitprice: float = 0.0, + stopprice: float = 0.0): + """ 设置目标仓位 @id 策略id @stdCode 合约代码 @qty 目标仓位, 正为多, 负为空 - ''' - self.api.cta_set_position(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8"), limitprice, stopprice) + """ + self.api.cta_set_position(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8"), limitprice, + stopprice) def cta_get_tdate(self) -> int: - ''' + """ 获取当前交易日 @return 当前交易日 - ''' + """ return self.api.cta_get_tdate() def cta_get_date(self) -> int: - ''' + """ 获取当前日期 @return 当前日期 - ''' + """ return self.api.cta_get_date() def cta_get_time(self) -> int: - ''' + """ 获取当前时间 @return 当前时间 - ''' + """ return self.api.cta_get_time() - def cta_get_first_entertime(self, id:int, stdCode:str) -> int: - ''' + def cta_get_first_entertime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的首次进场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.cta_get_first_entertime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_first_entertime(id, bytes(stdCode, encoding="utf8")) - def cta_get_last_entertag(self, id:int, stdCode:str) -> str: - ''' + def cta_get_last_entertag(self, id: int, stdCode: str) -> str: + """ 获取当前持仓的最后进场标记 @stdCode 合约代码 @return 进场标记 - ''' - return bytes.decode(self.api.cta_get_last_entertag(id, bytes(stdCode, encoding = "utf8"))) + """ + return bytes.decode(self.api.cta_get_last_entertag(id, bytes(stdCode, encoding="utf8"))) - def cta_get_last_entertime(self, id:int, stdCode:str) -> int: - ''' + def cta_get_last_entertime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的最后进场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.cta_get_last_entertime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_last_entertime(id, bytes(stdCode, encoding="utf8")) - def cta_get_last_exittime(self, id:int, stdCode:str) -> int: - ''' + def cta_get_last_exittime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的最后出场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.cta_get_last_exittime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.cta_get_last_exittime(id, bytes(stdCode, encoding="utf8")) - def cta_log_text(self, id:int, level:int, message:str): - ''' + def cta_log_text(self, id: int, level: int, message: str): + """ 日志输出 @id 策略ID @level 日志级别 @message 日志内容 - ''' + """ self.api.cta_log_text(id, level, ph.auto_encode(message)) - def cta_get_detail_entertime(self, id:int, stdCode:str, usertag:str) -> int: - ''' + def cta_get_detail_entertime(self, id: int, stdCode: str, usertag: str) -> int: + """ 获取指定标记的持仓的进场时间 @id 策略id @stdCode 合约代码 @usertag 进场标记 @return 进场时间, 格式如201907260932 - ''' - return self.api.cta_get_detail_entertime(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8")) + """ + return self.api.cta_get_detail_entertime(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8")) - def cta_get_detail_cost(self, id:int, stdCode:str, usertag:str) -> float: - ''' + def cta_get_detail_cost(self, id: int, stdCode: str, usertag: str) -> float: + """ 获取指定标记的持仓的开仓价 @id 策略id @stdCode 合约代码 @usertag 进场标记 @return 开仓价 - ''' - return self.api.cta_get_detail_cost(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8")) + """ + return self.api.cta_get_detail_cost(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8")) - def cta_get_detail_profit(self, id:int, stdCode:str, usertag:str, flag:int): - ''' + def cta_get_detail_profit(self, id: int, stdCode: str, usertag: str, flag: int): + """ 获取指定标记的持仓的盈亏 @id 策略id @stdCode 合约代码 @usertag 进场标记 @flag 盈亏记号, 0-浮动盈亏, 1-最大浮盈, 2-最大亏损(负数) @return 盈亏 - ''' - return self.api.cta_get_detail_profit(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8"), flag) + """ + return self.api.cta_get_detail_profit(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8"), + flag) - def cta_save_user_data(self, id:int, key:str, val:str): - ''' + def cta_save_user_data(self, id: int, key: str, val: str): + """ 保存用户数据 @id 策略id @key 数据名 @val 数据值 - ''' - self.api.cta_save_userdata(id, bytes(key, encoding = "utf8"), bytes(val, encoding = "utf8")) + """ + self.api.cta_save_userdata(id, bytes(key, encoding="utf8"), bytes(val, encoding="utf8")) - def cta_load_user_data(self, id:int, key:str, defVal:str = ""): - ''' + def cta_load_user_data(self, id: int, key: str, defVal: str = ""): + """ 加载用户数据 @id 策略id @key 数据名 @defVal 默认值 - ''' - ret = self.api.cta_load_userdata(id, bytes(key, encoding = "utf8"), bytes(defVal, encoding = "utf8")) + """ + ret = self.api.cta_load_userdata(id, bytes(key, encoding="utf8"), bytes(defVal, encoding="utf8")) return bytes.decode(ret) - def cta_sub_ticks(self, id:int, stdCode:str): - ''' + def cta_sub_ticks(self, id: int, stdCode: str): + """ 订阅行情 @id 策略id @stdCode 品种代码 - ''' - self.api.cta_sub_ticks(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.cta_sub_ticks(id, bytes(stdCode, encoding="utf8")) - def cta_sub_bar_events(self, id:int, stdCode:str, period:str): - ''' + def cta_sub_bar_events(self, id: int, stdCode: str, period: str): + """ 订阅K线事件 @id 策略id @stdCode 品种代码 @period 周期 - ''' - self.api.cta_sub_bar_events(id, bytes(stdCode, encoding = "utf8"), bytes(period, encoding = "utf8")) + """ + self.api.cta_sub_bar_events(id, bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8")) - def cta_set_chart_kline(self, id:int, stdCode:str, period:str): - ''' + def cta_set_chart_kline(self, id: int, stdCode: str, period: str): + """ 设置图表K线 @stdCode 合约代码 @period K线周期 - ''' - self.api.cta_set_chart_kline(id, bytes(stdCode, encoding = "utf8"), bytes(period, encoding = "utf8")) + """ + self.api.cta_set_chart_kline(id, bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8")) - def cta_add_chart_mark(self, id:int, price:float, icon:str, tag:str = 'Notag'): - ''' + def cta_add_chart_mark(self, id: int, price: float, icon: str, tag: str = 'Notag'): + """ 添加图表标记 @price 价格, 决定图标出现的位置 @icon 图标, 系统一定的图标ID @tag 标签, 自定义的 - ''' - self.api.cta_add_chart_mark(id, price, bytes(icon, encoding = "utf8"), bytes(tag, encoding = "utf8")) + """ + self.api.cta_add_chart_mark(id, price, bytes(icon, encoding="utf8"), bytes(tag, encoding="utf8")) - def cta_register_index(self, id:int, idxName:str, idxType:int = 1): - ''' + def cta_register_index(self, id: int, idxName: str, idxType: int = 1): + """ 注册指标, on_init调用 @idxName 指标名 @idxType 指标类型, 0-主图指标, 1-副图指标 - ''' - self.api.cta_register_index(id, bytes(idxName, encoding = "utf8"), idxType) + """ + self.api.cta_register_index(id, bytes(idxName, encoding="utf8"), idxType) - def cta_register_index_line(self, id:int, idxName:str, lineName:str, lineType:int = 0) -> bool: - ''' + def cta_register_index_line(self, id: int, idxName: str, lineName: str, lineType: int = 0) -> bool: + """ 注册指标线, on_init调用 @idxName 指标名称 @lineName 线名称 @lineType 线型, 0-曲线, 1-柱子 - ''' - return self.api.cta_register_index_line(id, bytes(idxName, encoding = "utf8"), bytes(lineName, encoding = "utf8"), lineType) + """ + return self.api.cta_register_index_line(id, bytes(idxName, encoding="utf8"), bytes(lineName, encoding="utf8"), + lineType) - def cta_add_index_baseline(self, id:int, idxName:str, lineName:str, value:float) -> bool: - ''' + def cta_add_index_baseline(self, id: int, idxName: str, lineName: str, value: float) -> bool: + """ 添加基准线, on_init调用 @idxName 指标名称 @lineName 线名称 @value 数值 - ''' - return self.api.cta_add_index_baseline(id, bytes(idxName, encoding = "utf8"), bytes(lineName, encoding = "utf8"), value) + """ + return self.api.cta_add_index_baseline(id, bytes(idxName, encoding="utf8"), bytes(lineName, encoding="utf8"), + value) - def cta_set_index_value(self, id:int, idxName:str, lineName:str, val:float) -> bool: - ''' + def cta_set_index_value(self, id: int, idxName: str, lineName: str, val: float) -> bool: + """ 设置指标值, 只有在oncalc的时候才生效 @idxName 指标名称 @lineName 线名称 - ''' - return self.api.cta_set_index_value(id, bytes(idxName, encoding = "utf8"), bytes(lineName, encoding = "utf8"), val) - - '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - '''SEL接口''' - def sel_get_bars(self, id:int, stdCode:str, period:str, count:int, isMain:bool): - ''' + """ + return self.api.cta_set_index_value(id, bytes(idxName, encoding="utf8"), bytes(lineName, encoding="utf8"), val) + + """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + """SEL接口""" + + def sel_get_bars(self, id: int, stdCode: str, period: str, count: int, isMain: bool): + """ 读取K线 @id 策略id @stdCode 合约代码 @period 周期, 如m1/m3/d1等 @count 条数 @isMain 是否主K线 - ''' - return self.api.sel_get_bars(id, bytes(stdCode, encoding = "utf8"), bytes(period, encoding = "utf8"), count, isMain, CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) - - def sel_get_ticks(self, id:int, stdCode:str, count:int): - ''' + """ + return self.api.sel_get_bars(id, bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8"), count, isMain, + CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) + + def sel_get_ticks(self, id: int, stdCode: str, count: int): + """ 读取Tick @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.sel_get_ticks(id, bytes(stdCode, encoding = "utf8"), count, CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) + """ + return self.api.sel_get_ticks(id, bytes(stdCode, encoding="utf8"), count, + CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) - def sel_save_user_data(self, id:int, key:str, val:str): - ''' + def sel_save_user_data(self, id: int, key: str, val: str): + """ 保存用户数据 @id 策略id @key 数据名 @val 数据值 - ''' - self.api.sel_save_userdata(id, bytes(key, encoding = "utf8"), bytes(val, encoding = "utf8")) + """ + self.api.sel_save_userdata(id, bytes(key, encoding="utf8"), bytes(val, encoding="utf8")) - def sel_load_user_data(self, id:int, key:str, defVal:str = ""): - ''' + def sel_load_user_data(self, id: int, key: str, defVal: str = ""): + """ 加载用户数据 @id 策略id @key 数据名 @defVal 默认值 - ''' - ret = self.api.sel_load_userdata(id, bytes(key, encoding = "utf8"), bytes(defVal, encoding = "utf8")) + """ + ret = self.api.sel_load_userdata(id, bytes(key, encoding="utf8"), bytes(defVal, encoding="utf8")) return bytes.decode(ret) - def sel_get_all_position(self, id:int): - ''' + def sel_get_all_position(self, id: int): + """ 获取全部持仓 @id 策略id - ''' + """ return self.api.sel_get_all_position(id, CB_STRATEGY_GET_POSITION(self.on_stra_get_position)) - def sel_get_position(self, id:int, stdCode:str, bonlyvalid:bool = False, usertag:str = ""): - ''' + def sel_get_position(self, id: int, stdCode: str, bonlyvalid: bool = False, usertag: str = ""): + """ 获取持仓 @id 策略id @stdCode 合约代码 @usertag 进场标记, 如果为空则获取该合约全部持仓 @return 指定合约的持仓手数, 正为多, 负为空 - ''' - return self.api.sel_get_position(id, bytes(stdCode, encoding = "utf8"), bonlyvalid, bytes(usertag, encoding = "utf8")) + """ + return self.api.sel_get_position(id, bytes(stdCode, encoding="utf8"), bonlyvalid, + bytes(usertag, encoding="utf8")) - def sel_get_price(self, stdCode:str): - ''' + def sel_get_price(self, stdCode: str): + """ @stdCode 合约代码 @return 指定合约的最新价格 - ''' - return self.api.sel_get_price(bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_price(bytes(stdCode, encoding="utf8")) - def sel_set_position(self, id:int, stdCode:str, qty:float, usertag:str = ""): - ''' + def sel_set_position(self, id: int, stdCode: str, qty: float, usertag: str = ""): + """ 设置目标仓位 @id 策略id @stdCode 合约代码 @qty 目标仓位, 正为多, 负为空 - ''' - self.api.sel_set_position(id, bytes(stdCode, encoding = "utf8"), qty, bytes(usertag, encoding = "utf8")) + """ + self.api.sel_set_position(id, bytes(stdCode, encoding="utf8"), qty, bytes(usertag, encoding="utf8")) def sel_get_tdate(self) -> int: - ''' + """ 获取当前交易日 @return 当前交易日 - ''' + """ return self.api.sel_get_tdate() - + def sel_get_date(self): - ''' + """ 获取当前日期 @return 当前日期 - ''' + """ return self.api.sel_get_date() def sel_get_time(self): - ''' + """ 获取当前时间 @return 当前时间 - ''' + """ return self.api.sel_get_time() - def sel_log_text(self, id:int, level:int, message:str): - ''' + def sel_log_text(self, id: int, level: int, message: str): + """ 日志输出 @id 策略ID @level 日志级别 @message 日志内容 - ''' + """ self.api.sel_log_text(id, level, ph.auto_encode(message)) - def sel_sub_ticks(self, id:int, stdCode:str): - ''' + def sel_sub_ticks(self, id: int, stdCode: str): + """ 订阅行情 @id 策略id @stdCode 品种代码 - ''' - self.api.sel_sub_ticks(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.sel_sub_ticks(id, bytes(stdCode, encoding="utf8")) - def sel_get_day_price(self, stdCode:str, flag:int = 0) -> float: - ''' + def sel_get_day_price(self, stdCode: str, flag: int = 0) -> float: + """ 获取当日价格 @stdCode 合约代码 @flag 价格标记, 0-开盘价, 1-最高价, 2-最低价, 3-最新价 @return 指定合约的价格 - ''' - return self.api.sel_get_day_price(bytes(stdCode, encoding = "utf8"), flag) + """ + return self.api.sel_get_day_price(bytes(stdCode, encoding="utf8"), flag) - def sel_get_fund_data(self, id:int, flag:int) -> float: - ''' + def sel_get_fund_data(self, id: int, flag: int) -> float: + """ 获取资金数据 @id 策略id @flag 0-动态权益, 1-总平仓盈亏, 2-总浮动盈亏, 3-总手续费 @return 资金数据 - ''' + """ return self.api.sel_get_fund_data(id, flag) - def sel_get_position_profit(self, id:int, stdCode:str): - ''' + def sel_get_position_profit(self, id: int, stdCode: str): + """ 获取浮动盈亏 @id 策略id @stdCode 合约代码 @return 指定合约的浮动盈亏 - ''' - return self.api.sel_get_position_profit(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_position_profit(id, bytes(stdCode, encoding="utf8")) - def sel_get_position_avgpx(self, id:int, stdCode:str): - ''' + def sel_get_position_avgpx(self, id: int, stdCode: str): + """ 获取持仓均价 @id 策略id @stdCode 合约代码 @return 指定合约的持仓均价 - ''' - return self.api.sel_get_position_avgpx(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_position_avgpx(id, bytes(stdCode, encoding="utf8")) - def sel_get_first_entertime(self, id:int, stdCode:str) -> int: - ''' + def sel_get_first_entertime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的首次进场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.sel_get_first_entertime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_first_entertime(id, bytes(stdCode, encoding="utf8")) - def sel_get_last_entertime(self, id:int, stdCode:str) -> int: - ''' + def sel_get_last_entertime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的最后进场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.sel_get_last_entertime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_last_entertime(id, bytes(stdCode, encoding="utf8")) - def sel_get_last_entertag(self, id:int, stdCode:str) -> str: - ''' + def sel_get_last_entertag(self, id: int, stdCode: str) -> str: + """ 获取当前持仓的最后进场标记 @stdCode 合约代码 @return 进场标记 - ''' - return bytes.decode(self.api.sel_get_last_entertag(id, bytes(stdCode, encoding = "utf8"))) + """ + return bytes.decode(self.api.sel_get_last_entertag(id, bytes(stdCode, encoding="utf8"))) - def sel_get_last_exittime(self, id:int, stdCode:str) -> int: - ''' + def sel_get_last_exittime(self, id: int, stdCode: str) -> int: + """ 获取当前持仓的最后出场时间 @stdCode 合约代码 @return 进场时间, 格式如201907260932 - ''' - return self.api.sel_get_last_exittime(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.sel_get_last_exittime(id, bytes(stdCode, encoding="utf8")) - def sel_get_detail_entertime(self, id:int, stdCode:str, usertag:str) -> int: - ''' + def sel_get_detail_entertime(self, id: int, stdCode: str, usertag: str) -> int: + """ 获取指定标记的持仓的进场时间 @id 策略id @stdCode 合约代码 @usertag 进场标记 @return 进场时间, 格式如201907260932 - ''' - return self.api.sel_get_detail_entertime(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8")) + """ + return self.api.sel_get_detail_entertime(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8")) - def sel_get_detail_cost(self, id:int, stdCode:str, usertag:str) -> float: - ''' + def sel_get_detail_cost(self, id: int, stdCode: str, usertag: str) -> float: + """ 获取指定标记的持仓的开仓价 @id 策略id @stdCode 合约代码 @usertag 进场标记 @return 开仓价 - ''' - return self.api.sel_get_detail_cost(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8")) + """ + return self.api.sel_get_detail_cost(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8")) - def sel_get_detail_profit(self, id:int, stdCode:str, usertag:str, flag:int): - ''' + def sel_get_detail_profit(self, id: int, stdCode: str, usertag: str, flag: int): + """ 获取指定标记的持仓的盈亏 @id 策略id @stdCode 合约代码 @usertag 进场标记 @flag 盈亏记号, 0-浮动盈亏, 1-最大浮盈, -1-最大亏损(负数), 2-最大浮盈价格, -2-最大浮亏价格 @return 盈亏 - ''' - return self.api.sel_get_detail_profit(id, bytes(stdCode, encoding = "utf8"), bytes(usertag, encoding = "utf8"), flag) + """ + return self.api.sel_get_detail_profit(id, bytes(stdCode, encoding="utf8"), bytes(usertag, encoding="utf8"), + flag) + + """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + """HFT接口""" - '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - '''HFT接口''' - def hft_get_bars(self, id:int, stdCode:str, period:str, count:int): - ''' + def hft_get_bars(self, id: int, stdCode: str, period: str, count: int): + """ 读取K线 @id 策略id @stdCode 合约代码 @period 周期, 如m1/m3/d1等 @count 条数 - ''' - return self.api.hft_get_bars(id, bytes(stdCode, encoding = "utf8"), bytes(period, encoding = "utf8"), count, CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) - - def hft_get_ticks(self, id:int, stdCode:str, count:int): - ''' + """ + return self.api.hft_get_bars(id, bytes(stdCode, encoding="utf8"), bytes(period, encoding="utf8"), count, + CB_STRATEGY_GET_BAR(self.on_stra_get_bar)) + + def hft_get_ticks(self, id: int, stdCode: str, count: int): + """ 读取Tick @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.hft_get_ticks(id, bytes(stdCode, encoding = "utf8"), count, CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) + """ + return self.api.hft_get_ticks(id, bytes(stdCode, encoding="utf8"), count, + CB_STRATEGY_GET_TICK(self.on_stra_get_tick)) - def hft_get_ordque(self, id:int, stdCode:str, count:int): - ''' + def hft_get_ordque(self, id: int, stdCode: str, count: int): + """ 读取委托队列 @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.hft_get_ordque(id, bytes(stdCode, encoding = "utf8"), count, CB_HFTSTRA_GET_ORDQUE(self.on_hftstra_get_order_queue)) + """ + return self.api.hft_get_ordque(id, bytes(stdCode, encoding="utf8"), count, + CB_HFTSTRA_GET_ORDQUE(self.on_hftstra_get_order_queue)) - def hft_get_orddtl(self, id:int, stdCode:str, count:int): - ''' + def hft_get_orddtl(self, id: int, stdCode: str, count: int): + """ 读取逐笔委托 @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.hft_get_orddtl(id, bytes(stdCode, encoding = "utf8"), count, CB_HFTSTRA_GET_ORDDTL(self.on_hftstra_get_order_detail)) + """ + return self.api.hft_get_orddtl(id, bytes(stdCode, encoding="utf8"), count, + CB_HFTSTRA_GET_ORDDTL(self.on_hftstra_get_order_detail)) - def hft_get_trans(self, id:int, stdCode:str, count:int): - ''' + def hft_get_trans(self, id: int, stdCode: str, count: int): + """ 读取逐笔成交 @id 策略id @stdCode 合约代码 @count 条数 - ''' - return self.api.hft_get_trans(id, bytes(stdCode, encoding = "utf8"), count, CB_HFTSTRA_GET_TRANS(self.on_hftstra_get_transaction)) + """ + return self.api.hft_get_trans(id, bytes(stdCode, encoding="utf8"), count, + CB_HFTSTRA_GET_TRANS(self.on_hftstra_get_transaction)) - def hft_save_user_data(self, id:int, key:str, val:str): - ''' + def hft_save_user_data(self, id: int, key: str, val: str): + """ 保存用户数据 @id 策略id @key 数据名 @val 数据值 - ''' - self.api.hft_save_userdata(id, bytes(key, encoding = "utf8"), bytes(val, encoding = "utf8")) + """ + self.api.hft_save_userdata(id, bytes(key, encoding="utf8"), bytes(val, encoding="utf8")) - def hft_load_user_data(self, id:int, key:str, defVal:str = ""): - ''' + def hft_load_user_data(self, id: int, key: str, defVal: str = ""): + """ 加载用户数据 @id 策略id @key 数据名 @defVal 默认值 - ''' - ret = self.api.hft_load_userdata(id, bytes(key, encoding = "utf8"), bytes(defVal, encoding = "utf8")) + """ + ret = self.api.hft_load_userdata(id, bytes(key, encoding="utf8"), bytes(defVal, encoding="utf8")) return bytes.decode(ret) - def hft_get_position(self, id:int, stdCode:str, bonlyvalid:bool = False): - ''' + def hft_get_position(self, id: int, stdCode: str, bonlyvalid: bool = False): + """ 获取持仓 @id 策略id @stdCode 合约代码 @return 指定合约的持仓手数, 正为多, 负为空 - ''' - return self.api.hft_get_position(id, bytes(stdCode, encoding = "utf8"), bonlyvalid) + """ + return self.api.hft_get_position(id, bytes(stdCode, encoding="utf8"), bonlyvalid) - def hft_get_position_profit(self, id:int, stdCode:str): - ''' + def hft_get_position_profit(self, id: int, stdCode: str): + """ 获取持仓盈亏 @id 策略id @stdCode 合约代码 @return 指定持仓的浮动盈亏 - ''' - return self.api.hft_get_position_profit(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.hft_get_position_profit(id, bytes(stdCode, encoding="utf8")) - def hft_get_position_avgpx(self, id:int, stdCode:str): - ''' + def hft_get_position_avgpx(self, id: int, stdCode: str): + """ 获取持仓均价 @id 策略id @stdCode 合约代码 @return 指定持仓的浮动盈亏 - ''' - return self.api.hft_get_position_avgpx(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.hft_get_position_avgpx(id, bytes(stdCode, encoding="utf8")) - def hft_get_undone(self, id:int, stdCode:str): - ''' + def hft_get_undone(self, id: int, stdCode: str): + """ 获取持仓 @id 策略id @stdCode 合约代码 @return 指定合约的持仓手数, 正为多, 负为空 - ''' - return self.api.hft_get_undone(id, bytes(stdCode, encoding = "utf8")) + """ + return self.api.hft_get_undone(id, bytes(stdCode, encoding="utf8")) - def hft_get_price(self, stdCode:str): - ''' + def hft_get_price(self, stdCode: str): + """ @stdCode 合约代码 @return 指定合约的最新价格 - ''' - return self.api.hft_get_price(bytes(stdCode, encoding = "utf8")) + """ + return self.api.hft_get_price(bytes(stdCode, encoding="utf8")) def hft_get_date(self): - ''' + """ 获取当前日期 @return 当前日期 - ''' + """ return self.api.hft_get_date() def hft_get_time(self): - ''' + """ 获取当前时间 @return 当前时间 - ''' + """ return self.api.hft_get_time() def hft_get_secs(self): - ''' + """ 获取当前时间 @return 当前时间 - ''' + """ return self.api.hft_get_secs() - def hft_log_text(self, id:int, level:int, message:str): - ''' + def hft_log_text(self, id: int, level: int, message: str): + """ 日志输出 @id 策略ID @level 日志级别 @message 日志内容 - ''' + """ self.api.hft_log_text(id, level, ph.auto_encode(message)) - def hft_sub_ticks(self, id:int, stdCode:str): - ''' + def hft_sub_ticks(self, id: int, stdCode: str): + """ 订阅实时行情数据 @id 策略ID @stdCode 品种代码 - ''' - self.api.hft_sub_ticks(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.hft_sub_ticks(id, bytes(stdCode, encoding="utf8")) - def hft_sub_order_queue(self, id:int, stdCode:str): - ''' + def hft_sub_order_queue(self, id: int, stdCode: str): + """ 订阅实时委托队列数据 @id 策略ID @stdCode 品种代码 - ''' - self.api.hft_sub_order_queue(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.hft_sub_order_queue(id, bytes(stdCode, encoding="utf8")) - def hft_sub_order_detail(self, id:int, stdCode:str): - ''' + def hft_sub_order_detail(self, id: int, stdCode: str): + """ 订阅逐笔委托数据 @id 策略ID @stdCode 品种代码 - ''' - self.api.hft_sub_order_detail(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.hft_sub_order_detail(id, bytes(stdCode, encoding="utf8")) - def hft_sub_transaction(self, id:int, stdCode:str): - ''' + def hft_sub_transaction(self, id: int, stdCode: str): + """ 订阅逐笔成交数据 @id 策略ID @stdCode 品种代码 - ''' - self.api.hft_sub_transaction(id, bytes(stdCode, encoding = "utf8")) + """ + self.api.hft_sub_transaction(id, bytes(stdCode, encoding="utf8")) - def hft_cancel(self, id:int, localid:int): - ''' + def hft_cancel(self, id: int, localid: int): + """ 撤销指定订单 @id 策略ID @localid 下单时返回的本地订单号 - ''' + """ return self.api.hft_cancel(id, localid) - def hft_cancel_all(self, id:int, stdCode:str, isBuy:bool): - ''' + def hft_cancel_all(self, id: int, stdCode: str, isBuy: bool): + """ 撤销指定品种的全部买入订单or卖出订单 @id 策略ID @stdCode 品种代码 @isBuy 买入or卖出 - ''' - ret = self.api.hft_cancel_all(id, bytes(stdCode, encoding = "utf8"), isBuy) + """ + ret = self.api.hft_cancel_all(id, bytes(stdCode, encoding="utf8"), isBuy) return bytes.decode(ret) - def hft_buy(self, id:int, stdCode:str, price:float, qty:float, userTag:str, flag:int): - ''' + def hft_buy(self, id: int, stdCode: str, price: float, qty: float, userTag: str, flag: int): + """ 买入指令 @id 策略ID @stdCode 品种代码 @price 买入价格, 0为市价 @qty 买入数量 @flag 下单标志, 0-normal, 1-fak, 2-fok - ''' - ret = self.api.hft_buy(id, bytes(stdCode, encoding = "utf8"), price, qty, bytes(userTag, encoding = "utf8"), flag) + """ + ret = self.api.hft_buy(id, bytes(stdCode, encoding="utf8"), price, qty, bytes(userTag, encoding="utf8"), flag) return bytes.decode(ret) - def hft_sell(self, id:int, stdCode:str, price:float, qty:float, userTag:str, flag:int): - ''' + def hft_sell(self, id: int, stdCode: str, price: float, qty: float, userTag: str, flag: int): + """ 卖出指令 @id 策略ID @stdCode 品种代码 @price 卖出价格, 0为市价 @qty 卖出数量 @flag 下单标志, 0-normal, 1-fak, 2-fok - ''' - ret = self.api.hft_sell(id, bytes(stdCode, encoding = "utf8"), price, qty, bytes(userTag, encoding = "utf8"), flag) + """ + ret = self.api.hft_sell(id, bytes(stdCode, encoding="utf8"), price, qty, bytes(userTag, encoding="utf8"), flag) return bytes.decode(ret) - '''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - '''CTA接口''' - def create_cta_context(self, name:str, slippage:int = 0) -> int: - ''' + """""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""" + """CTA接口""" + + def create_cta_context(self, name: str, slippage: int = 0) -> int: + """ 创建策略环境 @name 策略名称 @return 系统内策略ID - ''' - return self.api.create_cta_context(bytes(name, encoding = "utf8"), slippage) + """ + return self.api.create_cta_context(bytes(name, encoding="utf8"), slippage) - def create_hft_context(self, name:str, trader:str, agent:bool, slippage:int = 0) -> int: - ''' + def create_hft_context(self, name: str, trader: str, agent: bool, slippage: int = 0) -> int: + """ 创建策略环境 @name 策略名称 @trader 交易通道ID @agent 数据是否托管 @return 系统内策略ID - ''' - return self.api.create_hft_context(bytes(name, encoding = "utf8"), bytes(trader, encoding = "utf8"), agent, slippage) + """ + return self.api.create_hft_context(bytes(name, encoding="utf8"), bytes(trader, encoding="utf8"), agent, + slippage) - def create_sel_context(self, name:str, date:int, time:int, period:str, trdtpl:str = 'CHINA', session:str = "TRADING", slippage:int = 0) -> int: - ''' + def create_sel_context(self, name: str, date: int, time: int, period: str, trdtpl: str = 'CHINA', + session: str = "TRADING", slippage: int = 0) -> int: + """ 创建策略环境 @name 策略名称 @date 日期,根据周期变化,每日为0,每周为0~6,对应周日到周六,每月为1~31,每年为0101~1231 - @time 时间,精确到分钟 - @period 时间周期,可以是分钟min、天d、周w、月m、年y - @return 系统内策略ID - ''' - return self.api.create_sel_context(bytes(name, encoding = "utf8"), date, time, - bytes(period, encoding = "utf8"), bytes(trdtpl, encoding = "utf8"), bytes(session, encoding = "utf8"), slippage) - - def reg_cta_factories(self, factFolder:str): - return self.api.reg_cta_factories(bytes(factFolder, encoding = "utf8") ) + @time 时间,精确到分钟 + @period 时间周期,可以是分钟min、天d、周w、月m、年y + @return 系统内策略ID + """ + return self.api.create_sel_context(bytes(name, encoding="utf8"), date, time, + bytes(period, encoding="utf8"), bytes(trdtpl, encoding="utf8"), + bytes(session, encoding="utf8"), slippage) - def reg_hft_factories(self, factFolder:str): - return self.api.reg_hft_factories(bytes(factFolder, encoding = "utf8") ) + def reg_cta_factories(self, factFolder: str): + return self.api.reg_cta_factories(bytes(factFolder, encoding="utf8")) - def reg_sel_factories(self, factFolder:str): - return self.api.reg_sel_factories(bytes(factFolder, encoding = "utf8") ) + def reg_hft_factories(self, factFolder: str): + return self.api.reg_hft_factories(bytes(factFolder, encoding="utf8")) - def reg_exe_factories(self, factFolder:str): - return self.api.reg_exe_factories(bytes(factFolder, encoding = "utf8") ) + def reg_sel_factories(self, factFolder: str): + return self.api.reg_sel_factories(bytes(factFolder, encoding="utf8")) - \ No newline at end of file + def reg_exe_factories(self, factFolder: str): + return self.api.reg_exe_factories(bytes(factFolder, encoding="utf8")) diff --git a/wtpy/wrapper/__init__.py b/wtpy/wrapper/__init__.py index 9b8db913..83d482df 100644 --- a/wtpy/wrapper/__init__.py +++ b/wtpy/wrapper/__init__.py @@ -2,9 +2,10 @@ from .WtExecApi import WtExecApi from .WtBtWrapper import WtBtWrapper from .WtDtWrapper import WtDtWrapper -from .ContractLoader import ContractLoader,LoaderType +from .ContractLoader import ContractLoader, LoaderType from .WtDtHelper import WtDataHelper from .WtDtServoApi import WtDtServoApi from .TraderDumper import TraderDumper -__all__ = ["WtWrapper", "WtExecApi", "WtDtWrapper", "WtBtWrapper", "ContractLoader","LoaderType","WtDataHelper","WtDtServoApi","TraderDumper"] \ No newline at end of file +__all__ = ["WtWrapper", "WtExecApi", "WtDtWrapper", "WtBtWrapper", "ContractLoader", "LoaderType", "WtDataHelper", + "WtDtServoApi", "TraderDumper"]