




版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進行舉報或認領(lǐng)
文檔簡介
《自然語言處理技術(shù)》代碼33所示。代碼STYLEREF1\s3SEQ代碼\*ARABIC\s13模型構(gòu)建importtorchimporttorch.nnasnn#定義LSTM模型classLSTMModel(nn.Module):def__init__(self,vocab_size,embed_size,hidden_size,num_classes):super(LSTMModel,self).__init__()#定義詞嵌入層self.embedding=nn.Embedding(vocab_size,embed_size)#定義LSTM層self.lstm=nn.LSTM(embed_size,hidden_size,batch_first=True)#定義全連接層self.fc=nn.Linear(hidden_size,num_classes)defforward(self,x):#將輸入的詞語ID序列轉(zhuǎn)化為詞嵌入向量embedded=self.embedding(x)#將詞嵌入向量輸入到LSTM層中out,_=self.lstm(embedded)#將LSTM層的輸出傳入全連接層out=self.fc(out[:,-1,:])returnout#設(shè)置模型的參數(shù)vocab_size=len(vocab)#詞匯表的大小embed_size=100#詞嵌入向量的維度hidden_size=128#LSTM層的隱藏狀態(tài)的維度num_classes=2#情感分類的類別數(shù)#創(chuàng)建LSTM模型model=LSTMModel(vocab_size,embed_size,hidden_size,num_classes)在構(gòu)建基于LSTM的情感分類模型時,使用了3個關(guān)鍵的神經(jīng)網(wǎng)絡層,主要為詞嵌入層、LSTM層和全連接層,設(shè)置3個神經(jīng)網(wǎng)絡層的函數(shù)及其參數(shù)說明如REF_Ref132317559\h表31所示。表STYLEREF1\s3SEQ表\*ARABIC\s11LSTM模型中的3個關(guān)鍵層函數(shù)說明函數(shù)說明參數(shù)nn.Embedding詞嵌入層。將輸入的詞語ID序列轉(zhuǎn)換為詞嵌入向量vocab_size:詞匯表大小embed_size:詞嵌入向量維度nn.LSTMLSTM層。用于處理序列數(shù)據(jù),捕捉長期依賴關(guān)系embed_size:詞嵌入向量維度hidden_size:隱藏狀態(tài)維度batch_first:指定輸入張量的形狀中是否將批處理大小放在第一個維度nn.Linear全連接層。將LSTM層的輸出轉(zhuǎn)換為最終的類別預測hidden_size:隱藏狀態(tài)維度num_classes:情感分類類別數(shù)模型訓練針對搭建好的模型,設(shè)置模型超參數(shù)并進行訓練設(shè)置,其中,損失值設(shè)為CrossEntropyLoss、優(yōu)化器設(shè)為“Adam”、每批訓練數(shù)據(jù)個數(shù)設(shè)為32、總迭代輪次設(shè)為10,如REF_Ref131514844\h代碼34所示。代碼STYLEREF1\s3SEQ代碼\*ARABIC\s14模型訓練importtorch.optimasoptimfromtorch.utils.dataimportDataset,DataLoaderfromtorch.utils.dataimportTensorDataset,RandomSampler,DataLoader,SequentialSampler#定義自定義數(shù)據(jù)集類classSentimentDataset(Dataset):def__init__(self,data,labels):self.data=dataself.labels=labelsdef__getitem__(self,index):returnself.data[index],self.labels[index]def__len__(self):returnlen(self.data)#設(shè)置超參數(shù)batch_size=32#每次訓練的數(shù)據(jù)批次大小num_epochs=10#迭代次數(shù)#初始化數(shù)據(jù)集和數(shù)據(jù)加載器train_data=TensorDataset(torch.LongTensor(input_ids_train),torch.LongTensor(y_train))#隨機采樣器用于在訓練數(shù)據(jù)中隨機選取樣本組成一個batchtrain_sampler=RandomSampler(train_data)train_loader=DataLoader(train_data,sampler=train_sampler,batch_size=batch_size)test_data=TensorDataset(torch.LongTensor(input_ids_test),torch.LongTensor(y_test))#順序采樣器用于在測試數(shù)據(jù)中按順序選取樣本組成一個batchtest_sampler=SequentialSampler(test_data)test_loader=DataLoader(test_data,sampler=test_sampler,batch_size=batch_size)#初始化模型、損失函數(shù)和優(yōu)化器model=LSTMModel(vocab_size,embed_size,hidden_size,num_classes)criterion=nn.CrossEntropyLoss()#交叉熵損失函數(shù)optimizer=optim.Adam(model.parameters(),lr=0.001)#Adam優(yōu)化器#訓練模型total_step=len(train_loader)#計算總的訓練步數(shù)forepochinrange(num_epochs):fori,(data,labels)inenumerate(train_loader):#正向傳播outputs=model(data)loss=criterion(outputs,labels.squeeze())#反向傳播和優(yōu)化optimizer.zero_grad()loss.backward()optimizer.step()#打印狀態(tài)信息if(i+1)%100==0:print('Epoch[{}/{}],Step[{}/{}],Loss:{:.4f}'.format(epoch+1,num_epochs,i+1,total_step,loss.item()))在訓練深度學習模型時,通常需要將數(shù)據(jù)集分成多個批次進行訓練,在此過程中使用到TensorDataset和DataLoader函數(shù)來實現(xiàn)對數(shù)據(jù)集的處理。TensorDataset和DataLoader函數(shù)的作用及其參數(shù)說明如REF_Ref132317590\h表32所示。表STYLEREF1\s3SEQ表\*ARABIC\s12TensorDataset和DataLoader函數(shù)的作用及其參數(shù)說明函數(shù)說明參數(shù)TensorDataset將輸入數(shù)據(jù)轉(zhuǎn)換為適合訓練的數(shù)據(jù)集格式torch.LongTensor(input_ids_train):訓練數(shù)據(jù)的輸入ID序列torch.LongTensor(y_train):訓練數(shù)據(jù)的標簽DataLoader根據(jù)指定的采樣器和批次大小創(chuàng)建數(shù)據(jù)加載器,用于在訓練過程中加載數(shù)據(jù)train_data:TensorDataset格式的訓練數(shù)據(jù)集train_sampler:數(shù)據(jù)采樣器batch_size:每個批次的樣本數(shù)量同時,在訓練深度學習模型時,還需要選擇一個優(yōu)化器來更新模型參數(shù)。此代碼中使用到Adam優(yōu)化器。Adam優(yōu)化器的常用參數(shù)說明如REF_Ref132317633\h表33所示。表STYLEREF1\s3SEQ表\*ARABIC\s13Adam優(yōu)化器的常用參數(shù)說明參數(shù)名稱參數(shù)說明parameters接收list,表示模型參數(shù),用于更新模型參數(shù),多用于優(yōu)化器的初始化,無默認值lr接收float,表示學習率,又稱為步長因子,控制了權(quán)重的更新比率,較大的值更新前會有更快的初始學習,而較小的值會令訓練收斂到更好的性能。默認為1e-3運行模型訓練代碼后,訓練過程輸出結(jié)果如下。Epoch[1/10],Step[100/195],Loss:0.4767Epoch[2/10],Step[100/195],Loss:0.5186Epoch[3/10],Step[100/195],Loss:0.5767Epoch[4/10],Step[100/195],Loss:0.3618Epoch[5/10],Step[100/195],Loss:0.1925Epoch[6/10],Step[100/195],Loss:0.0633Epoch[7/10],Step[100/195],Loss:0.0628Epoch[8/10],Step[100/195],Loss:0.2164Epoch[9/10],Step[100/195],Loss:0.0051Epoch[10/10],Step[100/195],Loss:0.0053從模型訓練代碼的運行結(jié)果可以看出,模型在測試集上的損失值隨著模型訓練次數(shù)的增加,該值整體呈下降趨勢。模型測試對訓練好的模型使用測試數(shù)據(jù)進行測試,輸出模型的評價結(jié)果,如REF_Ref131514956\h代碼35所示。代碼STYLEREF1\s3SEQ代碼\*ARABIC\s15模型測試fromsklearnimportmetrics#使用測試數(shù)據(jù)進行模型測試model.eval()predict=[]withtorch.no_grad():fordata,labelsintest_loader:outputs=model(data)_,predicted=torch.max(outputs.data,1)predict.extend(predicted.detach().numpy())#模型評價acc=metrics.accuracy_score(y_test,predict)print('測試集的準確率為:',acc)print('精確率、召回率、F1值分別為:')print(metrics.classification_report(y_test,predict))print('混淆矩陣為:')cm=metrics.confusion_matrix(y_test,predict)#混淆矩陣print(cm)在評估分類模型的性能時,混淆矩陣是一個很有用的工具。metrics.confusion_matrix函數(shù)可以計算一個分類模型的混淆矩陣,可以展示各類別之間的實際值與預測值之間的關(guān)系,該函數(shù)的常用參數(shù)說明如REF_Ref132317915\h表34所示。表STYLEREF1\s3SEQ表\*ARABIC\s14metrics.confusion_matrix函數(shù)的常用參數(shù)說明參數(shù)名稱參數(shù)說明y_true接收numpy數(shù)組或list,表示真實的類別標簽,通常是一個一維數(shù)組或列表。無默認值y_pred接收numpy數(shù)組或list,表示預測的類別標簽,通常是一個一維數(shù)組或列表。無默認值在模型測試代碼,輸出模型的評價結(jié)果如下。測試集的準確率為:0.819703799098519精確率、召回率、F1值分別為:precisionrecallf1-scoresupport00.710.770.7451310.880.850.861040accuracy0.821553macroavg0.800.810.801553weig
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預覽,若沒有圖紙預覽就沒有圖紙。
- 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負責。
- 6. 下載文件中如有侵權(quán)或不適當內(nèi)容,請與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準確性、安全性和完整性, 同時也不承擔用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。
最新文檔
- 印刷機安全操作風險預警系統(tǒng)研究考核試卷
- 農(nóng)機租賃市場租賃政策影響考核試卷
- 團隊協(xié)作意識考核試卷
- 公路客運企業(yè)績效與激勵制度的區(qū)域差異性研究考核試卷
- 企業(yè)成本效益分析模型構(gòu)建考核試卷
- 保險業(yè)務風險損失評估考核試卷
- 化學基礎(chǔ)知識 大單元整合(含答案)-2026屆高三一輪復習學案
- 期末綜合試題-2024-2025學年數(shù)學人教版四年級下冊
- 護理理論知識試題+答案
- 2020年成人高考高起專語文現(xiàn)代文寫作鞏固
- 骨科VTE的預防及護理
- 工貿(mào)行業(yè)重大事故隱患判定標準安全試題及答案
- 2025年山東威海中考數(shù)學試卷真題及答案詳解(精校打印版)
- 2025年中國環(huán)烷基變壓器油行業(yè)市場調(diào)查、投資前景及策略咨詢報告
- 新生兒甲狀腺低下及護理
- 2025年全國新高考I卷高考全國一卷真題語文試卷(真題+答案)
- 信息費合同協(xié)議書范本
- 超市外租區(qū)租賃合同3篇
- 辦公樓裝修施工組織機構(gòu)及管理措施
- T/CMES 37005-2023滑道運營管理規(guī)范
- 催收機構(gòu)運營管理制度
評論
0/150
提交評論