Bridge619

Bridge619

Bridge619

命定的局限尽可永在,不屈的挑战却不可须臾或缺!

100 文章数
11 评论数
来首音乐
光阴似箭
今日已经过去小时
这周已经过去
本月已经过去
今年已经过去个月

TransE知识图谱补全

Bridge619
2022-10-30 / 0 评论 / 1532 阅读 / 0 点赞

1. 使用 TransE 模型对如下知识图谱进行补全。TransE原论文

说明:

(1) 图中实线部分为知识图谱的已知关系;

(2) 需要补全图中的虚线对应的三元组,即( 浙江大学 位于市 杭州市)。

1.1 目的

知识图谱补全是从已知的知识图谱中提取出三元组(h,r,t),为实体和关系进行建模,通过训练出的模型进行链接预测,以达成知识图谱补全的目标。

利用训练集进行transE建模,通过训练为每个实体和关系建立起向量映射,并在测试集中计算MeanRank和Hit10指标进行结果检验。

1.2 构建数据集

分为以下四个文件:

  • 构建实体id分配数据集entity2id.txt

    • 上图中共 个实体,为每一个实体分配一个 ,从 ~
    • 文件中每一行表示一个实体和对应的
    中国	    0
    西湖区	    1
    南京市	    2
    鼓楼区	    3
    南京大学  4
    海淀区     5
    朝阳区	    6
    江苏省	    7
    省份	    8
    清华大学  9
    浙江大学  10
    杭州市     11
    浙江省	    12
    大学    13
    北京市     14
    
  • 构建关系id分配数据集relation2id.txt​

    • 上图中共包含 个关系,为每一个关系分配一个 ,从 ~
    • 文件中每一行表示一个关系和对应的
    包含区	    0
    位于国家	1
    位于市     2
    类别是	    3
    包含单位	4
    临近省份	5
    位于省	    6
    
  • 构建训练数据集train.txt

    • 包含23个三元组,如下所示
    • 每一行表示⼀个三元组, 按头实体、关 系、尾实体顺序(这个顺序其实可以任意,在代码中也可以调整,但是为理解方便起见,一般是该顺序)
    • 另外需要注意的一点是对于本题所用代码,需要保证实体、关 系、尾实体的对齐
    杭州市	包含区	西湖区
    杭州市	位于省	浙江省
    浙江大学	位于省	浙江省
    浙江大学	位于国家	中国
    浙江大学	类别是	大学
    清华大学	类别是	大学
    清华大学	位于市	北京市
    海淀区	包含单位	清华大学
    西湖区	包含单位	浙江大学
    西湖区	位于省	浙江省
    北京市	位于国家	中国
    北京市	包含区	海淀区
    北京市	包含区	朝阳区
    南京市	位于省	江苏省
    南京市	包含区	鼓楼区
    南京市	位于国家	中国
    江苏省	类别是	省份
    浙江省	类别是	省份
    浙江省	临近省份	江苏省
    浙江省	位于国家	中国
    南京大学	位于省	江苏省
    南京大学	位于市	南京市
    鼓楼区	包含单位	南京大学
    
  • 构建测试数据集test.txt

    • 包含1个三元组
    浙江⼤学 位于市 杭州市
    

1.3 TransE算法及原理

1.3.1 原理

TransE将起始实体,关系,指向实体映射成同一空间的向量,如果(head,relation,tail)存在,那么

目标函数为:

1.3.2 算法

伪代码的意思是:

input: 输入模型的参数是训练集的三元组,实体集E,关系集L,margin,向量的维度k

1:初始化: 对于关系按照1的初始化方式初始化即可

2:这里进行了L2范数归一化,也就是除以自身的L2范数

3:同理,也对实体进行了初始化,但是这里没有除以自身的L2范数

4:训练的循环过程中:

5:首先对实体进行了L2范数归一化

6:取一个batch的样本,这里Sbatch代表的是正样本,也就是正确的三元组

7: 初始化三元组对,应该就是创造一个用于储存的列表

8,9,10:这里的意思应该是根据Sbatch的正样本替换头实体或者尾实体构造负样本,然后把对应的正样本三元组和负样本三元组放到一起,组成Tbatch

11:完成正负样本的提取

12:根据梯度下降更新向量

13:结束循环

(1)初始化

根据维度,为每个实体和关系初始化向量,并归一化

def emb_initialize(self):
        relation_dict = {}
        entity_dict = {}

        for relation in self.relation:
            r_emb_temp = np.random.uniform(-6 / math.sqrt(self.embedding_dim),
                                           6 / math.sqrt(self.embedding_dim),
                                           self.embedding_dim)
            relation_dict[relation] = r_emb_temp / np.linalg.norm(r_emb_temp, ord=2)

        for entity in self.entity:
            e_emb_temp = np.random.uniform(-6 / math.sqrt(self.embedding_dim),
                                           6 / math.sqrt(self.embedding_dim),
                                           self.embedding_dim)
            entity_dict[entity] = e_emb_temp / np.linalg.norm(e_emb_temp, ord=2)

(2)选取batch

设置nbatches为batch数目,batch_size = len(self.triple_list) // nbatches

从训练集中随机选择batch_size个三元组,并随机构成一个错误的三元组S',进行更新

def train(self, epochs):
        nbatches = 2
        batch_size = len(self.triple_list) // nbatches
        print("batch size: ", batch_size)
        for epoch in range(epochs):
            start = time.time()
            self.loss = 0

            # Sbatch:list
            Sbatch = random.sample(self.triple_list, batch_size)
            Tbatch = []

            for triple in Sbatch:
                corrupted_triple = self.Corrupt(triple)
                if (triple, corrupted_triple) not in Tbatch:
                    Tbatch.append((triple, corrupted_triple))
            self.update_embeddings(Tbatch)

(3)梯度下降

定义距离 来表示两个向量之间的距离,一般情况下,我们会取 ,或者

在这里,我们需要定义一个距离,对于正确的三元组 ,距离 越小越好;对于错误的三元组 ,距离 越小越好。

之后,使用梯度下降进行更新

1.3.3 训练结果

选择迭代次数 次,向量维度 ,学习率 进行训练(由于这个数据集太小,所以迭代次数不用设太大)

模型的参数保存在entity_50dim1relation_50dim1

1.3.4 链接预测

通过transE建模后,得到了每个实体关系的嵌入向量,利用嵌入向量,可以进行知识图谱的链接预测

将三元组(head,relation,tail)记为(h,r,t)

链接预测分为三类

  1. 头实体预测:(?,r,t)
  2. 关系预测:(h,?,t)
  3. 尾实体预测:(h,r,?)

原理很简单,利用向量的可加性即可实现。以 (h,r,?) 的预测为例:

假设t'=h+r,则在所有的实体中选择与t'距离最近的向量,即为t的的预测值

1.3.5 链接预测评价指标

1. MR :Mean rank 所有预测样本的平均排名

对于测试集的每个三元组,以预测tail实体为例,将 (h,r,t) 中的t用知识图谱中的每个实体来代替,然后通过distance(h, r, t)函数来计算距离,这样可以得到一系列的距离,之后按照升序将这些分数排列。

distance(h, r, t)函数值是越小越好,那么在上个排列中,排的越前越好。

然后去看每个三元组中正确答案也就是真实的t到底能在上述序列中排多少位,比如说t1排100,t2排200,t3排60.......,之后对这些排名求平均,mean rank就得到了。

2.Hit@10 所有预测样本中排名在10以内的比例

还是按照上述进行函数值排列,然后去看每个三元组正确答案是否排在序列的前十,如果在的话就计数+1

最终排在前十的个数/总个数 就是Hit@10

1.4 代码实现

1.4.1 模型训练

transE.py

import codecs
import copy
import math
import random
import time

import numpy as np

entity2id = {}
relation2id = {}
loss_ls = []

