{"id":1852,"date":"2025-02-24T16:55:40","date_gmt":"2025-02-24T08:55:40","guid":{"rendered":"https:\/\/www.forillusion.com\/?p=1852"},"modified":"2025-02-24T16:55:41","modified_gmt":"2025-02-24T08:55:41","slug":"6-4-rnn-scratch","status":"publish","type":"post","link":"https:\/\/www.forillusion.com\/index.php\/6-4-rnn-scratch\/","title":{"rendered":"6.4 \u5faa\u73af\u795e\u7ecf\u7f51\u7edc\u7684\u4ece\u96f6\u5f00\u59cb\u5b9e\u73b0"},"content":{"rendered":"\n<p><div class=\"has-toc have-toc\"><\/div><\/p>\n\n\n\n<p>\u8bfb\u53d6\u5468\u6770\u4f26\u4e13\u8f91\u6b4c\u8bcd\u6570\u636e\u96c6\uff1a<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import time\nimport math\nimport numpy as np\nimport torch\nfrom torch import nn, optim\nimport torch.nn.functional as F\nimport zipfile\nimport random\n\n# import d2lzh_pytorch as d2l\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nprint(device)\n\ndef load_data_jay_lyrics():\n    \"\"\"\u52a0\u8f7d\u5468\u6770\u4f26\u6b4c\u8bcd\u6570\u636e\u96c6\"\"\"\n    with zipfile.ZipFile('\u673a\u5668\u5b66\u4e60\/data\/jaychou_lyrics.txt.zip') as zin:\n        with zin.open('jaychou_lyrics.txt') as f:\n            corpus_chars = f.read().decode('utf-8')\n    corpus_chars = corpus_chars.replace('\\n', ' ').replace('\\r', ' ')\n    # corpus_chars = corpus_chars&#91;0:10000]\n    idx_to_char = list(set(corpus_chars))\n    char_to_idx = dict(&#91;(char, i) for i, char in enumerate(idx_to_char)])\n    vocab_size = len(char_to_idx)\n    corpus_indices = &#91;char_to_idx&#91;char] for char in corpus_chars]\n    return corpus_indices, char_to_idx, idx_to_char, vocab_size\n\n(corpus_indices, char_to_idx, idx_to_char, vocab_size) = load_data_jay_lyrics()<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">one-hot\u5411\u91cf<\/h2>\n\n\n\n<p>\u5047\u8bbe\u8bcd\u5178\u4e2d\u4e0d\u540c\u5b57\u7b26\u7684\u6570\u91cf\u4e3a$N$\uff08\u5373\u8bcd\u5178\u5927\u5c0f<code>vocab_size<\/code>\uff09\uff0c\u6bcf\u4e2a\u5b57\u7b26\u5df2\u7ecf\u540c\u4e00\u4e2a\u4ece0\u5230$N-1$\u7684\u8fde\u7eed\u6574\u6570\u503c\u7d22\u5f15\u4e00\u4e00\u5bf9\u5e94\u3002\u5982\u679c\u4e00\u4e2a\u5b57\u7b26\u7684\u7d22\u5f15\u662f\u6574\u6570$i$, \u90a3\u4e48\u521b\u5efa\u4e00\u4e2a\u51680\u7684\u957f\u4e3a$N$\u7684\u5411\u91cf\uff0c\u5e76\u5c06\u5176\u4f4d\u7f6e\u4e3a$i$\u7684\u5143\u7d20\u8bbe\u62101\u3002\u8be5\u5411\u91cf\u5c31\u662f\u5bf9\u539f\u5b57\u7b26\u7684one-hot\u5411\u91cf\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def one_hot(x, n_class, dtype=torch.float32): \n    # X shape: (batch), output shape: (batch, n_class)\n    x = x.long()\n    res = torch.zeros(x.shape&#91;0], n_class, dtype=dtype, device=x.device)\n    res.scatter_(1, x.view(-1, 1), 1)\n    return res<\/code><\/pre>\n\n\n\n<p>\u6bcf\u6b21\u91c7\u6837\u7684\u5c0f\u6279\u91cf\u7684\u5f62\u72b6\u662f(\u6279\u91cf\u5927\u5c0f, \u65f6\u95f4\u6b65\u6570)\u3002\u4e0b\u9762\u7684\u51fd\u6570\u5c06\u8fd9\u6837\u7684\u5c0f\u6279\u91cf\u53d8\u6362\u6210\u6570\u4e2a\u53ef\u4ee5\u8f93\u5165\u8fdb\u7f51\u7edc\u7684\u5f62\u72b6\u4e3a(\u6279\u91cf\u5927\u5c0f, \u8bcd\u5178\u5927\u5c0f)\u7684\u77e9\u9635\uff0c\u77e9\u9635\u4e2a\u6570\u7b49\u4e8e\u65f6\u95f4\u6b65\u6570\u3002\u4e5f\u5c31\u662f\u8bf4\uff0c\u65f6\u95f4\u6b65$t$\u7684\u8f93\u5165\u4e3a$\\boldsymbol{X}_t \\in \\mathbb{R}^{n \\times d}$\uff0c\u5176\u4e2d$n$\u4e3a\u6279\u91cf\u5927\u5c0f\uff0c$d$\u4e3a\u8f93\u5165\u4e2a\u6570\uff0c\u5373one-hot\u5411\u91cf\u957f\u5ea6\uff08\u8bcd\u5178\u5927\u5c0f\uff09\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def to_onehot(X, n_class):  \n    # X shape: (batch, seq_len), output: seq_len elements of (batch, n_class)\n    return &#91;one_hot(X&#91;:, i], n_class) for i in range(X.shape&#91;1])]<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">\u521d\u59cb\u5316\u6a21\u578b\u53c2\u6570<\/h2>\n\n\n\n<p>\u9690\u85cf\u5355\u5143\u4e2a\u6570 <code>num_hiddens<\/code>\u662f\u4e00\u4e2a\u8d85\u53c2\u6570\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size\nprint('will use', device)\n\ndef get_params():\n    def _one(shape):\n        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)\n        return torch.nn.Parameter(ts, requires_grad=True)\n\n    # \u9690\u85cf\u5c42\u53c2\u6570\n    W_xh = _one((num_inputs, num_hiddens))\n    W_hh = _one((num_hiddens, num_hiddens))\n    b_h = torch.nn.Parameter(torch.zeros(num_hiddens, device=device, requires_grad=True))\n    # \u8f93\u51fa\u5c42\u53c2\u6570\n    W_hq = _one((num_hiddens, num_outputs))\n    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, requires_grad=True))\n    return nn.ParameterList(&#91;W_xh, W_hh, b_h, W_hq, b_q])<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">\u5b9a\u4e49\u6a21\u578b<\/h2>\n\n\n\n<p>\u6839\u636e\u5faa\u73af\u795e\u7ecf\u7f51\u7edc\u7684\u8ba1\u7b97\u8868\u8fbe\u5f0f\u5b9e\u73b0\u8be5\u6a21\u578b\u3002\u9996\u5148\u5b9a\u4e49<code>init_rnn_state<\/code>\u51fd\u6570\u6765\u8fd4\u56de\u521d\u59cb\u5316\u7684\u9690\u85cf\u72b6\u6001\u3002\u5b83\u8fd4\u56de\u7531\u4e00\u4e2a\u5f62\u72b6\u4e3a(\u6279\u91cf\u5927\u5c0f, \u9690\u85cf\u5355\u5143\u4e2a\u6570)\u7684\u503c\u4e3a0\u7684<code>NDArray<\/code>\u7ec4\u6210\u7684\u5143\u7ec4\u3002\u4f7f\u7528\u5143\u7ec4\u662f\u4e3a\u4e86\u66f4\u4fbf\u4e8e\u5904\u7406\u9690\u85cf\u72b6\u6001\u542b\u6709\u591a\u4e2a<code>NDArray<\/code>\u7684\u60c5\u51b5\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def init_rnn_state(batch_size, num_hiddens, device):\n    return (torch.zeros((batch_size, num_hiddens), device=device), )<\/code><\/pre>\n\n\n\n<p>\u4e0b\u9762\u7684<code>rnn<\/code>\u51fd\u6570\u5b9a\u4e49\u4e86\u5728\u4e00\u4e2a\u65f6\u95f4\u6b65\u91cc\u5982\u4f55\u8ba1\u7b97\u9690\u85cf\u72b6\u6001\u548c\u8f93\u51fa\u3002\u8fd9\u91cc\u7684\u6fc0\u6d3b\u51fd\u6570\u4f7f\u7528\u4e86tanh\u51fd\u6570\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def rnn(inputs, state, params):\n    # inputs\u548coutputs\u7686\u4e3anum_steps\u4e2a\u5f62\u72b6\u4e3a(batch_size, vocab_size)\u7684\u77e9\u9635\n    W_xh, W_hh, b_h, W_hq, b_q = params\n    H, = state\n    outputs = &#91;]\n    for X in inputs: # \u6bcf\u6b21\u540c\u65f6\u5904\u7406\u6240\u6709\u6279\u91cf\u7684\u540c\u4e00\u4e2a\u65f6\u95f4\u6b65\n        # print(X)\n        H = torch.tanh(torch.matmul(X, W_xh) + torch.matmul(H, W_hh) + b_h)\n        Y = torch.matmul(H, W_hq) + b_q\n        outputs.append(Y)\n    return outputs, (H,)<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">\u5b9a\u4e49\u9884\u6d4b\u51fd\u6570<\/h2>\n\n\n\n<p>\u4ee5\u4e0b\u51fd\u6570\u57fa\u4e8e\u524d\u7f00<code>prefix<\/code>\uff08\u542b\u6709\u6570\u4e2a\u5b57\u7b26\u7684\u5b57\u7b26\u4e32\uff09\u6765\u9884\u6d4b\u63a5\u4e0b\u6765\u7684<code>num_chars<\/code>\u4e2a\u5b57\u7b26\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>def predict_rnn(prefix, num_chars, rnn, params, init_rnn_state,\n                num_hiddens, vocab_size, device, idx_to_char, char_to_idx):\n    # prefix: \u957f\u5ea6\u4e3aprefix\u7684\u5b57\u7b26\u4e32\n    # num_chars: \u9884\u6d4b\u540e\u7eed\u7684\u5b57\u7b26\u4e2a\u6570\n    # rnn: \u5faa\u73af\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\n    # params: \u5faa\u73af\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\u7684\u53c2\u6570\n    # init_rnn_state: \u521d\u59cb\u5316\u9690\u85cf\u72b6\u6001\u7684\u51fd\u6570\n    # num_hiddens: \u9690\u85cf\u5355\u5143\u4e2a\u6570\n    # vocab_size: \u8bcd\u5178\u5927\u5c0f\uff0c\u5373one-hot\u5411\u91cf\u957f\u5ea6\n    # device: \u8bbe\u5907\u540d\uff0c\u5982'cpu'\u6216'cuda'\n    # idx_to_char: \u7d22\u5f15\u5230\u5b57\u7b26\u7684\u6620\u5c04\n    # char_to_idx: \u5b57\u7b26\u5230\u7d22\u5f15\u7684\u6620\u5c04\n    state = init_rnn_state(1, num_hiddens, device) # \u521d\u59cb\u5316\u9690\u85cf\u72b6\u6001\n    output = &#91;char_to_idx&#91;prefix&#91;0]]] # output\u8bb0\u5f55prefix\u52a0\u4e0a\u9884\u6d4b\u7684num_chars\u4e2a\u5b57\u7b26\u7684\u7d22\u5f15\uff0c\u8fd9\u91cc\u53ea\u8bb0\u5f55\u4e86prefix\u7684\u7b2c\u4e00\u4e2a\u5b57\u7b26\n    for t in range(num_chars + len(prefix) - 1):\n        # \u5c06\u4e0a\u4e00\u65f6\u95f4\u6b65\u7684\u8f93\u51fa\u4f5c\u4e3a\u5f53\u524d\u65f6\u95f4\u6b65\u7684\u8f93\u5165\n        X = to_onehot(torch.tensor(&#91;&#91;output&#91;-1]]], device=device), vocab_size)\n        # \u8ba1\u7b97\u8f93\u51fa\u548c\u66f4\u65b0\u9690\u85cf\u72b6\u6001\n        (Y, state) = rnn(X, state, params)\n        # \u4e0b\u4e00\u4e2a\u65f6\u95f4\u6b65\u7684\u8f93\u5165\u662fprefix\u91cc\u7684\u5b57\u7b26\u6216\u8005\u5f53\u524d\u7684\u6700\u4f73\u9884\u6d4b\u5b57\u7b26\n        if t &lt; len(prefix) - 1:\n            output.append(char_to_idx&#91;prefix&#91;t + 1]])\n        else:\n            output.append(int(Y&#91;0].argmax(dim=1).item()))\n    return ''.join(&#91;idx_to_char&#91;i] for i in output])<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">\u88c1\u526a\u68af\u5ea6<\/h2>\n\n\n\n<p>\u5faa\u73af\u795e\u7ecf\u7f51\u7edc\u4e2d\u8f83\u5bb9\u6613\u51fa\u73b0\u68af\u5ea6\u8870\u51cf\u6216\u68af\u5ea6\u7206\u70b8\u3002\u4e3a\u4e86\u5e94\u5bf9\u68af\u5ea6\u7206\u70b8\uff0c\u53ef\u4ee5\u88c1\u526a\u68af\u5ea6\uff08clip gradient\uff09\u3002\u5047\u8bbe\u628a\u6240\u6709\u6a21\u578b\u53c2\u6570\u68af\u5ea6\u7684\u5143\u7d20\u62fc\u63a5\u6210\u4e00\u4e2a\u5411\u91cf $\\boldsymbol{g}$\uff0c\u5e76\u8bbe\u88c1\u526a\u7684\u9608\u503c\u662f$\\theta$\u3002\u88c1\u526a\u540e\u7684\u68af\u5ea6<\/p>\n\n\n\n<p>$$<br>\\min\\left(\\frac{\\theta}{|\\boldsymbol{g}|}, 1\\right)\\boldsymbol{g}<br>$$<\/p>\n\n\n\n<p>\u7684$L_2$\u8303\u6570\u4e0d\u8d85\u8fc7$\\theta$\u3002<\/p>\n\n\n\n<p>\u5982\u679c\u68af\u5ea6\u7684\u8303\u6570 $|\\boldsymbol{g}|$ \u5c0f\u4e8e\u6216\u7b49\u4e8e\u9608\u503c $\\theta$\uff0c\u5219\u4e0d\u9700\u8981\u88c1\u526a\uff0c\u76f4\u63a5\u4f7f\u7528\u539f\u59cb\u68af\u5ea6 $\\boldsymbol{g}$\u3002<\/p>\n\n\n\n<p>\u5982\u679c\u68af\u5ea6\u7684\u8303\u6570 $|\\boldsymbol{g}|$ \u5927\u4e8e\u9608\u503c $\\theta$\uff0c\u5219\u9700\u8981\u5bf9\u68af\u5ea6\u8fdb\u884c\u7f29\u653e\uff0c\u88c1\u526a\u540e\u7684\u68af\u5ea6\u4e3a\uff1a<\/p>\n\n\n\n<p>$$<br>\\frac{\\theta}{|\\boldsymbol{g}|} \\boldsymbol{g}<br>$$<\/p>\n\n\n\n<p>\u8fd9\u76f8\u5f53\u4e8e\u5c06\u68af\u5ea6\u5411\u91cf\u6309\u6bd4\u4f8b\u7f29\u5c0f\uff0c\u4f7f\u5176\u8303\u6570\u53d8\u4e3a $\\theta$\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code># \u672c\u51fd\u6570\u5df2\u4fdd\u5b58\u5728d2lzh_pytorch\u5305\u4e2d\u65b9\u4fbf\u4ee5\u540e\u4f7f\u7528\ndef grad_clipping(params, theta, device):\n    norm = torch.tensor(&#91;0.0], device=device)\n    for param in params:\n        norm += (param.grad.data ** 2).sum()\n    norm = norm.sqrt().item()\n    if norm &gt; theta:\n        for param in params:\n            param.grad.data *= (theta \/ norm)<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">\u56f0\u60d1\u5ea6<\/h2>\n\n\n\n<p>\u901a\u5e38\u4f7f\u7528\u56f0\u60d1\u5ea6\uff08perplexity\uff09\u6765\u8bc4\u4ef7\u8bed\u8a00\u6a21\u578b\u7684\u597d\u574f\u3002\u56f0\u60d1\u5ea6\u662f\u5bf9\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\u505a\u6307\u6570\u8fd0\u7b97\u540e\u5f97\u5230\u7684\u503c\u3002\u7279\u522b\u5730\uff0c<\/p>\n\n\n\n<ul class=\"wp-block-list\">\n<li>\u6700\u4f73\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u603b\u662f\u628a\u6807\u7b7e\u7c7b\u522b\u7684\u6982\u7387\u9884\u6d4b\u4e3a1\uff0c\u6b64\u65f6\u56f0\u60d1\u5ea6\u4e3a1\uff1b<\/li>\n\n\n\n<li>\u6700\u574f\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u603b\u662f\u628a\u6807\u7b7e\u7c7b\u522b\u7684\u6982\u7387\u9884\u6d4b\u4e3a0\uff0c\u6b64\u65f6\u56f0\u60d1\u5ea6\u4e3a\u6b63\u65e0\u7a77\uff1b<\/li>\n\n\n\n<li>\u57fa\u7ebf\u60c5\u51b5\u4e0b\uff0c\u6a21\u578b\u603b\u662f\u9884\u6d4b\u6240\u6709\u7c7b\u522b\u7684\u6982\u7387\u90fd\u76f8\u540c\uff0c\u6b64\u65f6\u56f0\u60d1\u5ea6\u4e3a\u7c7b\u522b\u4e2a\u6570\u3002<\/li>\n<\/ul>\n\n\n\n<p>\u663e\u7136\uff0c\u4efb\u4f55\u4e00\u4e2a\u6709\u6548\u6a21\u578b\u7684\u56f0\u60d1\u5ea6\u5fc5\u987b\u5c0f\u4e8e\u7c7b\u522b\u4e2a\u6570\u3002\u5728\u672c\u4f8b\u4e2d\uff0c\u56f0\u60d1\u5ea6\u5fc5\u987b\u5c0f\u4e8e\u8bcd\u5178\u5927\u5c0f<code>vocab_size<\/code>\u3002<\/p>\n\n\n\n<h2 class=\"wp-block-heading\">\u5b9a\u4e49\u6a21\u578b\u8bad\u7ec3\u51fd\u6570<\/h2>\n\n\n\n<pre class=\"wp-block-code\"><code>def train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,\n                          vocab_size, device, corpus_indices, idx_to_char,\n                          char_to_idx, is_random_iter, num_epochs, num_steps,\n                          lr, clipping_theta, batch_size, pred_period,\n                          pred_len, prefixes):\n    # rnn : \u5faa\u73af\u795e\u7ecf\u7f51\u7edc\u6a21\u578b\n    # get_params : \u83b7\u53d6\u6a21\u578b\u53c2\u6570\u7684\u51fd\u6570\n    # init_rnn_state : \u521d\u59cb\u5316\u9690\u85cf\u72b6\u6001\u7684\u51fd\u6570\n    # num_hiddens : \u9690\u85cf\u5355\u5143\u4e2a\u6570\n    # vocab_size : \u8bcd\u5178\u5927\u5c0f\uff0c\u5373one-hot\u5411\u91cf\u957f\u5ea6\n    # device : \u8bbe\u5907\u540d\uff0c\u5982'cpu'\u6216'cuda'\n    # corpus_indices : \u5b57\u7b26\u7d22\u5f15\u5e8f\u5217\n    # idx_to_char : \u7d22\u5f15\u5230\u5b57\u7b26\u7684\u6620\u5c04\n    # char_to_idx : \u5b57\u7b26\u5230\u7d22\u5f15\u7684\u6620\u5c04\n    # is_random_iter : \u662f\u5426\u4f7f\u7528\u968f\u673a\u91c7\u6837\n    # num_epochs : \u8fed\u4ee3\u6b21\u6570\n    # num_steps : \u65f6\u95f4\u6b65\u6570\n    # lr : \u5b66\u4e60\u7387\n    # clipping_theta : \u68af\u5ea6\u88c1\u526a\u9608\u503c\n    # batch_size : \u6279\u91cf\u5927\u5c0f\n    # pred_period : \u9884\u6d4b\u5468\u671f\uff0c\u6bcf\u95f4\u9694\u591a\u5c11\u4e2a\u8fed\u4ee3\u5468\u671f\u540e\u9884\u6d4b\u4e00\u6b21\n    # pred_len : \u9884\u6d4b\u957f\u5ea6\n    # prefixes : \u9884\u6d4b\u65f6\u4f7f\u7528\u7684\u524d\u7f00\n\n    if is_random_iter:\n        data_iter_fn = data_iter_random # 6.3\u4e2d\u7684\u968f\u673a\u91c7\u6837\u51fd\u6570\n    else:\n        data_iter_fn = data_iter_consecutive # 6.3\u4e2d\u7684\u76f8\u90bb\u91c7\u6837\u51fd\u6570\n    params = get_params() # \u83b7\u53d6\u6a21\u578b\u53c2\u6570\n    loss = nn.CrossEntropyLoss() # \u5b9a\u4e49\u4ea4\u53c9\u71b5\u635f\u5931\u51fd\u6570\n\n    for epoch in range(num_epochs):\n        if not is_random_iter:  # \u5982\u4f7f\u7528\u76f8\u90bb\u91c7\u6837\uff0c\u5728epoch\u5f00\u59cb\u65f6\u521d\u59cb\u5316\u9690\u85cf\u72b6\u6001\n            state = init_rnn_state(batch_size, num_hiddens, device) # \u521d\u59cb\u5316\u9690\u85cf\u72b6\u6001\n        l_sum, n, start = 0.0, 0, time.time() # \u521d\u59cb\u5316\u635f\u5931\uff0c\u5b57\u7b26\u6570\uff0c\u5f00\u59cb\u65f6\u95f4\n        data_iter = data_iter_fn(corpus_indices, batch_size, num_steps, device) # \u83b7\u53d6\u6570\u636e\u8fed\u4ee3\u5668\n        for X, Y in data_iter:\n            if is_random_iter:  # \u5982\u4f7f\u7528\u968f\u673a\u91c7\u6837\uff0c\u5728\u6bcf\u4e2a\u5c0f\u6279\u91cf\u66f4\u65b0\u524d\u521d\u59cb\u5316\u9690\u85cf\u72b6\u6001\n                state = init_rnn_state(batch_size, num_hiddens, device)\n            else:  \n            # \u5426\u5219\u9700\u8981\u4f7f\u7528detach\u51fd\u6570\u4ece\u8ba1\u7b97\u56fe\u5206\u79bb\u9690\u85cf\u72b6\u6001, \u8fd9\u662f\u4e3a\u4e86\n            # \u4f7f\u6a21\u578b\u53c2\u6570\u7684\u68af\u5ea6\u8ba1\u7b97\u53ea\u4f9d\u8d56\u4e00\u6b21\u8fed\u4ee3\u8bfb\u53d6\u7684\u5c0f\u6279\u91cf\u5e8f\u5217(\u9632\u6b62\u68af\u5ea6\u8ba1\u7b97\u5f00\u9500\u592a\u5927)\n                for s in state:\n                    s.detach_() # \u4ece\u8ba1\u7b97\u56fe\u5206\u79bb\u9690\u85cf\u72b6\u6001\n\n            inputs = to_onehot(X, vocab_size) # one-hot\u5411\u91cf\n            (outputs, state) = rnn(inputs, state, params) # outputs\u6709num_steps\u4e2a\u5f62\u72b6\u4e3a(batch_size, vocab_size)\u7684\u77e9\u9635\n            outputs = torch.cat(outputs, dim=0) # \u62fc\u63a5\u4e4b\u540e\u5f62\u72b6\u4e3a(num_steps * batch_size, vocab_size)\n            y = torch.transpose(Y, 0, 1).contiguous().view(-1)  # Y\u7684\u5f62\u72b6\u662f(batch_size, num_steps)\uff0c\u8f6c\u7f6e\u540e\u518d\u53d8\u6210\u957f\u5ea6\u4e3a batch * num_steps \u7684\u5411\u91cf\uff0c\u8fd9\u6837\u8ddf\u8f93\u51fa\u7684\u884c\u4e00\u4e00\u5bf9\u5e94\n            l = loss(outputs, y.long()) # \u4f7f\u7528\u4ea4\u53c9\u71b5\u635f\u5931\u8ba1\u7b97\u5e73\u5747\u5206\u7c7b\u8bef\u5dee\n\n            # \u68af\u5ea6\u6e050\n            if params&#91;0].grad is not None:\n                for param in params:\n                    param.grad.data.zero_()\n            l.backward() # \u53cd\u5411\u4f20\u64ad\n            grad_clipping(params, clipping_theta, device)  # \u88c1\u526a\u68af\u5ea6\n            sgd(params, lr, 1)  # \u56e0\u4e3a\u8bef\u5dee\u5df2\u7ecf\u53d6\u8fc7\u5747\u503c\uff0c\u68af\u5ea6\u4e0d\u7528\u518d\u505a\u5e73\u5747\n            l_sum += l.item() * y.shape&#91;0] # \u7edf\u8ba1\u603b\u7684\u635f\u5931\n            n += y.shape&#91;0] # \u7edf\u8ba1\u603b\u7684\u9884\u6d4b\u6570\u91cf\n\n        if (epoch + 1) % pred_period == 0:\n            print('\u7b2c %d \u4e2a\u8fed\u4ee3\u5468\u671f\uff0c\u56f0\u60d1\u5ea6 %.2f\uff0c\u8017\u65f6 %.2f \u79d2' % (\n                epoch + 1, math.exp(l_sum \/ n), time.time() - start)) # l_sum \/ n \u662f\u5e73\u5747\u635f\u5931\n            for prefix in prefixes: # \u4f7f\u7528\u9884\u6d4b\u6a21\u578b\n                print(' -', predict_rnn(prefix, pred_len, rnn, params, init_rnn_state,\n                    num_hiddens, vocab_size, device, idx_to_char, char_to_idx))<\/code><\/pre>\n\n\n\n<h2 class=\"wp-block-heading\">\u8bad\u7ec3\u6a21\u578b\u5e76\u521b\u4f5c\u6b4c\u8bcd<\/h2>\n\n\n\n<p>\u8bbe\u7f6e\u6a21\u578b\u8d85\u53c2\u6570\u3002\u5c06\u6839\u636e\u524d\u7f00\u201c\u5206\u5f00\u201d\u548c\u201c\u4e0d\u5206\u5f00\u201d\u5206\u522b\u521b\u4f5c\u957f\u5ea6\u4e3a50\u4e2a\u5b57\u7b26\uff08\u4e0d\u8003\u8651\u524d\u7f00\u957f\u5ea6\uff09\u7684\u4e00\u6bb5\u6b4c\u8bcd\u3002\u6bcf\u8fc750\u4e2a\u8fed\u4ee3\u5468\u671f\u4fbf\u6839\u636e\u5f53\u524d\u8bad\u7ec3\u7684\u6a21\u578b\u521b\u4f5c\u4e00\u6bb5\u6b4c\u8bcd\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>num_epochs, num_steps, batch_size, lr, clipping_theta = 250, 35, 32, 1e2, 1e-2\npred_period, pred_len, prefixes = 50, 50, &#91;'\u5206\u5f00', '\u4e0d\u5206\u5f00']<\/code><\/pre>\n\n\n\n<p>\u4e0b\u9762\u91c7\u7528\u968f\u673a\u91c7\u6837\u8bad\u7ec3\u6a21\u578b\u5e76\u521b\u4f5c\u6b4c\u8bcd\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,\n                      vocab_size, device, corpus_indices, idx_to_char,\n                      char_to_idx, True, num_epochs, num_steps, lr,\n                      clipping_theta, batch_size, pred_period, pred_len,\n                      prefixes)<\/code><\/pre>\n\n\n\n<p>\u63a5\u4e0b\u6765\u91c7\u7528\u76f8\u90bb\u91c7\u6837\u8bad\u7ec3\u6a21\u578b\u5e76\u521b\u4f5c\u6b4c\u8bcd\u3002<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>train_and_predict_rnn(rnn, get_params, init_rnn_state, num_hiddens,\n                      vocab_size, device, corpus_indices, idx_to_char,\n                      char_to_idx, False, num_epochs, num_steps, lr,\n                      clipping_theta, batch_size, pred_period, pred_len,\n                      prefixes)<\/code><\/pre>\n","protected":false},"excerpt":{"rendered":"<p>\u8bfb\u53d6\u5468\u6770\u4f26\u4e13\u8f91\u6b4c\u8bcd\u6570\u636e\u96c6\uff1a one-hot\u5411\u91cf \u5047\u8bbe\u8bcd\u5178\u4e2d\u4e0d\u540c\u5b57\u7b26\u7684\u6570\u91cf\u4e3a$N$\uff08\u5373\u8bcd\u5178\u5927\u5c0fvocab_size\uff09\uff0c\u6bcf\u4e2a\u5b57\u7b26\u5df2\u7ecf\u540c &#8230;<\/p>","protected":false},"author":1,"featured_media":1854,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[46,3],"tags":[45,44,12,22],"class_list":["post-1852","post","type-post","status-publish","format-standard","has-post-thumbnail","hentry","category-46","category-3","tag-45","tag-44","tag-12","tag-22"],"_links":{"self":[{"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/posts\/1852","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/comments?post=1852"}],"version-history":[{"count":1,"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/posts\/1852\/revisions"}],"predecessor-version":[{"id":1855,"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/posts\/1852\/revisions\/1855"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/media\/1854"}],"wp:attachment":[{"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/media?parent=1852"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/categories?post=1852"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/www.forillusion.com\/index.php\/wp-json\/wp\/v2\/tags?post=1852"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}