当前位置:网站首页>Pytorch实现Word2Vec
Pytorch实现Word2Vec
2022-07-22 07:18:00 【linxizi0622】
# !/usr/bin/env Python3
# -*- coding: utf-8 -*-
# @version: v1.0
# @Author : Meng Li
# @contact: [email protected]
# @FILE : torch_word2vec.py
# @Time : 2022/7/21 14:12
# @Software : PyCharm
# @site:
# @Description : 自己实现的基于skip-gram算法的Word2Vec预训练语言模型
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
sentences = ["jack like dog", "jack like cat", "jack like animal",
"dog cat animal", "banana apple cat dog like", "dog fish milk like",
"dog cat animal like", "jack like apple", "apple like", "jack like banana",
"apple banana jack movie book music like", "cat dog hate", "cat dog like"]
sentences_list = "".join([i for i in sentences]).split()
vocab = list(set(sentences_list))
word2idx = {j: i for i, j in enumerate(vocab)}
idx2word = {i: j for i, j in enumerate(vocab)}
vocab_size = len(vocab)
window_size = 2
embedding_size = 2
def make_data(seq_data):
context_arr = []
center = []
context = []
skip_gram = []
seq_data = "".join([i for i in seq_data]).split()
for sen in seq_data:
for step in range(window_size, len(sen) - window_size):
center = step
context_arr = list(range(step - window_size, step)) + list(range(step + 1, step + window_size))
for context_i in context_arr:
skip_gram.append([np.eye(vocab_size)[word2idx[seq_data[center]]], context_i])
input_data = []
target_data = []
for a, b in skip_gram:
input_data.append(a)
target_data.append(b)
return torch.FloatTensor(input_data), torch.LongTensor(target_data)
class my_dataset(Dataset):
def __init__(self, input_data, target_data):
super(my_dataset, self).__init__()
self.input_data = input_data
self.target_data = target_data
def __getitem__(self, index):
return self.input_data[index], self.target_data[index]
def __len__(self):
return self.input_data.size(0) # 返回张量的第一个维度
class SkipGram(nn.Module):
def __init__(self, embedding_size):
super(SkipGram, self).__init__()
self.embedding_size = embedding_size
self.fc1 = torch.nn.Linear(vocab_size, self.embedding_size)
self.fc2 = torch.nn.Linear(self.embedding_size, vocab_size)
self.loss = nn.CrossEntropyLoss()
def forward(self, center, context):
"""
:param center: [Batch_size]
:param context:[Batch_size, vocab_size]
:return:
"""
center = self.fc1(center)
center = self.fc2(center)
loss = self.loss(center, context)
return loss
batch_size = 2
center_data, context_data = make_data(sentences)
train_data = my_dataset(center_data, context_data)
train_iter = DataLoader(train_data, batch_size, shuffle=True)
epochs = 2000
model = SkipGram(embedding_size=embedding_size)
model.train()
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
for center, context in train_iter:
loss = model(center, context)
if epoch % 100 == 0:
print("step {0} loss {1}".format(epoch, loss.detach().numpy()))
optim.zero_grad()
loss.backward()
optim.step()
基于Pytorch实现的Word2vec预训练语言模型
边栏推荐
- Can flick SQL query Clickhouse
- Implementation of MATLAB mixer
- [must see for developers] [push kit] collection of typical problems of push service 1
- [HMS core] [FAQ] [account kit] typical problem set 2
- 备战攻防演练,这里有一张腾讯安全重保布防图!
- How does win11 close the touch pad? Three solutions for closing the touch panel in win11
- subprocess
- 【HMS core】【FAQ】In-App Purchases 常见问题分享
- 网络之物理层
- 为什么有些参数reload就可以生效,而有些参数必须重启数据库?
猜你喜欢
Virtual machine performance test scheme
家庭琐事问题
【HMS core】【FAQ】【Account Kit】典型问题集2
(11) 51 Single Chip Microcomputer -- realize the storage of stopwatch data with AT24C02 (attached with achievement display)
继承的详解
buu-misc进阶
Critical path problem
Nightmare of concurrent programs -- data competition
【如何优化她】教你如何定位不合理的SQL?并优化她~~~
mysql查询中能否同时判断多个字段的值
随机推荐
Critical path implementation
离线安装vscode
Chery Xingtu's product plan was exposed, and the 2.0T turbocharged engine was launched at the end of the year
指令安排问题
mysql查询中能否同时判断多个字段的值
[network counting] (III) hypernetwork, routing, NAT protocol
win11怎么关闭触控板?win11关闭触控板的三种解决方法
MySQL addition, deletion, modification and query (Advanced)
Conference OA project introduction & Conference release
(11) 51 Single Chip Microcomputer -- realize the storage of stopwatch data with AT24C02 (attached with achievement display)
继承的详解
Detailed explanation of inheritance
SSRF vulnerability attack intranet redis recurrence
DOM operation of JS -- event proxy
【微服务~远程调用】整合RestTemplate、WebClient、Feign
bokeh参数设置详解
Map insert element
Hcip OSPF interface network type experiment report
postgreSQL数据库部署在linux服务器上,本机查询ms级,用windows上安装的pgadmin查询超级慢20s左右,是网络的问题还是数据库配置问题?
云主机性能测试方案