def data_loader(file):
    file1 = file + "train.txt"
    file2 = file + "entity2id.txt"
    file3 = file + "relation2id.txt"
    
    # 需要注意实体及关系名中是否包含中文
    with open(file2, 'r',encoding='UTF-8') as f1, open(file3, 'r',encoding='UTF-8') as f2:
        lines1 = f1.readlines()
        lines2 = f2.readlines()
        for line in lines1:
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            entity2id[line[0]] = line[1]

        for line in lines2:
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            relation2id[line[0]] = line[1]

    entity_set = set()
    relation_set = set()
    triple_list = []

    with codecs.open(file1, 'r',encoding='UTF-8') as f:
        content = f.readlines()
        for line in content:
            triple = line.strip().split("\t")
            if len(triple) != 3:
                continue

            # 根据具体的数据集调整h_ t_ r_

            # 题目中所给数据集对应的h_ t_ r_
            h_ = entity2id[triple[0]]
            t_ = entity2id[triple[2]]
            r_ = relation2id[triple[1]]

            # FB15K-237数据集对应的h_ t_ r_
            # h_ = entity2id[triple[0]]
            # t_ = entity2id[triple[1]]
            # r_ = relation2id[triple[2]]

            triple_list.append([h_, t_, r_])

            entity_set.add(h_)
            entity_set.add(t_)

            relation_set.add(r_)

    return entity_set, relation_set, triple_list


def distanceL2(h, r, t):
    # 为方便求梯度,去掉sqrt
    return np.sum(np.square(h + r - t))


def distanceL1(h, r, t):
    return np.sum(np.fabs(h + r - t))


