diff --git a/deepmd/main.py b/deepmd/main.py index d829f11ba2..ed20bb54ae 100644 --- a/deepmd/main.py +++ b/deepmd/main.py @@ -860,6 +860,13 @@ def main_parser() -> argparse.ArgumentParser: default=None, help="Restart the training from the provided prefix of checkpoint files.", ) + parser_train_nvnmd.add_argument( + "-f", + "--init-frz-model", + type=str, + default=None, + help="Initialize the training from the frozen model.", + ) parser_train_nvnmd.add_argument( "-s", "--step", diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 002e7bd3d3..2da4f91588 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -38,8 +38,8 @@ from deepmd.tf.nvnmd.descriptor.se_atten import ( build_davg_dstd, build_op_descriptor, + build_recovered, check_switch_range, - descrpt2r4, filter_GR2D, filter_lower_R42GR, ) @@ -715,6 +715,20 @@ def _pass_filter( inputs_i = inputs inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) type_i = -1 + + # descrpt and recovered_switch for nvnmd + if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: + inputs_i, self.recovered_switch = build_recovered( + inputs_i, + self.t_avg, + self.t_std, + self.atype_nloc, + natoms[0], + self.ntypes, + self.rcut_r_smth, + self.filter_precision, + ) + if len(self.exclude_types): mask = self.build_type_exclude_mask_mixed( self.exclude_types, @@ -753,8 +767,7 @@ def _pass_filter( ) self.negative_mask = -(2 << 32) * (1.0 - self.nmask) inputs_i *= mask - if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: - inputs_i = descrpt2r4(inputs_i, atype) + layer, qmat = self._filter( inputs_i, type_i, @@ -1162,6 +1175,7 @@ def _filter_lower( inputs_i, atype, self.nei_type_vec, + self.recovered_switch, ) elif nvnmd_cfg.restore_descriptor: self.embedding_net_variables = ( diff --git a/deepmd/tf/nvnmd/data/data.py b/deepmd/tf/nvnmd/data/data.py index e1fcaac9f2..a950c89b6e 100644 --- a/deepmd/tf/nvnmd/data/data.py +++ b/deepmd/tf/nvnmd/data/data.py @@ -1,3 +1,5 @@ +import copy + # SPDX-License-Identifier: LGPL-3.0-or-later jdata_sys = {"debug": False} @@ -10,6 +12,7 @@ "neuron": [8, 16, 32], "resnet_dt": False, "axis_neuron": 4, + "seed": 1, "type_one_side": True, # rcut range "rc_lim": 0.5, @@ -39,6 +42,7 @@ # basic config from deepmd model "neuron": [128, 128, 128], "resnet_dt": False, + "seed": 1, "NNODE_FITS": "(M1*M2, neuron, 1)", "nlayer_fit": "len(neuron)+1", "NLAYER": "nlayer_fit", @@ -119,8 +123,8 @@ } # change the configuration according to the max_nnei -jdata_config_v0_ni128 = jdata_config_v0.copy() -jdata_config_v0_ni256 = jdata_config_v0.copy() +jdata_config_v0_ni128 = copy.deepcopy(jdata_config_v0) +jdata_config_v0_ni256 = copy.deepcopy(jdata_config_v0) jdata_config_v0_ni256["ctrl"] = { "MAX_NNEI": 256, "NSTDM": 128, @@ -142,6 +146,7 @@ "rcut_smth": 0.5, "neuron": [8, 16, 32], "resnet_dt": False, + "seed": 1, "axis_neuron": 4, "type_one_side": True, # rcut range @@ -150,6 +155,7 @@ # embedding net size "M1": "neuron[-1]", "M2": "axis_neuron", + "M3": 2, "SEL": 128, "NNODE_FEAS": "(1, neuron)", "nlayer_fea": "len(neuron)", @@ -170,6 +176,7 @@ # basic config from deepmd model "neuron": [128, 128, 128], "resnet_dt": False, + "seed": 1, "NNODE_FITS": "(M1*M2, neuron, 1)", "nlayer_fit": "len(neuron)+1", "NLAYER": "nlayer_fit", @@ -200,7 +207,7 @@ "NSEL": "NSTDM", "NSADV": "NSTDM+1", "VERSION": 1, - "SUB_VERSION": 1, + "SUB_VERSION": 2, }, "nbit": { # general @@ -243,6 +250,7 @@ "NBIT_CFG": 64, "NBIT_NET": 72, "NBIT_MODEL_HEAD": 32, + "NBIT_NSTEP": 8, # nbit for mapt-version "NBIT_IDX_S2G": 9, "NBIT_NEIB": 8, @@ -251,18 +259,18 @@ } # change the configuration according to the max_nnei -jdata_config_v1_ni128 = jdata_config_v1.copy() -jdata_config_v1_ni256 = jdata_config_v1.copy() +jdata_config_v1_ni128 = copy.deepcopy(jdata_config_v1) +jdata_config_v1_ni256 = copy.deepcopy(jdata_config_v1) jdata_config_v1_ni256["ctrl"] = { "MAX_NNEI": 256, - "NSTDM": 128, + "NSTDM": 64, "NSTDM_M1": 32, - "NSTDM_M2": 4, + "NSTDM_M2": 2, "NSTDM_M1X": 8, "NSEL": "NSTDM", "NSADV": "NSTDM+1", "VERSION": 1, - "SUB_VERSION": 1, + "SUB_VERSION": 2, } jdata_config_v1_ni256["nbit"]["NBIT_NEIB"] = 9 @@ -283,6 +291,7 @@ }, "nvnmd": { "version": 0, + "device": "vu9p", "max_nnei": 128, # 128 or 256 "net_size": 128, "config_file": "none", @@ -319,12 +328,12 @@ "disp_training": True, "time_training": True, "profiling": False, - "training_data": {"systems": "dataset", "batch_size": 1}, + "training_data": {"systems": "dataset", "set_prefix": "set", "batch_size": 1}, }, } -jdata_deepmd_input_v0_ni128 = jdata_deepmd_input_v0.copy() -jdata_deepmd_input_v0_ni256 = jdata_deepmd_input_v0.copy() +jdata_deepmd_input_v0_ni128 = copy.deepcopy(jdata_deepmd_input_v0) +jdata_deepmd_input_v0_ni256 = copy.deepcopy(jdata_deepmd_input_v0) jdata_deepmd_input_v0_ni256["nvnmd"]["max_nnei"] = 256 jdata_deepmd_input_v1 = { @@ -332,7 +341,9 @@ "descriptor": { "seed": 1, "type": "se_atten", - "tebd_input_mode": "strip", + "stripped_type_embedding": True, + "smooth_type_embdding": True, + "set_davg_zero": False, "sel": 128, "rcut": 7.0, "rcut_smth": 0.5, @@ -349,6 +360,7 @@ }, "nvnmd": { "version": 1, + "device": "vu9p", "max_nnei": 128, # 128 or 256 "net_size": 128, "config_file": "none", @@ -385,14 +397,15 @@ "disp_training": True, "time_training": True, "profiling": False, - "training_data": {"systems": "dataset", "batch_size": 1}, + "training_data": {"systems": "dataset", "set_prefix": "set", "batch_size": 1}, }, } -jdata_deepmd_input_v1_ni128 = jdata_deepmd_input_v1.copy() -jdata_deepmd_input_v1_ni256 = jdata_deepmd_input_v1.copy() +jdata_deepmd_input_v1_ni128 = copy.deepcopy(jdata_deepmd_input_v1) +jdata_deepmd_input_v1_ni256 = copy.deepcopy(jdata_deepmd_input_v1) jdata_deepmd_input_v1_ni256["nvnmd"]["max_nnei"] = 256 + NVNMD_WELCOME = ( r" _ _ __ __ _ _ __ __ ____ ", r"| \ | | \ \ / / | \ | | | \/ | | _ \ ", diff --git a/deepmd/tf/nvnmd/descriptor/se_atten.py b/deepmd/tf/nvnmd/descriptor/se_atten.py index 1ab2369148..b8b56519f5 100644 --- a/deepmd/tf/nvnmd/descriptor/se_atten.py +++ b/deepmd/tf/nvnmd/descriptor/se_atten.py @@ -61,24 +61,72 @@ def check_switch_range(davg, dstd) -> None: # 'init_from_model', 'restart', 'init_from_frz_model', 'finetune' if (davg is not None) or (dstd is not None): if davg is None: - davg = np.zeros([ntype, ndescrpt]) # pylint: disable=no-explicit-dtype + davg = np.zeros([ntype, ndescrpt], dtype=GLOBAL_NP_FLOAT_PRECISION) if dstd is None: - dstd = np.ones([ntype, ndescrpt]) # pylint: disable=no-explicit-dtype + dstd = np.ones([ntype, ndescrpt], dtype=GLOBAL_NP_FLOAT_PRECISION) nvnmd_cfg.get_s_range(davg, dstd) def build_op_descriptor(): r"""Replace se_a.py/DescrptSeA/build.""" if nvnmd_cfg.quantize_descriptor: + # [rij^2, xij, yij, zij] return op_module.prod_env_mat_a_mix_nvnmd_quantize else: return op_module.prod_env_mat_a_mix -def descrpt2r4(inputs, atype): - r"""Replace :math:`r_{ji} \rightarrow r'_{ji}` +def build_recovered( + descrpt, t_avg, t_std, atype, Na, ntypes, rcut_r_smth, filter_precision +): + NIDP = nvnmd_cfg.dscp["NIDP"] + # look up for avg and std + t_avg = tf.reshape(t_avg, [ntypes, -1, 4]) + t_std = tf.reshape(t_std, [ntypes, -1, 4]) + avg = tf.reshape(tf.slice(t_avg, [0, 0, 0], [-1, 1, 2]), [-1, 2]) + std = tf.reshape(tf.slice(t_std, [0, 0, 0], [-1, 1, 2]), [-1, 2]) + # look up + avg_lookup = tf.reshape(tf.nn.embedding_lookup(avg, atype), [-1, 1, 2]) + std_lookup = tf.reshape(tf.nn.embedding_lookup(std, atype), [-1, 1, 2]) + avg_s = tf.slice(avg_lookup, [0, 0, 0], [-1, -1, 1]) + std_s = tf.slice(std_lookup, [0, 0, 0], [-1, -1, 1]) + std_h = tf.slice(std_lookup, [0, 0, 1], [-1, -1, 1]) + # [rij^2, xij, yij, zij] -> [sij, hij] + s, h, k, r = descrpt2shkr(descrpt) + s = tf.reshape(s, [-1, NIDP, 1]) + h = tf.reshape(h, [-1, NIDP, 1]) + s_norm = (s - avg_s) / std_s + h_norm = (h - 0) / std_h + s_norm = tf.reshape(s_norm, [-1, 1]) + h_norm = tf.reshape(h_norm, [-1, 1]) + with tf.variable_scope("s", reuse=True): + s_norm = op_module.flt_nvnmd(s_norm) + log.debug("#s: %s", s_norm) + s_norm = tf.ensure_shape(s_norm, [None, 1]) + with tf.variable_scope("h", reuse=True): + h_norm = op_module.flt_nvnmd(h_norm) + log.debug("#h: %s", h_norm) + h_norm = tf.ensure_shape(h_norm, [None, 1]) + # merge into [sji, hji*xji, hji*yji, hji*zji] + Rs = s_norm + Rxyz = op_module.mul_flt_nvnmd(h_norm, r) + Rxyz = tf.ensure_shape(Rxyz, [None, 3]) + with tf.variable_scope("Rxyz", reuse=True): + Rxyz = op_module.flt_nvnmd(Rxyz) + log.debug("#Rxyz: %s", Rxyz) + Rxyz = tf.ensure_shape(Rxyz, [None, 3]) + R4 = tf.concat([Rs, Rxyz], axis=1) + descrpt_norm = tf.reshape(R4, [-1, NIDP * 4]) + # smooth + recovered_switch = k + + return descrpt_norm, recovered_switch + + +def descrpt2shkr(inputs): + r"""Replace :math:`r_{ji} \rightarrow s_{ji} and h_{ji}` where :math:`r_{ji} = (x_{ji}, y_{ji}, z_{ji})` and - :math:`r'_{ji} = (s_{ji}, \frac{s_{ji} x_{ji}}{r_{ji}}, \frac{s_{ji} y_{ji}}{r_{ji}}, \frac{s_{ji} z_{ji}}{r_{ji}})`. + :math:`h_{ji} = \frac{s_{ji} r_{ji}}`. """ NIDP = nvnmd_cfg.dscp["NIDP"] ndescrpt = NIDP * 4 @@ -96,17 +144,23 @@ def descrpt2r4(inputs, atype): rji = tf.reshape(tf.slice(inputs_reshape, [0, 1], [-1, 3]), [-1, 3]) with tf.variable_scope("rji", reuse=True): rji = op_module.flt_nvnmd(rji) - rji = tf.ensure_shape(rji, [None, 3]) log.debug("#rji: %s", rji) - - # s & h + rji = tf.ensure_shape(rji, [None, 3]) + # s & h & k u = tf.reshape(u, [-1, 1]) table = GLOBAL_NP_FLOAT_PRECISION( - np.concatenate([nvnmd_cfg.map["s"][0], nvnmd_cfg.map["h"][0]], axis=1) + np.concatenate( + [nvnmd_cfg.map["s"][0], nvnmd_cfg.map["h"][0], nvnmd_cfg.map["k"][0]], + axis=1, + ) ) table_grad = GLOBAL_NP_FLOAT_PRECISION( np.concatenate( - [nvnmd_cfg.map["s_grad"][0], nvnmd_cfg.map["h_grad"][0]], + [ + nvnmd_cfg.map["s_grad"][0], + nvnmd_cfg.map["h_grad"][0], + nvnmd_cfg.map["k_grad"][0], + ], axis=1, ) ) @@ -114,12 +168,14 @@ def descrpt2r4(inputs, atype): table_info = np.array([np.float64(v) for vs in table_info for v in vs]) table_info = GLOBAL_NP_FLOAT_PRECISION(table_info) - s_h = op_module.map_flt_nvnmd(u, table, table_grad, table_info) - s_h = tf.ensure_shape(s_h, [None, 1, 2]) - s = tf.slice(s_h, [0, 0, 0], [-1, -1, 1]) - h = tf.slice(s_h, [0, 0, 1], [-1, -1, 1]) + s_h_k = op_module.map_flt_nvnmd(u, table, table_grad, table_info) + s_h_k = tf.ensure_shape(s_h_k, [None, 1, 3]) + s = tf.slice(s_h_k, [0, 0, 0], [-1, -1, 1]) + h = tf.slice(s_h_k, [0, 0, 1], [-1, -1, 1]) + k = tf.slice(s_h_k, [0, 0, 2], [-1, -1, 1]) s = tf.reshape(s, [-1, 1]) h = tf.reshape(h, [-1, 1]) + k = tf.reshape(k, [-1, 1]) with tf.variable_scope("s_s", reuse=True): s = op_module.flt_nvnmd(s) @@ -130,46 +186,15 @@ def descrpt2r4(inputs, atype): h = op_module.flt_nvnmd(h) log.debug("#h_s: %s", h) h = tf.ensure_shape(h, [None, 1]) - # davg and dstd - # davg = nvnmd_cfg.map["davg"] # is_zero - dstd_inv = nvnmd_cfg.map["dstd_inv"] - atype_expand = tf.reshape(atype, [-1, 1]) - std_inv_sel = tf.nn.embedding_lookup(dstd_inv, atype_expand) - std_inv_sel = tf.reshape(std_inv_sel, [-1, 4]) - std_inv_s = tf.slice(std_inv_sel, [0, 0], [-1, 1]) - std_inv_h = tf.slice(std_inv_sel, [0, 1], [-1, 1]) - s = op_module.mul_flt_nvnmd(std_inv_s, tf.reshape(s, [-1, NIDP])) - h = op_module.mul_flt_nvnmd(std_inv_h, tf.reshape(h, [-1, NIDP])) - s = tf.ensure_shape(s, [None, NIDP]) - h = tf.ensure_shape(h, [None, NIDP]) - s = tf.reshape(s, [-1, 1]) - h = tf.reshape(h, [-1, 1]) - with tf.variable_scope("s", reuse=True): - s = op_module.flt_nvnmd(s) - log.debug("#s: %s", s) - s = tf.ensure_shape(s, [None, 1]) - - with tf.variable_scope("h", reuse=True): - h = op_module.flt_nvnmd(h) - log.debug("#h: %s", h) - h = tf.ensure_shape(h, [None, 1]) - # R2R4 - Rs = s - # Rxyz = h * rji - Rxyz = op_module.mul_flt_nvnmd(h, rji) - Rxyz = tf.ensure_shape(Rxyz, [None, 3]) - with tf.variable_scope("Rxyz", reuse=True): - Rxyz = op_module.flt_nvnmd(Rxyz) - log.debug("#Rxyz: %s", Rxyz) - Rxyz = tf.ensure_shape(Rxyz, [None, 3]) - R4 = tf.concat([Rs, Rxyz], axis=1) - inputs_reshape = R4 - inputs_reshape = tf.reshape(inputs_reshape, [-1, ndescrpt]) - return inputs_reshape + with tf.variable_scope("k", reuse=True): + k = op_module.flt_nvnmd(k) + log.debug("#k: %s", k) + k = tf.ensure_shape(k, [None, 1]) + return s, h, k, rji -def filter_lower_R42GR(inputs_i, atype, nei_type_vec): +def filter_lower_R42GR(inputs_i, atype, nei_type_vec, recovered_switch): r"""Replace se_a.py/DescrptSeA/_filter_lower.""" shape_i = inputs_i.get_shape().as_list() inputs_reshape = tf.reshape(inputs_i, [-1, 4]) @@ -177,8 +202,8 @@ def filter_lower_R42GR(inputs_i, atype, nei_type_vec): ntype = nvnmd_cfg.dscp["ntype"] NIDP = nvnmd_cfg.dscp["NIDP"] two_embd_value = nvnmd_cfg.map["gt"] - # print(two_embd_value) two_embd_value = GLOBAL_NP_FLOAT_PRECISION(two_embd_value) + # copy inputs_reshape = op_module.flt_nvnmd(inputs_reshape) inputs_reshape = tf.ensure_shape(inputs_reshape, [None, 4]) @@ -193,12 +218,12 @@ def filter_lower_R42GR(inputs_i, atype, nei_type_vec): table_info = nvnmd_cfg.map["cfg_s2g"] table_info = np.array([np.float64(v) for vs in table_info for v in vs]) table_info = GLOBAL_NP_FLOAT_PRECISION(table_info) - G = op_module.map_flt_nvnmd(s, table, table_grad, table_info) - G = tf.ensure_shape(G, [None, 1, M1]) + Gs = op_module.map_flt_nvnmd(s, table, table_grad, table_info) + Gs = tf.ensure_shape(Gs, [None, 1, M1]) with tf.variable_scope("g_s", reuse=True): - G = op_module.flt_nvnmd(G) - log.debug("#g_s: %s", G) - G = tf.ensure_shape(G, [None, 1, M1]) + Gs = op_module.flt_nvnmd(Gs) + log.debug("#g_s: %s", Gs) + Gs = tf.ensure_shape(Gs, [None, 1, M1]) # t2G atype_expand = tf.reshape(atype, [-1, 1]) idx_i = tf.tile(atype_expand * (ntype + 1), [1, NIDP]) @@ -206,15 +231,23 @@ def filter_lower_R42GR(inputs_i, atype, nei_type_vec): idx = idx_i + idx_j index_of_two_side = tf.reshape(idx, [-1]) two_embd = tf.nn.embedding_lookup(two_embd_value, index_of_two_side) - # two_embd = tf.reshape(two_embd, (-1, shape_i[1] // 4, M1)) two_embd = tf.reshape(two_embd, (-1, M1)) with tf.variable_scope("g_t", reuse=True): two_embd = op_module.flt_nvnmd(two_embd) log.debug("#g_t: %s", two_embd) two_embd = tf.ensure_shape(two_embd, [None, M1]) + # t2G * k(s) + two_embd = two_embd * tf.reshape(recovered_switch, [-1, 1]) + with tf.variable_scope("g_tk", reuse=True): + two_embd = op_module.flt_nvnmd(two_embd) + log.debug("#g_tk: %s", two_embd) + two_embd = tf.ensure_shape(two_embd, [None, M1]) # G_s, G_t -> G - G = tf.reshape(G, [-1, M1]) - G = op_module.mul_flt_nvnmd(G, two_embd) + # G = Gs * Gt + Gs + Gs = tf.reshape(Gs, [-1, M1]) + G2 = op_module.mul_flt_nvnmd(Gs, two_embd) + G2 = tf.ensure_shape(G2, [None, M1]) + G = op_module.add_flt_nvnmd(Gs, G2) G = tf.ensure_shape(G, [None, M1]) with tf.variable_scope("g", reuse=True): G = op_module.flt_nvnmd(G) diff --git a/deepmd/tf/nvnmd/entrypoints/mapt.py b/deepmd/tf/nvnmd/entrypoints/mapt.py index 2e6e56bf51..366b6859f8 100644 --- a/deepmd/tf/nvnmd/entrypoints/mapt.py +++ b/deepmd/tf/nvnmd/entrypoints/mapt.py @@ -7,6 +7,7 @@ import numpy as np from deepmd.tf.env import ( + GLOBAL_NP_FLOAT_PRECISION, op_module, tf, ) @@ -77,7 +78,7 @@ class MapTable: DOI: 10.1038/s41524-022-00773-z """ - def __init__(self, config_file: str, weight_file: str, map_file: str) -> None: + def __init__(self, config_file: str, weight_file: str, map_file: str): self.config_file = config_file self.weight_file = weight_file self.map_file = map_file @@ -91,17 +92,16 @@ def __init__(self, config_file: str, weight_file: str, map_file: str) -> None: # Gs + 1, Gt + 0 # 1 : xyz_scatter = xyz_scatter * two_embd + two_embd ; # Gs + 0, Gt + 1 - self.Gs_Gt_mode = 1 + # 2 : xyz_scatter = xyz_scatter * two_embd * recovered_switch + xyz_scatter; + # Gs + 0, Gt + 0 + self.Gs_Gt_mode = 2 nvnmd_cfg.init_from_jdata(jdata) def build_map(self): - if self.Gs_Gt_mode == 0: - self.shift_Gs = 1 - self.shift_Gt = 0 - if self.Gs_Gt_mode == 1: + if self.Gs_Gt_mode == 2: self.shift_Gs = 0 - self.shift_Gt = 1 + self.shift_Gt = 0 # M = nvnmd_cfg.dscp["M1"] if nvnmd_cfg.version == 0: @@ -137,12 +137,21 @@ def build_map(self): ndim, 1, ) + dic_map1["k"], dic_map1["k_grad"] = self.build_map_coef( + cfg_u2s, + u, + dic_u2s["k"], + dic_u2s["k_grad"], + dic_u2s["k_grad_grad"], + ndim, + 1, + ) ## s2g dic_map2 = {} s = np.reshape(dic_s2g["s"], [-1]) cfg_s2g = [ [s[0], s[256], s[1] - s[0], 0, 256], - [s[0], s[4096], s[16] - s[0], 256, 512], + [s[0], s[8192], s[32] - s[0], 256, 512], ] dic_map2["g"], dic_map2["g_grad"] = self.build_map_coef( cfg_s2g, @@ -194,7 +203,7 @@ def mapping(self, x, dic_map, cfgs): val_i = val[ii] nr = np.shape(val_i)[0] nc = np.shape(val_i)[1] // 4 - dat_i = np.zeros([n, nc]) # pylint: disable=no-explicit-dtype + dat_i = np.zeros([n, nc], dtype=GLOBAL_NP_FLOAT_PRECISION) for kk in range(n): xk = x[kk] for cfg in cfgs: @@ -250,7 +259,7 @@ def mapping2(self, x, dic_map, cfgs): dic_val[key] = dats return dic_val - def plot_lines(self, x, dic1, dic2=None) -> None: + def plot_lines(self, x, dic1, dic2=None): r"""Plot lines to see accuracy.""" pass @@ -391,14 +400,19 @@ def build_u2s(self, r2): h = h / std[tt, 1] sl.append(s) hl.append(h) - return sl, hl + return sl, hl, sl if nvnmd_cfg.version == 1: s = vv / r__ h = s / r__ + kk = 1 - rmin * s + k = -kk * kk * kk + 1 + k = tf.clip_by_value(k, 0.0, 1.0) + s = tf.reshape(s, [-1, 1]) h = tf.reshape(h, [-1, 1]) - return [s], [h] + k = tf.reshape(k, [-1, 1]) + return [s], [h], [k] def build_u2s_grad(self): r"""Build gradient of s with respect to u (r^2).""" @@ -409,13 +423,16 @@ def build_u2s_grad(self): # dic_ph = {} dic_ph["u"] = tf.placeholder(tf.float64, [None, 1], "t_u") - dic_ph["s"], dic_ph["h"] = self.build_u2s(dic_ph["u"]) + dic_ph["s"], dic_ph["h"], dic_ph["k"] = self.build_u2s(dic_ph["u"]) dic_ph["s_grad"], dic_ph["s_grad_grad"] = self.build_grad( dic_ph["u"], dic_ph["s"], ndim, 1 ) dic_ph["h_grad"], dic_ph["h_grad_grad"] = self.build_grad( dic_ph["u"], dic_ph["h"], ndim, 1 ) + dic_ph["k_grad"], dic_ph["k_grad_grad"] = self.build_grad( + dic_ph["u"], dic_ph["k"], ndim, 1 + ) return dic_ph def run_u2s(self): @@ -436,17 +453,21 @@ def run_u2s(self): # N = NUM_MAPT N = 512 N2 = int(rc_max**2) - # N+1 ranther than N for calculating difference + # N+1 ranther than N for calculating defference keys = list(dic_ph.keys()) vals = list(dic_ph.values()) - u = N2 * np.reshape(np.arange(0, N + 1) / N, [-1, 1]) # pylint: disable=no-explicit-dtype + u = N2 * np.reshape( + np.arange(0, N + 1, dtype=GLOBAL_NP_FLOAT_PRECISION) / N, [-1, 1] + ) res_lst = run_sess(sess, vals, feed_dict={dic_ph["u"]: u}) res_dic = dict(zip(keys, res_lst)) - u2 = N2 * np.reshape(np.arange(0, N * 16 + 1) / (N * 16), [-1, 1]) # pylint: disable=no-explicit-dtype + u2 = N2 * np.reshape( + np.arange(0, N * 16 + 1, dtype=np.float64) / (N * 16), [-1, 1] + ) res_lst2 = run_sess(sess, vals, feed_dict={dic_ph["u"]: u2}) - res_dic2 = dict(zip(keys, res_lst2)) # reference for compare + res_dic2 = dict(zip(keys, res_lst2)) # reference for commpare # change value for tt in range(ndim): @@ -456,6 +477,9 @@ def run_u2s(self): res_dic["h"][tt][0] = 0 res_dic["h_grad"][tt][0] = 0 res_dic["h_grad_grad"][tt][0] = 0 + res_dic["k"][tt][0] = 0 + res_dic["k_grad"][tt][0] = 0 + res_dic["k_grad_grad"][tt][0] = 0 # res_dic2["s"][tt][0] = -avg[tt, 0] / std[tt, 0] res_dic2["s_grad"][tt][0] = 0 @@ -463,6 +487,13 @@ def run_u2s(self): res_dic2["h"][tt][0] = 0 res_dic2["h_grad"][tt][0] = 0 res_dic2["h_grad_grad"][tt][0] = 0 + res_dic2["k"][tt][0] = 0 + res_dic2["k_grad"][tt][0] = 0 + res_dic2["k_grad_grad"][tt][0] = 0 + # + if nvnmd_cfg.version == 1: + res_dic["s"][tt][0] = 0 + res_dic2["s"][tt][0] = 0 sess.close() return res_dic, res_dic2 @@ -521,27 +552,34 @@ def run_s2g(self): dic_ph = self.build_s2g_grad() sess = get_sess() - N = 4096 - N2 = 16 + N = 8192 + N2 = 32 log.info(f"the range of s is [{smin}, {smax}]") # check - if (smax - smin) > 16.0: - log.warning("the range of s is over the limit (smax - smin) > 16.0") + if (smax - smin) > 32.0: + log.warning("the range of s is over the limit (smax - smin) > 32.0") prec = N / N2 # the lower limit of switch function - if nvnmd_cfg.version == 0: - smin_ = np.floor(smin * prec - 1) / prec - if nvnmd_cfg.version == 1: - smin_ = 0 + smin_ = np.floor(smin * prec - 1) / prec # keys = list(dic_ph.keys()) vals = list(dic_ph.values()) - s = N2 * np.reshape(np.arange(0, N + 1) / N, [-1, 1]) + smin_ # pylint: disable=no-explicit-dtype + s = ( + N2 + * np.reshape( + np.arange(0, N + 1, dtype=GLOBAL_NP_FLOAT_PRECISION) / N, [-1, 1] + ) + + smin_ + ) res_lst = run_sess(sess, vals, feed_dict={dic_ph["s"]: s}) res_dic = dict(zip(keys, res_lst)) - s2 = N2 * np.reshape(np.arange(0, N * 16 + 1) / (N * 16), [-1, 1]) + smin_ # pylint: disable=no-explicit-dtype + s2 = ( + N2 + * np.reshape(np.arange(0, N * 16 + 1, dtype=np.float64) / (N * 16), [-1, 1]) + + smin_ + ) res_lst2 = run_sess(sess, vals, feed_dict={dic_ph["s"]: s2}) res_dic2 = dict(zip(keys, res_lst2)) @@ -566,7 +604,9 @@ def build_t2g(self): dic_ph["t_one_hot"] = ebd_type wbs = [get_type_embedding_weight(nvnmd_cfg.weight, ll) for ll in range(1, 5)] ebd_type = self.build_embedding_net(dic_ph["t_one_hot"], wbs, None) - last_type = tf.cast(tf.zeros([1, ebd_type.shape[1]]), filter_precision) # pylint: disable=no-explicit-dtype + last_type = tf.cast( + tf.zeros([1, ebd_type.shape[1]], dtype=filter_precision), filter_precision + ) ebd_type = tf.concat([ebd_type, last_type], 0) dic_ph["t_ebd"] = ebd_type # type_embedding of i, j atoms -> two_side_type_embedding diff --git a/deepmd/tf/nvnmd/entrypoints/train.py b/deepmd/tf/nvnmd/entrypoints/train.py index c690190c0d..c16cea9762 100644 --- a/deepmd/tf/nvnmd/entrypoints/train.py +++ b/deepmd/tf/nvnmd/entrypoints/train.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import logging import os from typing import ( @@ -65,12 +66,12 @@ def normalized_input(fn, PATH_CNN, CONFIG_CNN): # model jdata_model = { "descriptor": { - "seed": 1, + "seed": jdata_nvnmd_.get("seed", 1), "sel": jdata_nvnmd_["sel"], "rcut": jdata_nvnmd_["rcut"], "rcut_smth": jdata_nvnmd_["rcut_smth"], }, - "fitting_net": {"seed": 1}, + "fitting_net": {"seed": jdata_nvnmd_.get("seed", 1)}, "type_map": [], } jdata_model["type_map"] = f.get(jdata_nvnmd_, "type_map", []) @@ -120,6 +121,7 @@ def train_nvnmd( INPUT: str, init_model: Optional[str], restart: Optional[str], + init_frz_model: Optional[str], step: str, skip_neighbor_stat: bool = False, **kwargs, @@ -141,16 +143,17 @@ def train_nvnmd( FioDic().save(INPUT_CNN, jdata) nvnmd_cfg.save(CONFIG_CNN) # train cnn - jdata = jdata_cmd_train.copy() + jdata = copy.deepcopy(jdata_cmd_train) jdata["INPUT"] = INPUT_CNN jdata["log_path"] = LOG_CNN jdata["init_model"] = init_model + jdata["init_frz_model"] = init_frz_model jdata["restart"] = restart jdata["skip_neighbor_stat"] = skip_neighbor_stat train(**jdata) tf.reset_default_graph() # freeze - jdata = jdata_cmd_freeze.copy() + jdata = copy.deepcopy(jdata_cmd_freeze) jdata["checkpoint_folder"] = PATH_CNN jdata["output"] = FRZ_MODEL_CNN jdata["nvnmd_weight"] = WEIGHT_CNN @@ -180,14 +183,14 @@ def train_nvnmd( FioDic().save(INPUT_QNN, jdata) nvnmd_cfg.save(CONFIG_QNN) # train qnn - jdata = jdata_cmd_train.copy() + jdata = copy.deepcopy(jdata_cmd_train) jdata["INPUT"] = INPUT_QNN jdata["log_path"] = LOG_QNN jdata["skip_neighbor_stat"] = skip_neighbor_stat train(**jdata) tf.reset_default_graph() # freeze - jdata = jdata_cmd_freeze.copy() + jdata = copy.deepcopy(jdata_cmd_freeze) jdata["checkpoint_folder"] = PATH_QNN jdata["output"] = FRZ_MODEL_QNN jdata["nvnmd_weight"] = WEIGHT_QNN diff --git a/deepmd/tf/nvnmd/entrypoints/wrap.py b/deepmd/tf/nvnmd/entrypoints/wrap.py index ced97bdbf1..d6a0025ae5 100755 --- a/deepmd/tf/nvnmd/entrypoints/wrap.py +++ b/deepmd/tf/nvnmd/entrypoints/wrap.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import logging from typing import ( Optional, @@ -7,11 +8,13 @@ import numpy as np from deepmd.tf.env import ( + GLOBAL_NP_FLOAT_PRECISION, op_module, tf, ) from deepmd.tf.nvnmd.data.data import ( jdata_deepmd_input_v0, + jdata_deepmd_input_v1_ni256, jdata_sys, ) from deepmd.tf.nvnmd.utils.config import ( @@ -76,8 +79,14 @@ def __init__( self.weight_file = weight_file self.map_file = map_file self.model_file = model_file - - jdata = jdata_deepmd_input_v0["nvnmd"] + # init according to local file + loc_config = np.load(config_file, allow_pickle=True) + loc_version = loc_config[0]["ctrl"]["VERSION"] + jdata = ( + jdata_deepmd_input_v1_ni256["nvnmd"] + if loc_version == 1 + else jdata_deepmd_input_v0["nvnmd"] + ) jdata["config_file"] = config_file jdata["weight_file"] = weight_file jdata["map_file"] = map_file @@ -238,14 +247,16 @@ def wrap_head(self, nhs, nws): def wrap_dscp(self): r"""Wrap the configuration of descriptor. - version 0: - [NBIT_IDX_S2G-1:0] SHIFT_IDX_S2G - [NBIT_NEIB*NTYPE-1:0] SELs - [NBIT_FIXD*M1*NTYPE*NTYPE-1:0] GSs - [NBIT_FLTE-1:0] NEXPO_DIV_NI + version 0: + [NBIT_IDX_S2G-1:0] SHIFT_IDX_S2G + [NBIT_NEIB*NTYPE-1:0] SELs + [NBIT_FIXD*M1*NTYPE*NTYPE-1:0] GSs + [NBIT_FLTE-1:0] NEXPO_DIV_NI - version 1: - [NBIT_FLTE-1:0] NEXPO_DIV_NI + version 1: + [NBIT_IDX_S2G-1:0] SHIFT_IDX_S2G + [NBIT_FLTE-1:0] NEXPO_DIV_NI + [NBIT_NSTEP-1:0] NSTEP """ dscp = nvnmd_cfg.dscp nbit = nvnmd_cfg.nbit @@ -292,7 +303,7 @@ def wrap_dscp(self): cfgs = mapt["cfg_u2s"] cfgs = np.array([np.float64(v) for vs in cfgs for v in vs]) feed_dict = { - t_x: np.ones([1, 1]) * 0.0, # pylint: disable=no-explicit-dtype + t_x: np.ones([1, 1], dtype=GLOBAL_NP_FLOAT_PRECISION) * 0.0, t_table: mi, t_table_grad: mi * 0.0, t_table_info: cfgs, @@ -304,7 +315,7 @@ def wrap_dscp(self): cfgs = mapt["cfg_s2g"] cfgs = np.array([np.float64(v) for vs in cfgs for v in vs]) feed_dict = { - t_x: np.ones([1, 1]) * si, # pylint: disable=no-explicit-dtype + t_x: np.ones([1, 1], dtype=GLOBAL_NP_FLOAT_PRECISION) * si, t_table: mi, t_table_grad: mi * 0.0, t_table_info: cfgs, @@ -312,7 +323,7 @@ def wrap_dscp(self): gi = run_sess(sess, t_y, feed_dict=feed_dict) gsi = np.reshape(si, [-1]) * np.reshape(gi, [-1]) else: - gsi = np.zeros(M1) # pylint: disable=no-explicit-dtype + gsi = np.zeros(M1, dtype=GLOBAL_NP_FLOAT_PRECISION) for ii in range(M1): GSs.extend( e.dec2bin(e.qr(gsi[ii], NBIT_FIXD_FL), NBIT_FIXD, True) @@ -324,10 +335,20 @@ def wrap_dscp(self): ln2_NIX = -int(np.log2(NIX)) bs = e.dec2bin(ln2_NIX, NBIT_FLTE, signed=True)[0] + bs if nvnmd_cfg.version == 1: + NBIT_IDX_S2G = nbit["NBIT_IDX_S2G"] NBIT_FLTE = nbit["NBIT_FLTE"] + NBIT_NSTEP = nbit["NBIT_NSTEP"] NIX = dscp["NIX"] + # shift_idx_s2g + x_st, x_ed, x_dt, N0, N1 = mapt["cfg_s2g"][0] + shift_idx_s2g = int(np.round(-x_st / x_dt)) + bs = e.dec2bin(shift_idx_s2g, NBIT_IDX_S2G)[0] + bs + # NI ln2_NIX = -int(np.log2(NIX)) bs = e.dec2bin(ln2_NIX, NBIT_FLTE, signed=True)[0] + bs + # NSTEP + nstep = dscp["NSTEP"] + bs = e.dec2bin(nstep, NBIT_NSTEP)[0] + bs return bs def wrap_fitn(self): @@ -391,6 +412,7 @@ def wrap_fitn(self): bwc.append(bwct) # bfps, bbps = [], [] + numdata = 4 if NSTDM == 32 else 2 for ss in range(NSEL): tt = ss // NSTDM sc = ss % NSTDM @@ -416,7 +438,12 @@ def wrap_fitn(self): for rr in range(nrs) for cc in range(nc) ] - bbp += [bdc[ll][tt][sc * ncs * 0 + cc] for cc in range(ncs)] + bbp += [ + "".join( + [bdc[ll][tt][sc * ncs * 0 + cc] for cc in range(ncs)] + * numdata + ) + ] # fix bug-adjust to multi data bbp += [bb[ll][tt][sc * ncs * 0 + cc] for cc in range(ncs)] else: # fp @@ -453,10 +480,10 @@ def wrap_weight(self, weight, NBIT_DISP, NBIT_WEIGHT): NBIT_WEIGHT_FL = NBIT_WEIGHT - 2 sh = weight.shape nr, nc = sh[0], sh[1] - nrs = np.zeros(nr) # pylint: disable=no-explicit-dtype - ncs = np.zeros(nc) # pylint: disable=no-explicit-dtype - wrs = np.zeros([nr, nc]) # pylint: disable=no-explicit-dtype - wcs = np.zeros([nr, nc]) # pylint: disable=no-explicit-dtype + nrs = np.zeros(nr, dtype=GLOBAL_NP_FLOAT_PRECISION) + ncs = np.zeros(nc, dtype=GLOBAL_NP_FLOAT_PRECISION) + wrs = np.zeros([nr, nc], dtype=GLOBAL_NP_FLOAT_PRECISION) + wcs = np.zeros([nr, nc], dtype=GLOBAL_NP_FLOAT_PRECISION) e = Encode() # row for ii in range(nr): @@ -503,26 +530,37 @@ def wrap_map(self): dsws = [] feas = [] gras = [] + for tt in range(ntype_max): - if tt < ntype: - swt = np.concatenate([maps["s"][tt], maps["h"][tt]], axis=1) - dsw = np.concatenate([maps["s_grad"][tt], maps["h_grad"][tt]], axis=1) - fea = maps["g"][tt] - gra = maps["g_grad"][tt] + ttt = tt if tt < ntype else 0 + kkk = 1 if tt < ntype else 0 + if nvnmd_cfg.version == 0: + swt = np.concatenate([maps["s"][ttt], maps["h"][ttt]], axis=1) + dsw = np.concatenate([maps["s_grad"][ttt], maps["h_grad"][ttt]], axis=1) else: - swt = np.concatenate([maps["s"][0], maps["h"][0]], axis=1) - dsw = np.concatenate([maps["s_grad"][0], maps["h_grad"][0]], axis=1) - fea = maps["g"][0] - gra = maps["g_grad"][0] - swt *= 0 - dsw *= 0 - fea *= 0 - gra *= 0 - swts.append(swt.copy()) - dsws.append(dsw.copy()) - feas.append(fea.copy()) - gras.append(gra.copy()) + swt = np.concatenate( + [maps["s"][ttt], maps["h"][ttt], maps["k"][ttt]], axis=1 + ) + dsw = np.concatenate( + [maps["s_grad"][ttt], maps["h_grad"][ttt], maps["k_grad"][ttt]], + axis=1, + ) + + fea = maps["g"][ttt] + gra = maps["g_grad"][ttt] + + swt *= kkk + dsw *= kkk + fea *= kkk + gra *= kkk + + swts.append(copy.deepcopy(swt)) + dsws.append(copy.deepcopy(dsw)) + feas.append(copy.deepcopy(fea)) + gras.append(copy.deepcopy(gra)) mapts = [swts, dsws, feas, gras] + # k = 2**23 + # print(dsws[0][42] * k) # reshape if nvnmd_cfg.version == 0: nmerges = [2 * 2, 2 * 2, 4 * 2, 4 * 2] @@ -539,7 +577,7 @@ def wrap_map(self): bs = e.merge_bin(bs, nmerges[ii]) bss.append(bs) if nvnmd_cfg.version == 1: - ndim = [2, 2, M1, M1] + ndim = [3, 3, M1, M1] bss = [] for ii in range(len(mapts)): nd = ndim[ii] @@ -551,7 +589,12 @@ def wrap_map(self): # bs = e.flt2bin(d, NBIT_FLTE, NBIT_FLTF) bss.append(bs) - bswt, bdsw, bfea, bgra = bss + ( + bswt, + bdsw, + bfea, + bgra, + ) = bss return bswt, bdsw, bfea, bgra def wrap_lut(self): @@ -572,18 +615,19 @@ def wrap_lut(self): NBIT_DATA_FL = nvnmd_cfg.nbit["NBIT_FIT_DATA_FL"] e = Encode() - # std - d = maps["dstd_inv"] - d2 = np.zeros([ntype_max, 2]) # pylint: disable=no-explicit-dtype + # avg & std + d_avg = maps["davg_opp"] + d_std = maps["dstd_inv"] + # d2 = np.zeros([ntype_max, 3]) + d2 = np.zeros([ntype_max, 4], dtype=GLOBAL_NP_FLOAT_PRECISION) for ii in range(ntype): - _d = d[ii, :2] - _d = np.reshape(_d, [-1, 2]) - _d = np.concatenate([_d[:, 0], _d[:, 1]], axis=0) - d2[ii] = _d + d2[ii, 0] = d_avg[ii, 0] + d2[ii, 1] = d_std[ii, 0] + d2[ii, 2] = d_std[ii, 1] bstd = e.flt2bin(d2, NBIT_FLTE, NBIT_FLTF) # gtt d = maps["gt"] - d2 = np.zeros([ntype_max**2, M1]) # pylint: disable=no-explicit-dtype + d2 = np.zeros([ntype_max**2, M1], dtype=GLOBAL_NP_FLOAT_PRECISION) for ii in range(ntype): for jj in range(ntype): _d = d[ii * (ntype + 1) + jj] @@ -595,7 +639,7 @@ def wrap_lut(self): d = maps["t_ebd"] w = get_type_weight(weight, 0) nd = w.shape[1] - d2 = np.zeros([ntype_max, nd]) # pylint: disable=no-explicit-dtype + d2 = np.zeros([ntype_max, nd], dtype=GLOBAL_NP_FLOAT_PRECISION) for ii in range(ntype): _d = d[ii] _d = np.reshape(_d, [1, -1]) diff --git a/deepmd/tf/nvnmd/utils/config.py b/deepmd/tf/nvnmd/utils/config.py index 41bd650b06..9f90d7facf 100644 --- a/deepmd/tf/nvnmd/utils/config.py +++ b/deepmd/tf/nvnmd/utils/config.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import copy import logging import numpy as np @@ -41,21 +42,23 @@ class NvnmdConfig: DOI: 10.1038/s41524-022-00773-z """ - def __init__(self, jdata: dict) -> None: + def __init__(self, jdata: dict): self.version = 0 + self.device = "vu9p" self.enable = False self.map = {} - self.config = jdata_config_v0.copy() + self.config = copy.deepcopy(jdata_config_v0) self.save_path = "nvnmd/config.npy" self.weight = {} self.init_from_jdata(jdata) - def init_from_jdata(self, jdata: dict = {}) -> None: + def init_from_jdata(self, jdata: dict = {}): r"""Initialize this class with `jdata` loaded from input script.""" if jdata == {}: return None self.version = jdata["version"] + self.device = jdata["device"] self.max_nnei = jdata["max_nnei"] self.net_size = jdata["net_size"] self.map_file = jdata["map_file"] @@ -66,7 +69,6 @@ def init_from_jdata(self, jdata: dict = {}) -> None: self.restore_fitting_net = jdata["restore_fitting_net"] self.quantize_descriptor = jdata["quantize_descriptor"] self.quantize_fitting_net = jdata["quantize_fitting_net"] - # load data if self.enable: self.map = FioDic().load(self.map_file, {}) @@ -78,7 +80,7 @@ def init_from_jdata(self, jdata: dict = {}) -> None: # if load the file, set net_size self.init_net_size() - def init_value(self) -> None: + def init_value(self): r"""Initialize member with dict.""" self.dscp = self.config["dscp"] self.fitn = self.config["fitn"] @@ -87,7 +89,7 @@ def init_value(self) -> None: self.ctrl = self.config["ctrl"] self.nbit = self.config["nbit"] - def update_config(self) -> None: + def update_config(self): r"""Update config from dict.""" self.config["dscp"] = self.dscp self.config["fitn"] = self.fitn @@ -96,7 +98,7 @@ def update_config(self) -> None: self.config["ctrl"] = self.ctrl self.config["nbit"] = self.nbit - def init_train_mode(self, mod="cnn") -> None: + def init_train_mode(self, mod="cnn"): r"""Configure for taining cnn or qnn.""" if mod == "cnn": self.restore_descriptor = False @@ -109,7 +111,7 @@ def init_train_mode(self, mod="cnn") -> None: self.quantize_descriptor = True self.quantize_fitting_net = True - def init_from_config(self, jdata) -> None: + def init_from_config(self, jdata): r"""Initialize member element one by one.""" if "ctrl" in jdata.keys(): if "VERSION" in jdata["ctrl"].keys(): @@ -128,48 +130,51 @@ def init_from_config(self, jdata) -> None: self.config["nbit"] = self.init_nbit(self.config["nbit"], self.config) self.init_value() - def init_config_by_version(self, version, max_nnei) -> None: + def init_config_by_version(self, version, max_nnei): r"""Initialize version-dependent parameters.""" self.version = version self.max_nnei = max_nnei log.debug(f"#Set nvnmd version as {self.version} ") if self.version == 0: if self.max_nnei == 128: - self.jdata_deepmd_input = jdata_deepmd_input_v0_ni128.copy() - self.config = jdata_config_v0_ni128.copy() + self.jdata_deepmd_input = copy.deepcopy(jdata_deepmd_input_v0_ni128) + self.config = copy.deepcopy(jdata_config_v0_ni128) elif self.max_nnei == 256: - self.jdata_deepmd_input = jdata_deepmd_input_v0_ni256.copy() - self.config = jdata_config_v0_ni256.copy() + self.jdata_deepmd_input = copy.deepcopy(jdata_deepmd_input_v0_ni256) + self.config = copy.deepcopy(jdata_config_v0_ni256) else: log.error("The max_nnei only can be set as 128|256 for version 0") if self.version == 1: if self.max_nnei == 128: - self.jdata_deepmd_input = jdata_deepmd_input_v1_ni128.copy() - self.config = jdata_config_v1_ni128.copy() + self.jdata_deepmd_input = copy.deepcopy(jdata_deepmd_input_v1_ni128) + self.config = copy.deepcopy(jdata_config_v1_ni128) elif self.max_nnei == 256: - self.jdata_deepmd_input = jdata_deepmd_input_v1_ni256.copy() - self.config = jdata_config_v1_ni256.copy() + self.jdata_deepmd_input = copy.deepcopy(jdata_deepmd_input_v1_ni256) + self.config = copy.deepcopy(jdata_config_v1_ni256) else: log.error("The max_nnei only can be set as 128|256 for version 1") - def init_net_size(self) -> None: + def init_net_size(self): r"""Initialize net_size.""" self.net_size = self.config["fitn"]["neuron"][0] if self.enable: self.config["fitn"]["neuron"] = [self.net_size] * 3 - def init_from_deepmd_input(self, jdata) -> None: + def init_from_deepmd_input(self, jdata): r"""Initialize members with input script of deepmd.""" fioObj = FioDic() self.config["dscp"] = fioObj.update(jdata["descriptor"], self.config["dscp"]) self.config["fitn"] = fioObj.update(jdata["fitting_net"], self.config["fitn"]) self.config["dscp"] = self.init_dscp(self.config["dscp"], self.config) self.config["fitn"] = self.init_fitn(self.config["fitn"], self.config) + log.info(self.config["dscp"]) dp_in = {"type_map": fioObj.get(jdata, "type_map", [])} self.config["dpin"] = fioObj.update(dp_in, self.config["dpin"]) # + log.info(self.config["dscp"]) self.init_net_size() self.init_value() + log.info(self.config["dscp"]) def init_dscp(self, jdata: dict, jdata_parent: dict = {}) -> dict: r"""Initialize members about descriptor.""" @@ -180,7 +185,9 @@ def init_dscp(self, jdata: dict, jdata_parent: dict = {}) -> dict: jdata["SEL"] = (jdata["sel"] + [0, 0, 0, 0])[0:4] for s in jdata["sel"]: if s > self.max_nnei: - log.error("The sel cannot be greater than the max_nnei") + log.error( + f"The sel ({jdata['sel']}) cannot be greater than the max_nnei ({self.max_nnei})" + ) exit(1) jdata["NNODE_FEAS"] = [1] + jdata["neuron"] jdata["nlayer_fea"] = len(jdata["neuron"]) @@ -193,20 +200,51 @@ def init_dscp(self, jdata: dict, jdata_parent: dict = {}) -> dict: jdata["ntype"] = len(jdata["sel"]) jdata["ntypex"] = 1 if (jdata["same_net"]) else jdata["ntype"] if self.version == 1: - # embedding jdata["M1"] = jdata["neuron"][-1] jdata["M2"] = jdata["axis_neuron"] + # embedding jdata["SEL"] = jdata["sel"] if jdata["sel"] > self.max_nnei: - log.error("The sel cannot be greater than the max_nnei") + log.error( + f"The sel ({jdata['sel']}) cannot be greater than the max_nnei ({self.max_nnei})" + ) exit(1) jdata["NNODE_FEAS"] = [1] + jdata["neuron"] jdata["nlayer_fea"] = len(jdata["neuron"]) jdata["same_net"] = 1 if jdata["type_one_side"] else 0 # neighbor - jdata["NI"] = self.max_nnei + jdata["NI"] = jdata["sel"] jdata["NIDP"] = int(jdata["sel"]) jdata["NIX"] = 2 ** int(np.ceil(np.log2(jdata["NIDP"] / 1.5))) + if jdata["sel"] <= 128: + if self.device == "vu13p": + jdata["NSTEP"] = 0 + else: + jdata["NSTEP"] = 0 + elif 128 < jdata["sel"] <= 160: + if self.device == "vu13p": + jdata["NSTEP"] = 8 + else: + jdata["NSTEP"] = 16 + # jdata["NSTEP"] = jdata["NI"]/2 - self.config["ctrl"]["NSTDM"] + elif 160 < jdata["sel"] <= 192: + if self.device == "vu13p": + jdata["NSTEP"] = 16 + else: + jdata["NSTEP"] = 32 + elif 192 < jdata["sel"] <= 224: + if self.device == "vu13p": + jdata["NSTEP"] = 24 + else: + jdata["NSTEP"] = 48 + elif 224 < jdata["sel"] <= 256: + if self.device == "vu13p": + jdata["NSTEP"] = 32 + else: + jdata["NSTEP"] = 64 + if jdata["sel"] > 256: + log.error(f"The sel ({jdata['sel']}) should be less than 256") + exit(1) # type jdata["ntype"] = jdata["ntype"] return jdata @@ -245,6 +283,16 @@ def init_ctrl(self, jdata: dict, jdata_parent: dict = {}) -> dict: jdata["NSEL"] = jdata["NSTDM"] * ntype_max jdata["VERSION"] = 0 if self.version == 1: + if self.device == "vu13p": + jdata["NSTDM"] = 32 + jdata["NSTDM_M1"] = jdata["NSTDM"] // 2 + jdata["NSTDM_M2"] = 2 + jdata["MAX_NNEI"] = 256 + elif self.device == "vu9p": + jdata["NSTDM"] = 64 + jdata["NSTDM_M1"] = jdata["NSTDM"] // 2 + jdata["NSTDM_M2"] = 2 + jdata["MAX_NNEI"] = 256 jdata["NSADV"] = jdata["NSTDM"] + 1 jdata["NSEL"] = jdata["NSTDM"] jdata["VERSION"] = 1 @@ -269,22 +317,25 @@ def init_nbit(self, jdata: dict, jdata_parent: dict = {}) -> dict: jdata["NBIT_SEL"] = int(np.ceil(np.log2(NSEL))) return jdata - def save(self, file_name=None) -> None: + def save(self, file_name=None): r"""Save all configuration to file.""" if file_name is None: file_name = self.save_path else: self.save_path = file_name self.update_config() + # fix debug config_file not correspond + # load_config = FioDic().load(self.config_file, self.config) + # self.init_from_config(load_config) FioDic().save(file_name, self.config) - def set_ntype(self, ntype) -> None: + def set_ntype(self, ntype): r"""Set the number of type.""" self.dscp["ntype"] = ntype self.config["dscp"]["ntype"] = ntype nvnmd_cfg.save() - def get_s_range(self, davg, dstd) -> None: + def get_s_range(self, davg, dstd): r"""Get the range of switch function.""" rmin = nvnmd_cfg.dscp["rcut_smth"] rmax = nvnmd_cfg.dscp["rcut"] @@ -301,8 +352,8 @@ def get_s_range(self, davg, dstd) -> None: nvnmd_cfg.save() # check log.info(f"the range of s is [{smin}, {smax}]") - if smax - smin > 16.0: - log.warning("the range of s is over the limit (smax - smin) > 16.0") + if smax - smin > 32.0: + log.warning("the range of s is over the limit (smax - smin) > 32.0") log.warning( "Please reset the rcut_smth as a bigger value to fix this warning" ) @@ -311,10 +362,26 @@ def get_dscp_jdata(self): r"""Generate `model/descriptor` in input script.""" dscp = self.dscp jdata = self.jdata_deepmd_input["model"]["descriptor"] - jdata["sel"] = dscp["sel"] + if self.version == 0: + jdata["sel"] = dscp["sel"] + if self.version == 1: + if dscp["sel"] <= 128: + jdata["sel"] = 128 + elif 128 < dscp["sel"] <= 160: + jdata["sel"] = 160 + elif 160 < dscp["sel"] <= 192: + jdata["sel"] = 192 + elif 192 < dscp["sel"] <= 224: + jdata["sel"] = 224 + elif 224 < dscp["sel"] <= 256: + jdata["sel"] = 256 + else: + log.error(f"The input sel ({dscp['sel']!s}) should be less than 256") + exit(1) jdata["rcut"] = dscp["rcut"] jdata["rcut_smth"] = dscp["rcut_smth"] jdata["neuron"] = dscp["neuron"] + jdata["seed"] = dscp["seed"] jdata["type_one_side"] = dscp["type_one_side"] jdata["axis_neuron"] = dscp["axis_neuron"] return jdata @@ -324,13 +391,16 @@ def get_fitn_jdata(self): fitn = self.fitn jdata = self.jdata_deepmd_input["model"]["fitting_net"] jdata["neuron"] = fitn["neuron"] + jdata["seed"] = fitn["seed"] return jdata def get_model_jdata(self): r"""Generate `model` in input script.""" jdata = self.jdata_deepmd_input["model"] + log.info(jdata) jdata["descriptor"] = self.get_dscp_jdata() jdata["fitting_net"] = self.get_fitn_jdata() + log.info(jdata) if len(self.dpin["type_map"]) > 0: jdata["type_map"] = self.dpin["type_map"] return jdata @@ -339,6 +409,7 @@ def get_nvnmd_jdata(self): r"""Generate `nvnmd` in input script.""" jdata = self.jdata_deepmd_input["nvnmd"] jdata["net_size"] = self.net_size + jdata["device"] = self.device jdata["max_nnei"] = self.max_nnei jdata["config_file"] = self.config_file jdata["weight_file"] = self.weight_file @@ -364,7 +435,7 @@ def get_training_jdata(self): def get_deepmd_jdata(self): r"""Generate input script with member element one by one.""" - jdata = self.jdata_deepmd_input.copy() + jdata = copy.deepcopy(self.jdata_deepmd_input) jdata["model"] = self.get_model_jdata() jdata["nvnmd"] = self.get_nvnmd_jdata() jdata["learning_rate"] = self.get_learning_rate_jdata() @@ -380,7 +451,7 @@ def get_dp_init_weights(self): dic[key2] = self.weight[key] return dic - def disp_message(self) -> None: + def disp_message(self): r"""Display the log of NVNMD.""" NVNMD_CONFIG = ( f"enable: {self.enable}", diff --git a/deepmd/tf/utils/type_embed.py b/deepmd/tf/utils/type_embed.py index 9b7b17528d..54e07eebdb 100644 --- a/deepmd/tf/utils/type_embed.py +++ b/deepmd/tf/utils/type_embed.py @@ -187,6 +187,8 @@ def build( ) ebd_type = tf.reshape(ebd_type, [ntypes, -1]) name = "type_embed_net" + suffix + if nvnmd_cfg.enable: + self.use_tebd_bias = True if ( nvnmd_cfg.enable and (nvnmd_cfg.version == 1) diff --git a/deepmd/utils/argcheck_nvnmd.py b/deepmd/utils/argcheck_nvnmd.py index da7252f3f7..97b17bef96 100644 --- a/deepmd/utils/argcheck_nvnmd.py +++ b/deepmd/utils/argcheck_nvnmd.py @@ -9,6 +9,7 @@ def nvnmd_args(fold_subdoc: bool = False) -> Argument: "configuration the nvnmd version (0 | 1), 0 for 4 types, 1 for 32 types" ) doc_max_nnei = "configuration the max number of neighbors, 128|256 for version 0, 128 for version 1" + doc_device = "hardware used by model, vu9p or vu13p" doc_net_size_file = ( "configuration the number of nodes of fitting_net, just can be set as 128" ) @@ -26,6 +27,7 @@ def nvnmd_args(fold_subdoc: bool = False) -> Argument: doc_quantize_fitting_net = "enable the quantizatioin of fitting_net" args = [ Argument("version", int, optional=False, default=0, doc=doc_version), + Argument("device", str, optional=False, default="none", doc=doc_device), Argument("max_nnei", int, optional=False, default=128, doc=doc_max_nnei), Argument("net_size", int, optional=False, default=128, doc=doc_net_size_file), Argument("map_file", str, optional=False, default="none", doc=doc_map_file), diff --git a/source/tests/tf/nvnmd/config.npy b/source/tests/tf/nvnmd/config.npy new file mode 100644 index 0000000000..72e66a795d Binary files /dev/null and b/source/tests/tf/nvnmd/config.npy differ diff --git a/source/tests/tf/nvnmd/out/map_v0_cnn_part_1.npy b/source/tests/tf/nvnmd/out/map_v0_cnn_part_1.npy new file mode 100644 index 0000000000..bffa6755df Binary files /dev/null and b/source/tests/tf/nvnmd/out/map_v0_cnn_part_1.npy differ diff --git a/source/tests/tf/nvnmd/out/map_v0_cnn_part_2.npy b/source/tests/tf/nvnmd/out/map_v0_cnn_part_2.npy new file mode 100644 index 0000000000..8700ed9716 Binary files /dev/null and b/source/tests/tf/nvnmd/out/map_v0_cnn_part_2.npy differ diff --git a/source/tests/tf/nvnmd/out/map_v0_cnn_part_3.npy b/source/tests/tf/nvnmd/out/map_v0_cnn_part_3.npy new file mode 100644 index 0000000000..c6242c444f Binary files /dev/null and b/source/tests/tf/nvnmd/out/map_v0_cnn_part_3.npy differ diff --git a/source/tests/tf/nvnmd/out/map_v1_cnn_part_1.npy b/source/tests/tf/nvnmd/out/map_v1_cnn_part_1.npy new file mode 100644 index 0000000000..04e8f2a3a4 Binary files /dev/null and b/source/tests/tf/nvnmd/out/map_v1_cnn_part_1.npy differ diff --git a/source/tests/tf/nvnmd/out/map_v1_cnn_part_2.npy b/source/tests/tf/nvnmd/out/map_v1_cnn_part_2.npy new file mode 100644 index 0000000000..147169c6c5 Binary files /dev/null and b/source/tests/tf/nvnmd/out/map_v1_cnn_part_2.npy differ diff --git a/source/tests/tf/nvnmd/out/weight_v1_qnn.npy b/source/tests/tf/nvnmd/out/weight_v1_qnn.npy new file mode 100644 index 0000000000..e10c483b6f Binary files /dev/null and b/source/tests/tf/nvnmd/out/weight_v1_qnn.npy differ diff --git a/source/tests/tf/nvnmd/ref/config_v1_cnn.npy b/source/tests/tf/nvnmd/ref/config_v1_cnn.npy index f50cec9c98..72e66a795d 100644 Binary files a/source/tests/tf/nvnmd/ref/config_v1_cnn.npy and b/source/tests/tf/nvnmd/ref/config_v1_cnn.npy differ diff --git a/source/tests/tf/test_nvnmd_entrypoints.py b/source/tests/tf/test_nvnmd_entrypoints.py index eaf8bfafd5..c81de7477f 100644 --- a/source/tests/tf/test_nvnmd_entrypoints.py +++ b/source/tests/tf/test_nvnmd_entrypoints.py @@ -51,6 +51,16 @@ class TestNvnmdEntrypointsV0(tf.test.TestCase): def test_mapt_cnn_v0(self) -> None: config_file = str(tests_path / "nvnmd" / "ref" / "config_v0_cnn.npy") weight_file = str(tests_path / "nvnmd" / "ref" / "weight_v0_cnn.npy") + output_filename = f"{tests_path}/nvnmd/out/map_v0_cnn.npy" + parts = [ + f"{tests_path}/nvnmd/out/map_v0_cnn_part_1.npy", + f"{tests_path}/nvnmd/out/map_v0_cnn_part_2.npy", + f"{tests_path}/nvnmd/out/map_v0_cnn_part_3.npy", + ] + with open(output_filename, "wb") as output_file: + for part_filename in parts: + with open(part_filename, "rb") as part_file: + output_file.write(part_file.read()) map_file = str(tests_path / "nvnmd" / "out" / "map_v0_cnn.npy") # mapt mapObj = MapTable(config_file, weight_file, map_file) @@ -61,7 +71,7 @@ def test_mapt_cnn_v0(self) -> None: pred = mapObj.mapping2(x, {"s": mapt["s"]}, mapt["cfg_u2s"]) pred = np.reshape(pred["s"], [-1]) ref_dout = [ - -0.36629248, + 0, 11.73139954, 7.64562607, 5.61323166, @@ -93,7 +103,7 @@ def test_mapt_cnn_v0(self) -> None: -0.36629248, -0.36629248, -0.36629248, - -0.37758207, + 0, 12.93425751, 8.43843079, 6.2020607, @@ -472,20 +482,20 @@ def test_model_qnn_v0(self) -> None: idx = np.array([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]) pred = np.reshape(valuedic["o_descriptor"], [-1]) ref_dout = [ - 0.00614393, - -0.00593019, - 0.00424719, - 0.0053246, - 0.03973877, - 0.00422275, - 0.0081141, - 0.01380706, - 0.04038167, - 0.01963985, - 0.01557279, - 0.00587749, - 0.03684294, - 0.02304173, + 0.06064451, + -0.03320849, + 0.02500141, + 0.00765049, + 0.0825026, + 0.01446712, + 0.10439038, + 0.26883912, + 0.13117445, + 0.19348001, + 0.21398795, + 0.18065619, + 0.32556486, + 0.1505754, ] np.testing.assert_almost_equal(pred[idx], ref_dout, 8) # o_rmat @@ -509,7 +519,7 @@ def test_model_qnn_v0(self) -> None: np.testing.assert_almost_equal(pred[idx], ref_dout, 8) # o_energy pred = valuedic["o_energy"] - ref_dout = -62.60181403 + ref_dout = -56.20791733 np.testing.assert_almost_equal(pred, ref_dout, 8) def tearDown(self) -> None: @@ -522,10 +532,19 @@ class TestNvnmdEntrypointsV1(tf.test.TestCase): def test_mapt_cnn_v1(self) -> None: config_file = str(tests_path / "nvnmd" / "ref" / "config_v1_cnn.npy") weight_file = str(tests_path / "nvnmd" / "ref" / "weight_v1_cnn.npy") + output_filename = f"{tests_path}/nvnmd/out/map_v1_cnn.npy" + parts = [ + f"{tests_path}/nvnmd/out/map_v1_cnn_part_1.npy", + f"{tests_path}/nvnmd/out/map_v1_cnn_part_2.npy", + ] + with open(output_filename, "wb") as output_file: + for part_filename in parts: + with open(part_filename, "rb") as part_file: + output_file.write(part_file.read()) map_file = str(tests_path / "nvnmd" / "out" / "map_v1_cnn.npy") # mapt mapObj = MapTable(config_file, weight_file, map_file) - mapObj.Gs_Gt_mode = 0 + mapObj.Gs_Gt_mode = 2 mapt = mapObj.build_map() # N = 32 @@ -573,134 +592,134 @@ def test_mapt_cnn_v1(self) -> None: pred = mapObj.mapping2(x, {"g": mapt["g"]}, mapt["cfg_s2g"]) pred = np.reshape(pred["g"], [-1]) ref_dout = [ - -1.0770483, - 0.02810931, - 1.62244892, - 1.1394949, - -1.37506485, - 1.87449265, - 1.45126152, - 2.25518417, - 1.89931679, - -0.38976216, - 0.36672592, - -0.11820012, - 0.71667051, - 0.3249805, - 0.02166232, - 2.50204468, - -0.55733442, - -1.39325333, - 1.50640583, - 0.78623056, - -1.35564613, - 1.54273891, - -0.34428048, - 3.64800453, - 2.21231842, - -0.76567078, - -0.49742508, - 1.26832676, - 0.60595608, - 0.03836584, - 1.15446472, - 1.80911732, - -0.97592735, - 0.13686812, - 1.54618168, - 0.59849787, - -0.47609806, - 1.73630333, - 2.79491806, - 2.56436539, - 1.61628342, - -0.10759199, - -0.38582754, - -0.38886118, - 1.72347736, - 0.29652929, - 1.62401485, - 2.978302, - -0.64212656, - -1.15452957, - 1.23154068, - 1.07660294, - 0.04454666, - 1.29101658, - 1.15926933, - 3.78246689, - 2.15491486, - -0.55593348, - -0.87728548, - 0.87736368, - 1.34770393, - -0.12231946, - 2.63235283, - 2.44036293, - -0.54117918, - 1.20576477, - 1.79139519, - -0.1107251, - 0.19127345, - 1.56829548, - 3.01862144, - 2.87240219, - 1.26001167, - 0.92324543, - -0.9365592, - -0.41182208, - 2.55228806, - 0.55107069, - 2.24660492, - 3.25927734, - -0.65150642, - -0.01702949, - 1.26624203, - 0.9379921, - 1.49566174, - 1.32353592, - 1.52009106, - 3.78927612, - 2.0195179, - 0.921103, - -1.12800884, - 0.7468524, - 2.01961136, - 0.12407303, - 2.90290642, - 3.03822327, - -0.40549946, - 1.8200779, - 1.85752392, - -0.30378056, - 0.28259802, - 1.466959, - 2.96389198, - 2.99284935, - 1.08252144, - 1.4072113, - -1.21003246, - -0.34711146, - 2.75219727, - 0.66153288, - 2.36572838, - 3.28229713, - -0.75278902, - 0.53803587, - 1.28469372, - 0.904603, - 1.79652405, - 1.37996578, - 1.48132515, - 3.78080368, - 1.90426636, - 1.70287418, - -1.30566597, - 0.78082895, - 2.23096085, - 0.25823808, - 2.89975548, - 3.24079514, + -2.07704735, + -0.97189045, + 0.62244892, + 0.13949502, + -2.37506485, + 0.87449265, + 0.451262, + 1.25518513, + 0.89931679, + -1.38976192, + -0.6332736, + -1.11819935, + -0.28332901, + -0.67501926, + -0.97833729, + 1.50204468, + -1.55733395, + -2.39325333, + 0.50640631, + -0.21376932, + -2.35564613, + 0.54273939, + -1.34428024, + 2.64800453, + 1.21231842, + -1.76567078, + -1.49742508, + 0.26832676, + -0.39404368, + -0.96163368, + 0.15446472, + 0.80911779, + -1.97592735, + -0.863132, + 0.54618216, + -0.40150189, + -1.47609806, + 0.73630381, + 1.79491806, + 1.56436443, + 0.61628389, + -1.10759163, + -1.38582611, + -1.38885975, + 0.72347689, + -0.70347071, + 0.62401438, + 1.978302, + -1.64212608, + -2.15452957, + 0.2315408, + 0.07660317, + -0.9554534, + 0.29101682, + 0.15927005, + 2.78246689, + 1.15491581, + -1.55593395, + -1.877285, + -0.1226362, + 0.34770393, + -1.12232018, + 1.63235283, + 1.44036198, + -1.54117966, + 0.20576489, + 0.79139519, + -1.11072445, + -0.80872631, + 0.56829643, + 2.01862144, + 1.87240219, + 0.26001167, + -0.07675433, + -1.93655872, + -1.41182232, + 1.5522871, + -0.44892907, + 1.24660492, + 2.25927734, + -1.65150547, + -1.01702976, + 0.26624179, + -0.06200758, + 0.49566174, + 0.32353592, + 0.52009106, + 2.78927612, + 1.01951885, + -0.07889676, + -2.12800598, + -0.25314736, + 1.01961136, + -0.87592697, + 1.90290642, + 2.03822136, + -1.4054985, + 0.8200779, + 0.85752344, + -1.3037796, + -0.71740198, + 0.46695948, + 1.96389294, + 1.9928503, + 0.08252192, + 0.40721178, + -2.2100296, + -1.3471117, + 1.75219727, + -0.33846688, + 1.36572742, + 2.28229713, + -1.75278759, + -0.46196413, + 0.28469372, + -0.09539658, + 0.79652452, + 0.3799665, + 0.48132539, + 2.78080368, + 0.90426683, + 0.70287418, + -2.30566406, + -0.21917057, + 1.23096085, + -0.74176168, + 1.89975643, + 2.24079514, ] np.testing.assert_almost_equal(pred, ref_dout, 8) @@ -788,20 +807,20 @@ def test_model_qnn_v1(self) -> None: idx = np.array([1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]) pred = np.reshape(valuedic["o_descriptor"], [-1]) ref_dout = [ - -0.03495526, - 0.19181037, - 0.00139165, - 0.03920531, - 0.3982904, - 0.05152893, - 0.28467178, - 1.33868217, - 0.83964777, - 1.16189384, - 1.16278744, - 0.93079185, - 1.37950325, - 0.90435696, + 0.06811976, + -0.05252528, + 0.0639708, + 0.04032826, + 0.09842145, + 0.06800854, + 0.06667721, + 0.09088826, + 0.05695295, + 0.106408, + 0.06601942, + 0.04956532, + 0.05257654, + 0.07719779, ] np.testing.assert_almost_equal(pred[idx], ref_dout, 8) # o_rmat @@ -825,7 +844,7 @@ def test_model_qnn_v1(self) -> None: np.testing.assert_almost_equal(pred[idx], ref_dout, 8) # o_energy pred = valuedic["o_energy"] - ref_dout = 60.73941362 + ref_dout = [-13.23571336] np.testing.assert_almost_equal(pred, ref_dout, 8) # test freeze sess = self.cached_session().__enter__() @@ -872,7 +891,7 @@ def test_wrap_qnn_v1(self) -> None: idx = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384] idx = [i + 128 * 4 for i in idx] pred = [data[i] for i in idx] - red_dout = [249, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 254, 95, 24, 176] + red_dout = [1, 242, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 254, 95, 24, 176] np.testing.assert_equal(pred, red_dout) def tearDown(self) -> None: