ASRT_SpeechRecognition/test_train.ipynb

173 lines
11 KiB
Plaintext
Raw Normal View History

2024-09-24 15:05:44 +08:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from torch_speech_model import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'datalist/primewords/test.wav.lst'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [2], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdata_loader\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DataLoader\n\u001b[1;32m 3\u001b[0m feat \u001b[38;5;241m=\u001b[39m SpecAugment()\n\u001b[0;32m----> 4\u001b[0m data_loader \u001b[38;5;241m=\u001b[39m \u001b[43mDataLoader\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtest\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m speechdata \u001b[38;5;241m=\u001b[39m SpeechDataset(data_loader, feat, input_shape\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m1600\u001b[39m,\u001b[38;5;241m200\u001b[39m,\u001b[38;5;241m1\u001b[39m),max_label_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m,device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmps\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
"File \u001b[0;32m~/PycharmProjects/Open_Source/ASRT_SpeechRecognition/data_loader.py:48\u001b[0m, in \u001b[0;36mDataLoader.__init__\u001b[0;34m(self, dataset_type)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpinyin_list \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m()\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpinyin_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m()\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_load_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/PycharmProjects/Open_Source/ASRT_SpeechRecognition/data_loader.py:58\u001b[0m, in \u001b[0;36mDataLoader._load_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 56\u001b[0m filename_datalist \u001b[38;5;241m=\u001b[39m config[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdataset\u001b[39m\u001b[38;5;124m'\u001b[39m][\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset_type][index][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata_list\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 57\u001b[0m filename_datapath \u001b[38;5;241m=\u001b[39m config[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdataset\u001b[39m\u001b[38;5;124m'\u001b[39m][\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset_type][index][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata_path\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m---> 58\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfilename_datalist\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mutf-8\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m file_pointer:\n\u001b[1;32m 59\u001b[0m lines \u001b[38;5;241m=\u001b[39m file_pointer\u001b[38;5;241m.\u001b[39mread()\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m line \u001b[38;5;129;01min\u001b[39;00m lines:\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'datalist/primewords/test.wav.lst'"
]
}
],
"source": [
"from speech_features import SpecAugment\n",
"from data_loader import DataLoader\n",
"feat = SpecAugment()\n",
"data_loader = DataLoader('test')\n",
"speechdata = SpeechDataset(data_loader, feat, input_shape=(1600,200,1),max_label_length=64,device='mps')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from model_zoo.speech_model.pytorch_backend import SpeechModel251BN\n",
"\n",
"model = SpeechModel251BN()\n",
"speechModel = ModelSpeech(model, feat, max_label_length=64)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"SpeechModel251BN(\n",
" (conv0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (bn0): BatchNorm2d(32, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (bn1): BatchNorm2d(32, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (bn2): BatchNorm2d(64, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (bn3): BatchNorm2d(64, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (bn4): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (bn5): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (bn6): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (bn7): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (bn8): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
" (conv9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
" (bn9): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
" (dense0): Linear(in_features=3200, out_features=128, bias=True)\n",
" (dense1): Linear(in_features=128, out_features=1428, bias=True)\n",
" (ctc_loss): CTCLoss()\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ASRT] torch model successfully initialized to device: cpu\n",
"[ASRT] Epoch 1/5\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/filianore/PycharmProjects/Open_Source/ASRT_SpeechRecognition/model_zoo/speech_model/pytorch_backend.py:89: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" x = F.softmax(self.dense1(x))\n"
]
}
],
"source": [
"from torch import optim\n",
"\n",
"speechModel.train(speechdata, epochs=5, batch_size=32, optimizer=optim.Adam(model.parameters(), lr=0.0001), save_step=1, last_epoch=10)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "'module' object is not callable",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn [7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mAdam\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.0001\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mspeechdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/PycharmProjects/Open_Source/ASRT_SpeechRecognition/model_zoo/speech_model/pytorch_backend.py:100\u001b[0m, in \u001b[0;36mSpeechModel251BN.train_model\u001b[0;34m(self, train_loader, optimizer, num_epochs, device)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 98\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[0;32m--> 100\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[43mtqdm\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 101\u001b[0m epoch_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m train_loader:\n",
"\u001b[0;31mTypeError\u001b[0m: 'module' object is not callable"
]
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "torch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}