class TransE:
    def __init__(self, entity_set, relation_set, triple_list,
                 embedding_dim=100, learning_rate=0.01, margin=1, L1=True):
        self.embedding_dim = embedding_dim
        self.learning_rate = learning_rate
        self.margin = margin
        self.entity = entity_set
        self.relation = relation_set
        self.triple_list = triple_list
        self.L1 = L1

        self.loss = 0

    def emb_initialize(self):
        relation_dict = {}
        entity_dict = {}

        for relation in self.relation:
            r_emb_temp = np.random.uniform(-6 / math.sqrt(self.embedding_dim),
                                           6 / math.sqrt(self.embedding_dim),
                                           self.embedding_dim)
            relation_dict[relation] = r_emb_temp / np.linalg.norm(r_emb_temp, ord=2)

        for entity in self.entity:
            e_emb_temp = np.random.uniform(-6 / math.sqrt(self.embedding_dim),
                                           6 / math.sqrt(self.embedding_dim),
                                           self.embedding_dim)
            entity_dict[entity] = e_emb_temp / np.linalg.norm(e_emb_temp, ord=2)

        self.relation = relation_dict
        self.entity = entity_dict

    def train(self, epochs):
        nbatches = 2 # 根据训练样本数给出合适的数值
        batch_size = len(self.triple_list) // nbatches
        print("batch size: ", batch_size)
        for epoch in range(epochs):
            start = time.time()
            self.loss = 0

            # Sbatch:list
            Sbatch = random.sample(self.triple_list, batch_size)
            Tbatch = []

            for triple in Sbatch:
                corrupted_triple = self.Corrupt(triple)
                if (triple, corrupted_triple) not in Tbatch:
                    Tbatch.append((triple, corrupted_triple))
            self.update_embeddings(Tbatch)

            end = time.time()
            print("epoch: ", epoch, "cost time: %s" % (round((end - start), 3)))
            print("loss: ", self.loss)
            loss_ls.append(self.loss)

            # 保存临时结果
            if epoch % 20 == 0:
                with codecs.open("entity_temp", "w") as f_e:
                    for e in self.entity.keys():
                        f_e.write(e + "\t")
                        f_e.write(str(list(self.entity[e])))
                        f_e.write("\n")
                with codecs.open("relation_temp", "w") as f_r:
                    for r in self.relation.keys():
                        f_r.write(r + "\t")
                        f_r.write(str(list(self.relation[r])))
                        f_r.write("\n")

        print("写入文件...")
        with codecs.open("entity_50dim", "w") as f1:
            for e in self.entity.keys():
                f1.write(e + "\t")
                f1.write(str(list(self.entity[e])))
                f1.write("\n")

        with codecs.open("relation_50dim", "w") as f2:
            for r in self.relation.keys():
                f2.write(r + "\t")
                f2.write(str(list(self.relation[r])))
                f2.write("\n")

        with codecs.open("loss", "w") as f3:
            f3.write(str(loss_ls))

        print("写入完成")

    def Corrupt(self, triple):
        corrupted_triple = copy.deepcopy(triple)
        seed = random.random()
        if seed > 0.5:
            # 替换head
            rand_head = triple[0]
            while rand_head == triple[0]:
                rand_head = random.sample(self.entity.keys(), 1)[0]
            corrupted_triple[0] = rand_head

        else:
            # 替换tail
            rand_tail = triple[1]
            while rand_tail == triple[1]:
                rand_tail = random.sample(self.entity.keys(), 1)[0]
            corrupted_triple[1] = rand_tail
        return corrupted_triple

    def update_embeddings(self, Tbatch):
        copy_entity = copy.deepcopy(self.entity)
        copy_relation = copy.deepcopy(self.relation)

        for triple, corrupted_triple in Tbatch:
            # 取copy里的vector累积更新
            h_correct_update = copy_entity[triple[0]]
            t_correct_update = copy_entity[triple[1]]
            relation_update = copy_relation[triple[2]]

            h_corrupt_update = copy_entity[corrupted_triple[0]]
            t_corrupt_update = copy_entity[corrupted_triple[1]]

            # 取原始的vector计算梯度
            h_correct = self.entity[triple[0]]
            t_correct = self.entity[triple[1]]
            relation = self.relation[triple[2]]

            h_corrupt = self.entity[corrupted_triple[0]]
            t_corrupt = self.entity[corrupted_triple[1]]

            if self.L1:
                dist_correct = distanceL1(h_correct, relation, t_correct)
                dist_corrupt = distanceL1(h_corrupt, relation, t_corrupt)
            else:
                dist_correct = distanceL2(h_correct, relation, t_correct)
                dist_corrupt = distanceL2(h_corrupt, relation, t_corrupt)

            err = self.hinge_loss(dist_correct, dist_corrupt)

            if err > 0:
                self.loss += err

                grad_pos = 2 * (h_correct + relation - t_correct)
                grad_neg = 2 * (h_corrupt + relation - t_corrupt)

                #  梯度计算
                if self.L1:
                    for i in range(len(grad_pos)):
                        if (grad_pos[i] > 0):
                            grad_pos[i] = 1
                        else:
                            grad_pos[i] = -1

                    for i in range(len(grad_neg)):
                        if (grad_neg[i] > 0):
                            grad_neg[i] = 1
                        else:
                            grad_neg[i] = -1

                #  梯度下降
                # head系数为正,减梯度;tail系数为负,加梯度
                h_correct_update -= self.learning_rate * grad_pos
                t_correct_update -= (-1) * self.learning_rate * grad_pos

                # corrupt项整体为负,因此符号与correct相反
                if triple[0] == corrupted_triple[0]:  # 若替换的是尾实体,则头实体更新两次
                    h_correct_update -= (-1) * self.learning_rate * grad_neg
                    t_corrupt_update -= self.learning_rate * grad_neg

                elif triple[1] == corrupted_triple[1]:  # 若替换的是头实体,则尾实体更新两次
                    h_corrupt_update -= (-1) * self.learning_rate * grad_neg
                    t_correct_update -= self.learning_rate * grad_neg

                # relation更新两次
                relation_update -= self.learning_rate * grad_pos
                relation_update -= (-1) * self.learning_rate * grad_neg

        # batch norm
        for i in copy_entity.keys():
            copy_entity[i] /= np.linalg.norm(copy_entity[i])
        for i in copy_relation.keys():
            copy_relation[i] /= np.linalg.norm(copy_relation[i])

        # 达到批量更新的目的
        self.entity = copy_entity
        self.relation = copy_relation

    def hinge_loss(self, dist_correct, dist_corrupt):
        return max(0, dist_correct - dist_corrupt + self.margin)


if __name__ == '__main__':
    file1 = "./data/"
    entity_set, relation_set, triple_list = data_loader(file1)
    print("load file...")
    print("Complete load. entity : %d , relation : %d , triple : %d" % (
        len(entity_set), len(relation_set), len(triple_list)))

    transE = TransE(entity_set, relation_set, triple_list, embedding_dim=50, learning_rate=0.01, margin=1, L1=True)
    transE.emb_initialize()
    transE.train(epochs=200)

运行结果:

损失函数变化曲线如下:

1.4.2 链接预测及评价

test.py

import codecs
import random

import numpy as np

entity2id = {}
relation2id = {}
entityId2vec = {}
relationId2vec = {}


