mMEGNet
modified MEGNet(mMEGNet)
该脚本实现了一个基于MEGNet图神经网络的预测流程,用于预测晶体材料的超导临界温度Tc。程序从data.xlsx读取结构ID和对应的Tc标签,从cif目录加载相应结构文件,并可选择加入元素电负性均值和标准差作为全局特征。数据可采用train/test划分或KFold交叉验证方式进行训练,并支持EarlyStopping和学习率调度。脚本输出RMSE、MAE和R2指标,保存预测散点图;在megnet_tuned模式下,使用Optuna优化网络宽度和学习率,并自动保存每次trial和fold的日志及模型文件。
使用方法
运行命令:python mMEGNet_v14.py。
-
TARGET_COL:预测目标列名称(例如'tc'表示超导临界温度)。 -
method:训练模式选择('megnet_default'为固定结构,'megnet_tuned'为使用Optuna进行超参数优化)。 -
batch_size:每次梯度更新的样本数,影响收敛稳定性和显存占用。 -
lr:初始学习率,决定优化步长。 -
train_ratio/test_ratio:训练集和测试集划分比例。 -
USE_EN_GLOBAL:是否加入基于电负性的全局状态特征。 -
n1, n2, n3:控制MEGNet网络宽度(模型容量);数值越大模型越复杂,但可能增加过拟合风险。 -
epochs:最大训练轮数。 -
n_folds:交叉验证折数,用于控制模型评估的稳健性(推荐5,或'none'表示单次划分)。
[!NOTE] 请将
~\anaconda3\Lib\site-packages\megnet\data\graph.py中的get_atom_features函数替换为以下实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20@staticmethod def get_atom_features(structure) -> List[Any]: z_list = [] for site in structure: try: if getattr(site, "is_ordered", True): z = int(site.specie.Z) else: items = list(site.species.items()) # [(Element, occ), ...] items.sort(key=lambda kv: (float(kv[1]), getattr(kv[0], "Z", 0)), reverse=True) z = int(getattr(items[0][0], "Z", 0)) except Exception: tok = str(site.species_string).split()[0] try: from pymatgen.core.periodic_table import Element z = int(Element(tok).Z) except Exception: z = 0 z_list.append(z) return np.array(z_list, dtype="int32").tolist()该修改通过在分数占位(无序)结构中选择占据比例最高的元素,避免在读取包含部分占位原子的
CIF文件时出现报错。
引用
论文正式发表后补充。