前面我们完成了自己训练一个小模型,今天我们结合论文来学习一下Transformer的理论知识~
概述Transformer模型于2017年在论文《注意力就是你所需要的一切》中首次提出。Transformer架构旨在训练语言翻译目的模型。然而,OpenAI的团队发现transformer架构是角色预测的关键解决方案。一旦对整个互联网数据进行训练,该模型就有可能理解任何文本的上下文,并连贯地完成任何句子,就像人类一样。
该模型由两部分组成:编码器和解码器。通常,仅编码器体系结构擅长从文本中提取信息以执行分类和回归等任务,而仅解码器模型则专门用于生成文本。例如,专注于文本生成的GPT属于仅解码器模型的范畴。
让我们在训练模型时了解架构的关键思想。
我画了一张图来说明类似GPT的仅解码器转换器架构的训练过程:
首先,我们需要一系列输入字符作为训练数据。这些输入被转换为矢量嵌入格式。
接下来,我们在向量嵌入中添加位置编码,以捕获每个字符在序列中的位置。
随后,模型通过一系列计算操作处理这些输入嵌入,最终为给定的输入文本生成可能的下一个字符的概率分布。
该模型根据训练数据集中的实际后续特征评估预测结果,并相应地调整概率或“权重”。
最后,该模型迭代地完善了这一过程,不断更新其参数以提高未来预测的精度。
让我们深入了解每个步骤的细节。
1Tokenization标记化是转换器模型的第一步,该模型:
将输入句子转换为数字表示格式。
标记化是将文本划分为称为标记的较小单元的过程,这些单元可以是单词、子单词、短语或字符。因为将短语分解成更小的部分有助于模型识别文本的底层结构并更有效地处理它。
例如:
上面这句话可以切成:
Chapter,,,,,,,,,,,,1``:``Building``Rap``port``and``Capturing
它被标记为10个数字:
如您所见,数字220用于表示空格字符。有许多方法可以将字符标记为整数。对于我们的示例数据集,我们将使用tiktoken库。
出于演示目的,我将使用一个小型教科书数据集(来自HuggingFace),其中包含460k个字符用于我们的训练。
文件大小:450Kb
词汇量:3,771(表示唯一单词/子单词)
我们的训练数据包含3,771个不同字符的词汇量。用于标记我们的教科书数据集的最大数量是,它被映射到一个字符。100069``Clar
一旦我们有了标记化映射,我们就可以为数据集中每个字符找到相应的整数索引。我们将利用这些分配的整数索引作为标记,而不是在与模型交互时使用整个单词。
2WordEmbeddings首先,让我们构建一个包含词汇表中所有字符的查找表。从本质上讲,该表由一个填充了随机初始化数字的矩阵组成。
给定我们拥有的最大标记数是,并考虑维度为64(原始论文使用512维,表示为d_model),生成的查找表变为100,069×64矩阵,这称为标记嵌入查找表。表示如下:100069
TokenEmbeddingLook-UpTable:01234567895455565758596061626300.6257650.0255100.9545140.064349-0.502401-0.202555-1.567081-1.0979560.235958-0.2397780.4208120.2775960.7788981.5332691.609736-0.403228-0.2749281.4738400.0688261.3327081-0.4970060.465756-0.257259-1.0672590.835319-1.956048-0.800265-0.504499-1.4266640.9059420.008287-0.252325-0.6576260.318449-0.549586-1.464924-0.557690-0.693927-0.3252471.24393321.3471211.690980-0.124446-1.6823661.134614-0.0823840.2893160.8357730.306655-0.7472330.543340-0.843840-0.6874812.1382190.5114121.2190900.097527-0.978587-0.432050-1.49375031.078523-0.614952-0.4588530.5674820.095883-1.5699570.373957-0.142067-1.242306-0.961821-0.8824410.6387201.119174-1.907924-0.5275631.080655-2.2152070.203201-1.115814-1.25869140.814849-0.0642971.4236530.261726-0.1331770.2118931.4497903.055426-1.783010-0.8323390.6654150.723436-1.3184540.785860-1.1501111.313207-0.3349490.1497431.306531-0.046524100064-0.898191-1.906910-0.9069101.8385322.121814-1.6544440.0827780.0645360.3451210.2622470.4389560.1633140.4919961.721039-0.1243161.2282420.3689631.0582800.406413-0.3262231000651.354992-1.203096-2.184551-1.745679-0.005853-0.8605061.0107840.355051-1.489120-1.9361921.354665-1.338872-0.2639050.2849060.202743-0.487176-0.4219590.490739-1.0564572.636806100066-0.4361160.450023-1.3815220.6255080.4155760.628877-0.595811-1.074244-1.512645-2.0274220.4365220.0689741.3058520.005790-0.583766-0.7970040.144952-0.2797721.522029-0.6296721000670.1471020.578953-0.668165-0.0114430.2366210.348374-0.7060881.368070-1.428709-0.6201891.130942-0.739860-1.546209-1.475937-0.145684-1.7448290.637790-1.0644551.290440-1.1105201000680.415268-0.3455750.441546-0.5790851.110969-1.3036910.143943-0.714082-1.4265121.646982-2.5025351.4094180.159812-0.9113230.856282-0.404213-0.0127411.3334260.3722550.722526[100,069rowsx64columns]
其中每行代表一个字符(按其标记编号索引),每列代表一个维度。
现在,您可以将“维度”视为角色的特征或方面。在我们的例子中,我们指定了64个维度,这意味着我们将能够以64种不同的方式理解一个角色的文本含义,例如将其分类为名词、动词、形容词等。
假设,现在我们有一个16context_length的训练输入示例,即:
".Bymasteringtheartofidentifyingunderlyingmotivationsanddesires,weequipourselveswith"
现在,我们通过使用其整数索引来查找嵌入表,从而检索每个标记化字符(或单词)的嵌入向量。因此,我们得到了它们各自的输入嵌入:
[627,1383,88861,279,1989,315,25607,16940,65931,323,32097,11,584,26458,13520,449]
在变压器架构中,多个输入序列同时并行处理,通常称为多批处理。让我们将batch_size设置为4。因此,我们将一次处理四个随机选择的句子作为我们的输入。
InputSequenceBatch:01234567891011123838886127919893065935842645874931196446396611107421329622867633748886526997268672647016031666126449311724666044[4rowsx16columns]
每行代表一个句子;每列是该句子从第0位到第15位的字符。
结果,我们现在有一个矩阵,表示4批16个字符的输入。该矩阵的形状为(batch_size,context_length)=[4,16]。
回顾一下,我们将输入嵌入查找表定义为大小为100,069×64的矩阵。下一步是获取我们的输入序列矩阵并将其映射到这个嵌入矩阵上,以获得我们的输入嵌入。
在这里,我们将重点分解输入序列矩阵的每一行,从第一行开始。首先,我们将此开始行从其原始尺寸(1,context_length)=[1,16]重塑为(context_length,1)=[16,1]的新格式。随后,我们将这个重组后的行覆盖在我们之前建立的嵌入矩阵大小(vocab_size,d_model)=[100069,64]上,从而将匹配的嵌入向量替换为给定上下文窗口中存在的每个字符。生成的输出是形状为(context_length,d_model)=[16,64]的矩阵。
输入序列批处理的第一行:
InputEmbedding:01234567895455565758596061626301.051807-0.704369-0.913199-1.1515640.582201-0.8985820.984299-0.075260-0.004821-0.7436421.1513780.1195950.601200-0.9403520.2899600.5797490.4286230.263096-0.773865-0.7342201-0.293959-1.278850-0.0507310.8625620.200148-1.7326250.374076-1.1285070.281203-1.073113-0.062417-0.4405990.8002830.7830431.602350-0.676059-0.2465311.005652-1.0186670.6040922-0.2921960.109248-0.131576-0.7005360.326451-1.885801-0.1508340.348330-0.7772810.9867690.3824801.315575-0.1440371.2801031.1128290.438884-0.275823-2.2266980.1089840.70188130.4279420.878749-0.1769510.5487720.226408-0.070323-1.8652351.4733641.0328850.6961731.2701871.028823-0.872329-0.147387-0.0832870.142618-0.375903-0.1018870.989520-0.0625604-1.064934-0.1315700.514266-0.7590370.2940440.9571250.976445-1.477583-1.376966-1.1713440.2311121.2786870.2546880.5162870.6217530.2191791.345463-0.9278670.5101720.65685152.514588-1.0012510.391298-0.8457120.046932-0.0367321.3964510.934358-0.876228-0.0244400.0898040.646096-0.2069350.187104-1.288239-1.0681430.696718-0.373597-0.334495-0.46221860.498423-0.349237-1.061968-0.0930991.374657-0.512061-1.238927-1.342982-1.6116352.0714450.0255050.6380720.104059-0.600942-0.367796-0.4721890.8439340.706170-1.676522-0.26637971.684027-0.651413-0.7680500.599159-0.3815950.9287992.1885721.579998-0.122685-1.026440-0.3136721.276962-1.142109-0.1451391.207923-0.058557-0.3528061.506868-2.2966421.3786788-0.041210-0.834533-1.243622-0.675754-1.7765860.038765-2.7130902.423366-1.7118150.621387-1.0637581.525688-1.7620230.1610980.0268060.4623470.7329750.4797500.942445-1.05057590.7087541.0585100.2975600.2105480.4605511.0161412.5548970.2540320.935956-0.250423-0.5528350.0841240.4373480.5962280.5121680.289721-0.028321-0.932675-0.4112351.03575410-0.5845531.3956760.7273540.6413520.693481-2.113973-0.786199-0.3277581.278788-0.1561181.204587-0.131655-0.595295-0.433438-0.8636843.2722470.1015910.619058-0.982174-1.17412511-0.7538280.098016-0.9453220.708373-1.4937440.3947320.075629-0.049392-1.0055640.3563532.452891-0.2335710.398788-1.597272-1.919085-0.405561-0.2666441.2370221.079494-2.29241412-0.6118640.0068101.989711-0.446170-0.6701080.045619-0.0928341.226774-1.407549-0.0966951.181310-0.407162-0.086341-0.5306280.0429211.3694780.823999-0.3129570.5917550.51631413-0.5845531.3956760.7273540.6413520.693481-2.113973-0.786199-0.3277581.278788-0.1561181.204587-0.131655-0.595295-0.433438-0.8636843.2722470.1015910.619058-0.982174-1.17412514-1.1740900.096075-0.7491950.395859-0.622460-1.2911260.0944310.680156-0.4807420.7093180.7866630.2377331.5137970.2966960.069533-0.2367191.098030-0.442940-0.5831771.151497150.401740-0.5295873.016675-1.134723-0.256546-0.2198960.6379362.000511-0.418684-0.242720-0.442287-1.519394-1.007496-0.5174800.307449-0.316039-0.880636-1.424680-1.9016441.968463[16rowsx64columns]
矩阵显示映射后的四行之一
我们对其余的3行执行相同的操作,最终我们有4组x[16行x64列]。
这会导致形状为(batch_size,context_length,d_model)=[4,16,64]的输入嵌入矩阵。
从本质上讲,为每个单词提供唯一的嵌入允许模型适应语言的变化并管理具有多种含义或形式的单词。
让我们继续前进,理解我们的输入嵌入矩阵作为我们模型的预期输入格式,即使我们还没有完全掌握起作用的基本数学原理。
3PositionalEncoding在我看来,位置编码是变压器架构中最具挑战性的概念。
总结一下位置编码解决了什么问题:
我们希望每个单词都带有一些关于它在句子中的位置的信息。
我们希望模型将看起来彼此接近的单词视为“接近”,将距离较远的单词视为“遥远”。
我们希望位置编码表示模型可以学习的模式。
位置编码描述序列中实体的位置或位置,以便为每个位置分配唯一的表示形式。
位置编码是另一个数字向量,它被添加到每个标记化字符的输入嵌入中。位置编码是正弦波和余弦波,其频率根据标记化字符的位置而变化。
在原始论文中,引入的位置编码计算方法是:
PE(pos,2i)=sin(pos/10000^(2i/d_model))PE(pos,2i+1)=cos(pos/10000^(2i/d_model))
其中是位置,从0到d_model/2。是我们在训练模型时定义的模型维度(在我们的例子中是64,在原始论文中他们使用512)。pos``i``d_model
事实上,这个位置编码矩阵只创建一次,并重复用于每个输入序列。
让我们看一下位置编码矩阵:
让我们多谈谈位置编码技巧。
PositionEmbeddingLook-UpTable:01234567895455565758596061626300.0000001.0000000.0000001.0000000.0000001.0000000.0000001.0000000.0000001.0000000.0000001.0000000.0000001.0000000.0000001.0000000.0000001.0000000.0000001.00000010.8414710.5403020.6815610.7317610.5331680.8460090.4093090.9123960.3109840.9504150.0004221.0000000.0003161.0000000.0002371.0000000.0001781.0000000.0001331.00000020.909297-0.4161470.9974800.0709480.9021310.4314630.7469040.6649320.5911270.8065780.0008431.0000000.0006321.0000000.0004741.0000000.0003561.0000000.0002671.00000030.141120-0.9899920.778273-0.6279270.993253-0.1159660.9536350.3009670.8126490.5827540.0012650.9999990.0009491.0000000.0007111.0000000.0005331.0000000.0004001.0000004-0.756802-0.6536440.141539-0.9899330.778472-0.6276800.993281-0.1157300.9535810.3011370.0016870.9999990.0012650.9999990.0009491.0000000.0007111.0000000.0005331.0000005-0.9589240.283662-0.571127-0.8208620.323935-0.9460790.858896-0.5121500.999947-0.0103420.0021080.9999980.0015810.9999990.0011860.9999990.0008891.0000000.0006671.0000006-0.2794150.960170-0.977396-0.211416-0.230368-0.9731040.574026-0.8188370.947148-0.3207960.0025300.9999970.0018970.9999980.0014230.9999990.0010670.9999990.0008001.00000070.6569870.753902-0.8593130.511449-0.713721-0.7004300.188581-0.9820580.800422-0.5994370.0029520.9999960.0022140.9999980.0016600.9999990.0012450.9999990.0009331.00000080.989358-0.145500-0.2802280.959933-0.977262-0.212036-0.229904-0.9732130.574318-0.8186320.0033740.9999940.0025300.9999970.0018970.9999980.0014230.9999990.0010670.99999990.412118-0.9111300.4491940.893434-0.9398240.341660-0.608108-0.7938540.291259-0.9566440.0037950.9999930.0028460.9999960.0021340.9999980.0016000.9999990.0012000.99999910-0.544021-0.8390720.9376330.347628-0.6129370.790132-0.879767-0.475405-0.020684-0.9997860.0042170.9999910.0031620.9999950.0023710.9999970.0017780.9999980.0013340.99999911-0.9999900.0044260.923052-0.384674-0.0972760.995257-0.997283-0.073661-0.330575-0.9437800.0046390.9999890.0034780.9999940.0026090.9999970.0019560.9999980.0014670.99999912-0.5365730.8438540.413275-0.9106060.4483430.893862-0.9400670.340989-0.607683-0.7941790.0050600.9999870.0037950.9999930.0028460.9999960.0021340.9999980.0016000.999999130.4201670.907447-0.318216-0.9480180.8558810.517173-0.7181440.695895-0.824528-0.5658210.0054820.9999850.0041110.9999920.0030830.9999950.0023120.9999970.0017340.999998140.9906070.136737-0.878990-0.4768390.999823-0.018796-0.3703950.928874-0.959605-0.2813490.0059040.9999830.0044270.9999900.0033200.9999950.0024900.9999970.0018670.999998150.650288-0.759688-0.9682060.2501540.835838-0.5489750.0422490.999107-0.9995190.0310220.0063250.9999800.0047430.9999890.0035570.9999940.0026670.9999960.0020000.999998[16rowsx64columns]
据我了解,位置值是根据它们在序列中的相对位置建立的。此外,由于每个输入句子的上下文长度一致,它使我们能够在各种输入中回收相同的位置编码。因此,必须谨慎地创建序列号,以防止过大的幅度对输入嵌入产生负面影响,确保相邻位置表现出微小的差异,而远处的位置显示出它们之间的较大差异。
使用正弦和余弦向量的组合,该模型可以看到独立于词嵌入的位置编码向量,而不会混淆输入嵌入(语义)信息。很难想象这在神经元网络内部是如何工作的,但它是有效的。
我们可以可视化我们的位置嵌入数字并查看模式。
每条垂直线是我们从0到64的维度;每行代表一个字符。这些值介于-1和1之间,因为它们来自正弦和余弦函数。颜色越深表示值越接近-1,颜色越亮表示值越接近1。绿色表示介于两者之间的值。
让我们回到我们的位置编码矩阵,正如你所看到的,这个位置编码表与输入嵌入表[4,16,64]中的每个批处理具有相同的形状,它们都是(context_length,d_model)=[16,64]。
由于两个具有相同形状的矩阵可以相加,因此我们可以将位置信息添加到每个输入嵌入行中,以获得最终输入嵌入矩阵。
batch0:01234567895455565758596061626301.0518070.295631-0.913199-0.1515640.5822010.1014180.9842990.924740-0.0048210.2563581.1513781.1195950.6012000.0596480.2899601.5797490.4286231.263096-0.7738650.26578010.547512-0.7385480.6308301.5943230.733316-0.8866160.783385-0.2161110.592187-0.122698-0.0619950.5594010.8005991.7830431.6025870.323941-0.2463532.005651-1.0185341.60409220.617101-0.3068990.865904-0.6295881.228581-1.4543390.5960701.013263-0.1861541.7933480.3833242.315575-0.1434042.2801021.1133031.438884-0.275467-1.2266980.1092511.70188130.569062-0.1112430.601322-0.0791541.219661-0.186289-0.9116001.7743321.8455331.2789271.2714522.028822-0.8713800.852612-0.0825751.142617-0.3753690.8981130.9899200.9374404-1.821736-0.7852140.655805-1.7489691.0725160.3294451.969725-1.593312-0.423386-0.8702060.2327992.2786850.2559531.5162870.6227011.2191781.3461750.0721330.5107051.65685151.555663-0.717588-0.179829-1.6665740.370867-0.9828112.2553470.4222080.123719-0.0347820.0919121.646094-0.2053541.187103-1.287054-0.0681440.6976070.626403-0.3338280.53778260.2190070.610934-2.039364-0.3045161.144289-1.485164-0.664902-2.161820-0.6644871.7506490.0280361.6380680.1059570.399056-0.3663730.5278100.8450011.706170-1.6757220.73362172.3410130.102489-1.6273631.110608-1.0953160.2283692.3771530.5979400.677737-1.625878-0.3107202.276958-1.1398950.8548591.2095830.941441-0.3515622.506867-2.2957082.37867880.948148-0.980033-1.5238500.284180-2.753848-0.173272-2.9429951.450153-1.137498-0.197246-1.0603852.525683-1.7594941.1610950.0287031.4623460.7343971.4797490.943511-0.05057591.1208720.1473800.7467531.103982-0.4792731.3578011.946789-0.5398221.227215-1.207067-0.5490401.0841170.4401941.5962240.5143031.289719-0.0267210.067324-0.4100352.03575310-1.1285740.5566041.6649860.9889800.080544-1.323841-1.665967-0.8031631.258105-1.1559041.2088040.868336-0.5921320.566557-0.8613134.2722440.1033691.619057-0.980840-0.17412611-1.7538180.102441-0.0222700.323699-1.5910201.389990-0.921654-0.123053-1.336139-0.5874272.4575300.7664190.402266-0.597278-1.9164760.594436-0.2646882.2370201.080961-1.29241512-1.1484370.8506642.402985-1.356776-0.2217650.939481-1.0329021.567763-2.015232-0.8908741.1863700.592825-0.0825460.4693650.0457672.3694740.8261330.6870410.5933551.51631313-0.1643862.3031230.409138-0.3066661.549362-1.596800-1.5043430.3681370.454260-0.7219381.2100690.868330-0.5911840.566554-0.8606014.2722430.1039031.619056-0.980440-0.17412714-0.1834820.232812-1.628186-0.0809810.377364-1.309922-0.2759641.609030-1.4403470.4279690.7925661.2377151.5182241.2966860.0728530.7632761.1005200.557057-0.5813102.151496151.052028-1.2892752.048469-0.8845700.579293-0.7688710.6801852.999618-1.418203-0.211697-0.435962-0.519414-1.0027520.4825080.3110060.683955-0.877969-0.424683-1.8996432.968462[16rowsx64columns]batch1:0123456789545556575859606162630-0.2642360.9656811.909974-0.338721-0.5541960.254583-0.5761111.766522-0.6525870.455450-1.0164260.458762-0.5132900.6184110.8772292.5265910.6145510.662366-1.2469071.12806611.732205-0.8581780.3240081.022650-1.1728650.513133-0.1216112.6300850.0724252.3322960.7376601.9882252.5446611.9954710.4478633.1744280.4449890.8604262.1377971.5375802-1.348308-1.0802211.7533940.1561930.4406521.015287-0.7906441.2155372.0370300.4765600.2969411.100837-0.1531941.329375-0.1889581.229344-1.3019190.938138-0.860689-0.86013730.601103-0.1564190.850114-0.324190-0.311584-2.232454-0.9031120.2426870.8019082.502464-0.3970071.150545-0.4739070.318961-1.9701261.967961-0.1868310.1318730.947445-0.2815734-1.821736-0.7852140.655805-1.7489691.0725160.3294451.969725-1.593312-0.423386-0.8702060.2327992.2786850.2559531.5162870.6227011.2191781.3461750.0721330.5107051.65685151.555663-0.717588-0.179829-1.6665740.370867-0.9828112.2553470.4222080.123719-0.0347820.0919121.646094-0.2053541.187103-1.287054-0.0681440.6976070.626403-0.3338280.53778260.5998410.943214-1.397184-0.607349-0.333995-1.222589-0.731189-0.9977061.8486110.2542380.3409861.3831131.6745922.229903-0.1574150.362868-0.4937621.9041360.0279031.19601770.0722341.386670-0.985962-1.1844860.958293-0.295773-1.529277-0.7278441.5105031.268154-0.3564590.3823310.138104-0.360916-0.6384481.305404-0.7564420.2991500.154600-0.4661548-0.008645-1.066763-0.7165552.148885-0.709739-0.1372660.3854010.6991391.907906-2.3575670.490190-1.2154121.2164590.659227-0.282908-0.9122660.5955691.2107010.7374070.8016729-0.006332-0.9499280.1926893.158421-1.292153-0.8302480.966141-2.0565140.0423641.4859270.480763-0.3185540.0058373.031636-0.4481171.0594030.5981060.8714270.3273211.09092110-1.152681-0.710162-0.456591-0.468090-0.2925660.747535-0.149907-0.3955230.170872-2.372754-1.2674610.043283-0.1149801.083042-0.2887761.4423180.7755910.728716-0.576776-0.72725711-0.955986-0.2774750.946888-0.2426871.2577440.3699940.4600730.728078-0.165204-0.761762-0.3079832.078995-1.0677921.8056370.6089681.722982-0.371174-0.6031820.2853871.11293212-0.8443470.8832241.222388-0.811387-0.5935570.157268-0.6503151.289236-1.472027-0.447092-0.5364332.465097-0.8229051.2727860.7036642.687270-0.9243880.596134-0.3671380.812242130.7764701.549248-0.2396930.1337830.7672551.996130-0.436228-0.327975-0.6507430.507769-0.8217931.387792-1.0521052.1236031.4210922.066746-0.7477660.627081-1.749071-0.679443141.2775790.6539450.045632-0.4097900.8297080.249433-0.6820510.601958-1.932014-2.0773970.1606111.0378560.6568320.992817-0.6840561.031199-0.1808664.579140-1.1235550.181580150.356328-2.038538-1.0189381.1127161.035987-2.2816000.416325-0.129400-0.718316-1.042091-0.0560920.5593810.8050261.7830321.6059070.323934-0.2438632.005648-1.0166671.604090[16rowsx64columns]batch2:01234567895455565758596061626300.6458541.291073-1.5889311.814376-0.1852700.846816-1.6868620.982995-0.9731081.2972030.8526001.5332310.6927292.437029-0.1781370.4934130.5974841.9091551.2578212.64432511.732205-0.8581780.3240081.022650-1.1728650.513133-0.1216112.6300850.0724252.3322960.7376601.9882252.5446611.9954710.4478633.1744280.4449890.8604262.1377971.53758023.298391-0.3639080.376535-0.2766921.262433-0.5956591.6945410.542514-0.4647560.368460-0.1694741.4208090.3044881.689731-1.128037-0.024476-1.3568082.160992-2.110703-0.47240430.626955-2.9885240.9155781.1235030.6359830.0780060.466728-0.9307652.1892861.5054992.4966491.6915780.6426642.0892051.9261871.185045-0.9699520.666007-0.0306410.66757440.396447-2.1164150.384262-1.6327790.859029-0.7265992.121946-1.3140460.744388-0.227106-1.9373522.3786200.0292201.215336-0.405487-0.834419-1.2198250.000676-0.8212930.3407975-2.1330140.379737-1.320323-0.425003-0.298524-2.2372050.9533270.1680060.5192050.6989760.7887711.2377311.5153781.2966950.0707180.7632811.0989200.557059-0.5825102.1514976-0.3909180.634039-1.3504610.0321290.1064280.3704101.2923870.986316-0.0953960.555067-1.792372-0.3575990.9122760.0887460.8669500.927208-0.3816432.5321190.464615-1.0442997-0.4079470.622332-0.345048-0.247587-0.4196770.2566951.165026-2.459640-0.576545-1.7707810.2340642.2786820.2569011.5162850.6234131.2191771.3467080.0721330.5111051.65685183.503946-1.1467510.1110700.114221-0.930330-0.2487691.166547-0.038856-0.301910-0.8430720.0931771.646091-0.2044051.187101-1.286342-0.0681450.6981410.626402-0.3334280.5377819-1.946920-0.4437880.5601033.584257-0.134643-1.538940-1.059084-0.1286792.503847-2.244587-0.6435521.608934-0.488734-0.2912531.633294-0.0187630.696360-0.6577610.6923951.741288100.3765200.583786-0.7050470.8555480.4714730.687240-0.6056460.4630471.619052-1.894214-0.6886521.974150-1.3994122.567682-0.0500401.782055-0.2979122.366196-1.8885270.63526011-0.109256-1.3940540.565499-0.093785-1.8033090.662382-1.5282031.644028-0.5691330.4381010.7418771.9882142.5478231.9954650.4502343.1744240.4467680.8604242.1391301.53757912-1.553993-0.9834210.392842-1.4731861.5303871.894017-0.732786-1.601045-0.7403440.245303-0.3288283.0138831.1782961.2633330.2848240.7918742.402131-0.231270-1.0254110.17874813-0.7579651.7713060.805440-0.5091211.2122500.388750-0.6069592.352489-2.445346-0.1032230.4255561.7830190.6983361.8715302.3140230.424368-1.0027450.983784-0.0901330.90533714-0.1834820.232812-1.628186-0.0809810.377364-1.309922-0.2759641.609030-1.4403470.4279690.7925661.2377151.5182241.2966860.0728530.7632761.1005200.557057-0.5813102.15149615-0.151101-0.257150-0.478131-1.1700821.318685-0.1881660.1463752.895475-0.918949-0.3052611.6233501.656103-0.6004561.039260-1.9442020.8949111.4093961.722673-0.1720702.265543[16rowsx64columns]batch3:01234567895455565758596061626300.377847-0.3806131.9586400.224087-0.4202930.915635-1.0777481.255988-0.2231470.977568-1.2905321.4609631.365088-2.037483-2.2138411.039091-2.1296490.108403-0.3569962.23935610.5279610.3427870.0967460.8850160.7066992.8736560.1397320.497379-0.009022-0.147825-0.4099130.785146-0.1381662.0410000.2775001.578947-1.5351130.912230-0.3127350.54036521.054965-0.1344112.155045-0.1887240.651576-0.265663-0.7772630.5710801.5086611.0217180.7624582.297400-0.624743-0.9792122.0240081.2956330.2088250.953138-2.9626241.5869013-1.032970-0.8939180.029077-0.2320680.370793-1.4070921.0480660.9811230.3319071.2920720.7879281.2377321.5147461.2966950.0702440.7632811.0985640.557060-0.5827772.1514974-0.980037-1.0146051.875135-2.4596350.486067-0.9410921.2054901.2485311.8013830.5769830.1920971.784109-0.2010230.4050950.9820411.9276370.0085351.063376-1.4397872.9671855-0.369996-1.151058-0.1262220.7684310.107524-0.4810102.056029-0.8728151.522675-0.4409160.246007-1.0326840.5725650.9447440.790383-0.034063-1.704374-0.0533191.7395372.3815066-0.555136-0.284736-0.162689-1.542923-1.619371-2.0142240.957231-0.3381641.353500-2.0484360.180549-0.5986030.4271751.8450720.924364-0.013093-0.054108-0.082885-0.7192180.96055270.5488341.1304441.2074970.565839-1.814344-0.1115230.480270-1.7418231.451116-0.9776401.692325-0.708754-0.7475911.373189-0.224415-0.074035-0.3234352.001849-1.1025841.64465880.117209-0.9054900.2723360.9948480.6489510.354459-0.731171-1.641071-0.966286-0.8374980.2940061.0087741.3769442.9695550.9974522.0767080.6313581.0806000.0753841.81930290.557786-0.6293951.6067580.633762-1.190379-0.355466-2.132275-0.8877071.208793-0.7415050.7654102.297393-0.622529-0.9792162.0256681.2956310.2100700.953136-2.9616911.586900101.107697-2.0504591.3998691.271179-1.3915291.103020-0.910370-0.398901-0.803458-2.0813021.462017-0.1157300.1710520.5941180.5143881.5932230.064085-0.029184-0.0446211.20641511-1.7719330.4694750.9617300.0027981.3860890.250342-0.062900-0.569053-2.149857-0.519952-0.725692-0.727693-0.1786831.675822-0.4017121.1093310.980627-0.357667-0.4848530.20834012-1.5182131.899549-0.320427-0.929415-0.7010200.727833-2.7644980.6127560.041370-1.599998-0.1363141.0689950.6355010.7653690.2700070.319588-0.6529921.3226581.7242272.343042130.0949230.575470-0.852224-2.0985930.9985790.347285-0.4676880.773722-1.664829-0.412623-1.2742620.454381-1.1421071.853844-1.9125370.5443110.667555-1.1874681.2911082.27595614-0.1834820.232812-1.628186-0.0809810.377364-1.309922-0.2759641.609030-1.4403470.4279690.7925661.2377151.5182241.2966860.0728530.7632761.1005200.557057-0.5813102.151496152.053710-2.769740-0.1487960.983717-0.038190-0.6553601.826909-0.332533-1.036128-1.0014300.6743100.695848-0.1816351.051397-0.8848971.590696-1.3751170.596254-0.6513980.797715[16rowsx64columns]
最终输入嵌入将馈送到Transformer解码器模块进行训练。
这个最终结果矩阵称为位置输入嵌入,其形状为(batch_size,context_length,d_model)=[4,16,64]。
这都是关于位置编码的。
到目前为止,我们已经介绍了模型的输入编码和位置编码部分。让我们转到变压器块。
4TransformerBlockTransformer模块是由三层组成的堆栈:一个屏蔽的多头注意力机制、两个归一化层和一个前馈网络。
蒙面的多头注意力是一组自我注意,每个自我注意都称为一个头。因此,让我们先来看看自我注意力机制。
4.1Multi-HeadAttentionOverview多头注意力只是由几个单独的头堆叠在一起组成。所有磁头都接收到完全相同的输入,尽管它们在计算过程中使用自己特定的权重集。处理输入后,所有磁头的输出被连接起来,然后通过线性层。
下图提供了头部内过程的可视化表示,以及多头注意力模块中的详细信息。
为了进行证明计算,让我们从原始论文“注意力是你所需要的”中引入公式:
从公式中,我们首先需要三个矩阵:Q(查询)、K(键)和V(值)。要计算注意力分数,我们需要执行以下步骤:
将Q乘以K转置(表示为K^T)
除以K维数的平方根
应用SoftMax函数
乘以V
我们将一一介绍。
4.2PrepareQ,K,V计算注意力的第一步是获取Q、K和V矩阵,分别表示查询、键和值。这三个值将用于我们的注意力层来计算注意力概率(权重)。这些是通过将上一步中的位置输入嵌入矩阵(表示为X)应用于标记为Wq、Wk和Wv的三个不同的线性层来确定的(所有值都是随机分配的,首先可学习)。然后将每个线性层的输出拆分为多个磁头,表示为num_heads,这里我们选择4个磁头。
Wq、Wk、Wv是三个矩阵,维度为(d_model,d_model)=[64,64]。所有值都是随机分配的。这在神经网络中称为线性层或可训练参数。可训练参数是模型在训练期间将学习和自我更新的值。
为了获得我们的Q,K,V值,我们在输入嵌入矩阵X和三个矩阵Wq、Wk、Wv中的每一个之间进行矩阵乘法(再次,它们的初始值是随机分配的)。
Q=X*Wq
K=X*周
V=X*Wv
上述函数的计算(矩阵乘法)逻辑:
X的形状为(batch_size,context_length,d_model)=[4,16,64],我们将其分解为4个形状为[16,64]的子矩阵。而Wq、Wk、Wv的形状为(d_model,d_model)=[64,64]。我们可以对4个X的子矩阵中的每一个进行矩阵乘法,以Wq、Wk、Wv为单位。
如果回想一下线性代数,则只有当第一个矩阵中的列数等于第二个矩阵中的行数时,才有可能对两个矩阵进行乘法。在我们的例子中,X中的列数是64,Wq、Wk、Wv中的行数也是64。因此,乘法是可能的。
矩阵乘法得到4个形状为[16,64]的子矩阵的形状,可以组合表示为(batch_size,context_length,d_model)=[4,16,64]。
现在,我们的Q、K、V矩阵的形状为(batch_size,context_length,d_model)=[4,16,64]。接下来,我们需要将它们拆分为多个头。这就是为什么变压器架构将其命名为多头注意力的原因。
劈头只是意味着在d_model的64个维度中,我们将它们切割成多个头部,每个头部包含一定数量的维度。每个头部都将能够学习输入的某些模式或语义。
假设我们将num_heads也设置为4。这意味着我们将Q、K、V形状为[4,16,64]的矩阵拆分为多个子矩阵。
实际的拆分是通过将64的最后一个维度重塑为16的4个子维度来完成的。
每个Q、K、V矩阵从形状[4,16,64]转换为[4,16,4,16]。最后两个维度是头部。换句话说,它从以下转变而来:
[batch_size、context_length、d_model]
自:
[batch_size、context_length、num_heads、head_size]
要理解具有相同形状的Q、K和V矩阵[4,16,4,16],请考虑以下观点:
在管道中,有四个批次。每批由16个代币(单词)组成。对于每个标记,有4个头,每个头编码16个维度的语义信息。
4.3CalculateQ,KAttention现在我们已经有了Q、K和V这三个矩阵,让我们开始逐步计算单头注意力。
从变压器图中,Q和K矩阵首先相乘。
现在,如果我们丢弃Q和K矩阵中的batch_size,只保留最后三个维度,现在Q=K=V=[context_length,num_heads,head_size]=[16,4,16]。
我们需要在前两个维度上再做一个转置,使它们的形状为Q=K=V=[num_heads,context_length,head_size]=[4,16,16]。这是因为我们需要在最后两个维度上进行矩阵乘法运算。
Q*K^T=[4,16,16]*[4,16,16]=[4,16,16]
我们为什么要这样做?此处的转置是为了促进不同上下文之间的矩阵乘法。用图表解释更直接。最后两个维度表示为[16,16],可以可视化如下:
这个矩阵,其中每行和每列在我们的例句的上下文中代表一个标记(单词)。矩阵乘法是衡量上下文中每个单词与所有其他单词之间的相似性。该值越高,它们越相似。
让我提出一个注意力得分的头:
[0.2712,0.5608,-0.4975,,-0.4172,-0.2944,0.1899],[-0.0456,0.3352,-0.2611,,0.0419,1.0149,0.2020],[-0.0627,0.1498,-0.3736,,-0.3537,0.6299,0.3374],,,,,,,,,,,,,,,[-0.4166,-0.3364,-0.0458,,-0.2498,-0.1401,-0.0726],[0.4109,1.3533,-0.9120,,0.7061,-0.0945,0.2296],[-0.0602,0.2428,-0.3014,,-0.0209,-0.6606,-0.3170][16rowsx16columns]
这个16x16矩阵中的数字代表我们的例句“”的注意力分数。.Bymasteringtheartofidentifyingunderlyingmotivationsanddesires,weequipourselveswith
更容易看作一个情节:
横轴代表Q的头之一,纵轴表示K的头之一,彩色方块表示上下文中每个令牌和彼此令牌之间的相似性分数。颜色越深,相似度越高。
当然,上面显示的相似之处现在没有多大意义,因为这些只是来自随机分配的值。但是经过训练,相似性分数将是有意义的。
好了,现在让我们把批次维度batch_size带回Q*K注意力分数。最终结果的形状为[batch_size,num_heads,context_length,head_size],即[4,4,16,16]。
这是当前步骤的Q*K注意力分数。
4.4Scale量表部分很简单,我们只需要将Q*K^T注意力分数除以K维度的平方根即可。
在这里,我们的K维数等于Q的维数,d_model除以num_heads:64/4=16。
然后我们取16的平方根,即4。并将Q*K^T注意力得分除以4。
这样做的原因是为了防止Q*K^T注意力分数过大,这可能会导致softmax函数饱和,进而导致梯度消失。
4.5Mask在仅解码器转换器模型中,掩蔽的自我注意力本质上充当序列填充。
解码器只能查看以前的字符,而不能查看未来的字符。因此,未来的字符被屏蔽并用于计算注意力权重。
如果我们再次可视化情节,这很容易理解:
空格表示0分,被屏蔽了
多头注意力层中屏蔽的要点是防止解码器“看到未来”。在我们的例句中,解码器只允许看到当前单词和它之前的所有单词。
4.6Softmaxsoftmax步骤将数字更改为一种特殊的列表,其中整个列表加起来为1。它增加了高数字并减少了低数字,从而创造了明确的选择。
简而言之,softmax函数用于将线性层的输出转换为概率分布。
在现代深度学习框架(如PyTorch)中,softmax函数是一个内置函数,使用起来非常简单:
这行代码会将softmax应用于我们在上一步中计算的所有注意力分数,并产生介于0和1之间的概率分布。
让我们也提出应用softmax后同一头的注意力分数:
现在,所有概率分数均为正数,加起来为1。
4.7CalculateVAttention最后一步是将softmax输出乘以V矩阵。
请记住,我们的V矩阵还将其拆分为多个头,形状为(batch_size,num_heads,context_length,head_size)=[4,4,16,16]。
而上一个softmax步骤的输出为(batch_size,num_heads,context_length,head_size)=[4,4,16,16]。
在这里,我们对两个矩阵的最后两个维度执行另一个矩阵乘法。
softmax_output*V=[4,4,16,16]*[4,4,16,16]=[4,4,16,16]
结果的形状为[batch_size,num_heads,context_length,head_size]=[4,4,16,16]。
我们称此结果为A。
4.8ConcatenateandOutput我们多头注意力的最后一步是将所有头连接在一起,并将它们穿过线性层。
串联的理想是将来自所有头部的信息组合在一起。因此,我们需要将A矩阵从[batch_size,num_heads,context_length,head_size]=[4,4,16,16]重塑为[batch_size,context_length,num_heads,head_size]=[4,16,4,16]。原因是我们需要将最后两个维度放在一起,因此可以很容易地将它们(通过矩阵乘法)组合回大小。num_heads``head_size``d_model=64
这可以通过PyTorch的内置函数轻松完成:
A=(1,2)[4,16,64][batch_size,context_length,d_model]
正如你所看到的,经过一系列的计算,我们的结果矩阵A现在回到了与我们的输入嵌入矩阵X相同的形状,即[batch_size,context_length,d_model]=[4,16,64]。由于此输出结果将作为输入传递到下一层,因此必须保持输入和输出相同的形状。
但在将其传递到下一层之前,我们需要对它执行另一个线性变换。这是通过在串联矩阵A和Wo之间执行另一个矩阵乘法来完成的。
这个Wo被随机分配了形状[d_model,d_model],并将在训练期间更新。
输出=A*Wo=[4,16,64]*[64,64]=[4,16,64]
线性层的输出是单头注意力的输出,表示为输出。
祝贺!现在我们已经完成了蒙面的多头注意力部分!让我们开始变压器块的其余部分。这些都很简单,所以我会快速浏览它们。
5ResidualConnectionandLayerNormalization残差连接(有时称为跳过连接)是允许原始输入X绕过一个或多个层的连接。
这只是原始输入X和多头注意力层输出的总和。由于它们的形状相同,因此将它们相加很简单。
残差连接后,该过程进入层归一化。LayerNorm是一种对网络中每一层的输出进行规范化的技术。这是通过减去平均值并除以图层输出的标准差来完成的。此技术用于防止层的输出变得太大或太小,这可能导致网络变得不稳定。
残差连接和层归一化在“AttentionisAllYouNeed”的原始论文中表示。AddNorm
6Feed-ForwardNetwork一旦我们有了归一化的注意力权重(概率分数),它将通过一个位置前馈网络进行处理。
前馈网络(FFN)由两个线性层组成,它们之间具有ReLU激活函数。让我们看看python代码是如何实现的:
Applythefinallinearlayertogetthelogitslogits=(d_model,vocab_size)(output)
我们将这个线性层之后的输出称为logits。logits是形状为[batch_size,context_length,vocab_size]=[4,16,3771]的矩阵。
然后使用最终的softmax函数将线性层的logits转换为概率分布。
logits=(logits,dim=-1)
注意:在训练过程中,我们不需要在这里应用softmax函数,而是使用nn。CrossEntropy函数,因为它内置了softmax行为。
我们如何查看形状[4,16,3771]的结果对数?实际上,经过所有计算,这是一个非常简单的想法:
我们有4个批处理管道,每个管道包含该输入序列中的所有16个单词,每个单词映射到词汇表中其他每个单词的概率。
如果模型在训练中,我们更新这些概率参数,如果模型在推理中,我们只需选择概率最高的一个。那么一切都有意义了。
总结Transformer架构的复杂性可能具有挑战性。如果想要深入了解,还需要结合实际代码多做尝试,我会在接下来的时间里,结合代码来说明Transformer架构。
原创声明:本文为本人原创作品,首发于AIONES,如果转载,请保留本文链接,谢谢。





