目录

mMEGNet

misaraty 更新 | 2026-02-28
前言
下载:mMEGNet

modified MEGNet(mMEGNet)

该脚本实现了一个基于MEGNet图神经网络的预测流程,用于预测晶体材料的超导临界温度Tc。程序从data.xlsx读取结构ID和对应的Tc标签,从cif目录加载相应结构文件,并可选择加入元素电负性均值和标准差作为全局特征。数据可采用train/test划分或KFold交叉验证方式进行训练,并支持EarlyStopping和学习率调度。脚本输出RMSEMAER2指标,保存预测散点图;在megnet_tuned模式下,使用Optuna优化网络宽度和学习率,并自动保存每次trialfold的日志及模型文件。

使用方法

运行命令:python mMEGNet_v14.py

  • TARGET_COL:预测目标列名称(例如'tc'表示超导临界温度)。

  • method:训练模式选择('megnet_default' 为固定结构,'megnet_tuned' 为使用Optuna进行超参数优化)。

  • batch_size:每次梯度更新的样本数,影响收敛稳定性和显存占用。

  • lr:初始学习率,决定优化步长。

  • train_ratio / test_ratio:训练集和测试集划分比例。

  • USE_EN_GLOBAL:是否加入基于电负性的全局状态特征。

  • n1, n2, n3:控制MEGNet网络宽度(模型容量);数值越大模型越复杂,但可能增加过拟合风险。

  • epochs:最大训练轮数。

  • n_folds:交叉验证折数,用于控制模型评估的稳健性(推荐5,或 'none' 表示单次划分)。

[!NOTE] 请将 ~\anaconda3\Lib\site-packages\megnet\data\graph.py 中的 get_atom_features 函数替换为以下实现:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
    @staticmethod
    def get_atom_features(structure) -> List[Any]:
        z_list = []
        for site in structure:
            try:
                if getattr(site, "is_ordered", True):
                    z = int(site.specie.Z)
                else:
                    items = list(site.species.items())  # [(Element, occ), ...]
                    items.sort(key=lambda kv: (float(kv[1]), getattr(kv[0], "Z", 0)), reverse=True)
                    z = int(getattr(items[0][0], "Z", 0))
            except Exception:
                tok = str(site.species_string).split()[0]
                try:
                    from pymatgen.core.periodic_table import Element
                    z = int(Element(tok).Z)
                except Exception:
                    z = 0
            z_list.append(z)
        return np.array(z_list, dtype="int32").tolist()

该修改通过在分数占位(无序)结构中选择占据比例最高的元素,避免在读取包含部分占位原子的CIF文件时出现报错。

引用

论文正式发表后补充。