def data_loader(file):
    file1 = file + "test.txt"
    file2 = file + "entity2id.txt"
    file3 = file + "relation2id.txt"

    with open(file2, 'r') as f1, open(file3, 'r') as f2:
        lines1 = f1.readlines()
        lines2 = f2.readlines()
        for line in lines1:
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            entity2id[line[0]] = line[1]

        for line in lines2:
            line = line.strip().split('\t')
            if len(line) != 2:
                continue
            relation2id[line[0]] = line[1]

    entity_set = set()
    relation_set = set()
    triple_list = []

    with codecs.open(file1, 'r') as f:
        content = f.readlines()
        for line in content:
            triple = line.strip().split("\t")
            if len(triple) != 3:
                continue

            h_ = entity2id[triple[0]]
            t_ = entity2id[triple[1]]
            r_ = relation2id[triple[2]]

            triple_list.append([h_, t_, r_])

            entity_set.add(h_)
            entity_set.add(t_)

            relation_set.add(r_)

    return entity_set, relation_set, triple_list


def transE_loader(file):
    # file1 = file + "entity_50dim_batch400"
    file1 = file + "entity_50dim"
    # file2 = file + "relation50dim_batch400"
    file2 = file + "relation_50dim"
    with codecs.open(file1, 'r') as f:
        content = f.readlines()
        for line in content:
            line = line.strip().split("\t")
            entityId2vec[line[0]] = eval(line[1])
    with codecs.open(file2, 'r') as f:
        content = f.readlines()
        for line in content:
            line = line.strip().split("\t")
            relationId2vec[line[0]] = eval(line[1])


def distance(h, r, t):
    h = np.array(h)
    r = np.array(r)
    t = np.array(t)
    s = h + r - t
    return np.linalg.norm(s)


def mean_rank(entity_set, triple_list):
    # triple_batch = random.sample(triple_list, 100)
    triple_batch = triple_list
    mean = 0
    hit10 = 0
    hit3 = 0
    for triple in triple_batch:
        dlist = []
        h = triple[0]
        t = triple[1]
        r = triple[2]
        dlist.append((t, distance(entityId2vec[h], relationId2vec[r], entityId2vec[t])))
        for t_ in entity_set:
            if t_ != t:
                dlist.append((t_, distance(entityId2vec[h], relationId2vec[r], entityId2vec[t_])))
        dlist = sorted(dlist, key=lambda val: val[1])
        for index in range(len(dlist)):
            if dlist[index][0] == t:
                mean += index + 1
                if index < 3:
                    hit3 += 1
                if index < 10:
                    hit10 += 1
                print(index)
                break
    print("mean rank:", mean / len(triple_batch))
    print("hit@3:", hit3 / len(triple_batch))
    print("hit@10:", hit10 / len(triple_batch))


if __name__ == '__main__':
    file1 = "./data/"
    print("load file...")
    entity_set, relation_set, triple_list = data_loader(file1)
    print("Complete load. entity : %d , relation : %d , triple : %d" % (
        len(entity_set), len(relation_set), len(triple_list)))
    print("load transE vec...")
    transE_loader("./")
    print("Complete load.")
    mean_rank(entity_set, triple_list)

运行结果:

由评价指标可知:预测正确

2. TransE知识图谱补全,FB15K-237数据集

2.1 根据具体的数据集,调整相应的参数即可

  • 调整
  • 调整
  • 传入对应数据集地址及调整

test.py也需要传入正确的数据集地址

2.2 运行结果

2.2.1 transE.py 运行结果

2.2.2 损失函数变化曲线

2.2.3 test.py 运行结果

2.2.4 分析

经过transE建模后,在测试集的13584个实体,961个关系的 59071个三元组中,测试结果如下:

mean rank: 7.591836734693878
hit@3: 0.5102040816326531
hit@10: 0.8163265306122449

一方面可以看出训练后的结果是有效的,但不是十分优秀,可能与transE模型的局限性有关,transE只能处理一对一的关系,不适合一对多/多对一关系。

3. 完整数据集及代码

Bridge619/KG_TransE (github)


参考:transE算法 简单实现 FB15k

文章不错,扫码支持一下吧~
上一篇 下一篇
评论