173 lines
11 KiB
Plaintext
173 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from torch_speech_model import *"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "FileNotFoundError",
|
|
"evalue": "[Errno 2] No such file or directory: 'datalist/primewords/test.wav.lst'",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn [2], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdata_loader\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m DataLoader\n\u001b[1;32m 3\u001b[0m feat \u001b[38;5;241m=\u001b[39m SpecAugment()\n\u001b[0;32m----> 4\u001b[0m data_loader \u001b[38;5;241m=\u001b[39m \u001b[43mDataLoader\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtest\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m speechdata \u001b[38;5;241m=\u001b[39m SpeechDataset(data_loader, feat, input_shape\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m1600\u001b[39m,\u001b[38;5;241m200\u001b[39m,\u001b[38;5;241m1\u001b[39m),max_label_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m64\u001b[39m,device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmps\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
|
"File \u001b[0;32m~/PycharmProjects/Open_Source/ASRT_SpeechRecognition/data_loader.py:48\u001b[0m, in \u001b[0;36mDataLoader.__init__\u001b[0;34m(self, dataset_type)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpinyin_list \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m()\n\u001b[1;32m 47\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mpinyin_dict \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mdict\u001b[39m()\n\u001b[0;32m---> 48\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_load_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m~/PycharmProjects/Open_Source/ASRT_SpeechRecognition/data_loader.py:58\u001b[0m, in \u001b[0;36mDataLoader._load_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 56\u001b[0m filename_datalist \u001b[38;5;241m=\u001b[39m config[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdataset\u001b[39m\u001b[38;5;124m'\u001b[39m][\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset_type][index][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata_list\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[1;32m 57\u001b[0m filename_datapath \u001b[38;5;241m=\u001b[39m config[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdataset\u001b[39m\u001b[38;5;124m'\u001b[39m][\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset_type][index][\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mdata_path\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m---> 58\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfilename_datalist\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mr\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mutf-8\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m file_pointer:\n\u001b[1;32m 59\u001b[0m lines \u001b[38;5;241m=\u001b[39m file_pointer\u001b[38;5;241m.\u001b[39mread()\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m'\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 60\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m line \u001b[38;5;129;01min\u001b[39;00m lines:\n",
|
|
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'datalist/primewords/test.wav.lst'"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from speech_features import SpecAugment\n",
|
|
"from data_loader import DataLoader\n",
|
|
"feat = SpecAugment()\n",
|
|
"data_loader = DataLoader('test')\n",
|
|
"speechdata = SpeechDataset(data_loader, feat, input_shape=(1600,200,1),max_label_length=64,device='mps')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from model_zoo.speech_model.pytorch_backend import SpeechModel251BN\n",
|
|
"\n",
|
|
"model = SpeechModel251BN()\n",
|
|
"speechModel = ModelSpeech(model, feat, max_label_length=64)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"SpeechModel251BN(\n",
|
|
" (conv0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
|
" (bn0): BatchNorm2d(32, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
|
" (bn1): BatchNorm2d(32, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
|
" (bn2): BatchNorm2d(64, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
|
" (bn3): BatchNorm2d(64, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
|
" (bn4): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (conv5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
|
" (bn5): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (conv6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
|
" (bn6): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (conv7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
|
" (bn7): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (conv8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
|
" (bn8): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (conv9): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)\n",
|
|
" (bn9): BatchNorm2d(128, eps=0.0002, momentum=0.1, affine=True, track_running_stats=True)\n",
|
|
" (dense0): Linear(in_features=3200, out_features=128, bias=True)\n",
|
|
" (dense1): Linear(in_features=128, out_features=1428, bias=True)\n",
|
|
" (ctc_loss): CTCLoss()\n",
|
|
")"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[ASRT] torch model successfully initialized to device: cpu\n",
|
|
"[ASRT] Epoch 1/5\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/Users/filianore/PycharmProjects/Open_Source/ASRT_SpeechRecognition/model_zoo/speech_model/pytorch_backend.py:89: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
|
" x = F.softmax(self.dense1(x))\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from torch import optim\n",
|
|
"\n",
|
|
"speechModel.train(speechdata, epochs=5, batch_size=32, optimizer=optim.Adam(model.parameters(), lr=0.0001), save_step=1, last_epoch=10)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "TypeError",
|
|
"evalue": "'module' object is not callable",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn [7], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_model\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mAdam\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparameters\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0.0001\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mspeechdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m~/PycharmProjects/Open_Source/ASRT_SpeechRecognition/model_zoo/speech_model/pytorch_backend.py:100\u001b[0m, in \u001b[0;36mSpeechModel251BN.train_model\u001b[0;34m(self, train_loader, optimizer, num_epochs, device)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m 98\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain()\n\u001b[0;32m--> 100\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[43mtqdm\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mnum_epochs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 101\u001b[0m epoch_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.0\u001b[39m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m batch \u001b[38;5;129;01min\u001b[39;00m train_loader:\n",
|
|
"\u001b[0;31mTypeError\u001b[0m: 'module' object is not callable"
|
|
]
|
|
}
|
|
],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "torch",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|