在其他任何事情之前,您可能需要申请访问TPU研究云(TRC)。再加上谷歌云免费试用,你可以免费在这里做任何事情。一旦你进入TRC,你需要创建一个项目,然后用新项目的名称填写发送给你的表格。使用脚本create_finetune_tfrecords.py将数据准备为tfrecords;我可能会做一个单独的指南。您可能需要做的另一件事是fork mesh-transformer-jax代码,以便更容易添加和修改配置文件。
- 安装Google Cloud SDK,我们稍后需要它。
- 如果你还没有完成一个项目并通过TRC激活TPU访问(或者如果你计划自掏腰包),免费申请。
- TPU使用谷歌云存储桶进行存储,现在就创建一个吧。确保它位于TPU VM所在的区域;TRC的电子邮件将告诉您可以在哪个地区使用免费TPU。
- 为了微调模型,您需要完全预训练的权重。下载地址
现在云上有了一个存储桶,电脑上有了权重,您需要分两步将权重上传到存储桶:
- 解压缩并提取GPT-J-6B/step_383500.tar.zstd,这样您就剩下包含碎片检查点的未压缩文件夹。
- 打开Google Cloud SDK并运行以下命令,根据需要替换路径名:gsutil-m cp -R LOCAL_path_TO/step_383500 gs://YOUR-BUCKET。如果成功,控制台将显示正在上载的文件。注:我从加州上传到荷兰,花了大约12个小时;希望你的地理位置比我好!我最初也犯了上传仍然打包的.tar的错误。不要这样做,TPU虚拟机没有足够的本地存储空间供您解压缩。为了避免重新上传,我不得不在Colab中解压缩。
你也会想上传你的数据的tfrecords,说实在的,没有人会想通过网络界面上传近70GB的权重。
请注意,稍后可以通过在VM的文本编辑器中编辑基本repo来完成准备索引和配置文件的步骤6和7。相反,对您自己的repo分支进行以下更改更有效:
- 在数据文件夹中,创建一个新的文件foo.train.index,将foo替换为您希望引用数据集的任何内容。对于存储桶中要训练的每个tfrecord,将路径添加为索引中的一行。创建foo.val.index,并对验证数据集执行同样的操作(如果有)。有关示例,请参见现有文件。
- 复制配置文件6B_roto_256.json,将其重命名为适合您项目的名称。打开它并进行以下编辑:
- tpu_size:从256更改为8
- bucket:换成你的bucket
- model_dir:更改到要保存checkpoints的目录
- train_set和val_set:从最后一步更改为索引文件
- eval_haness_tasks:如果不打算使用eval线束,则可以删除
- val_every&ckpt_every&keep_every:用法应该直观。但不要将foo_every值设置为0,否则会出现被零除的错误。如果没有val_set,只需将val_every设置为高于total_steps的值
- val_batches:这应该等于val数据集中的序列数。您可以在create_finetune_tfrecords.py生成的.tfsrecords文件的末尾找到此数字
- name:更改为模型的名称
- warmup_steps、lr、val_batches等:请参阅本指南末尾的学习率注释部分
- 将更改推送到GitHub存储库。
- 遵循本指南,直到连接到您的云TPU VM。
此时,您应该可以远程访问TPU VM!
- 在新的VM终端中,键入git clone https://github.com/kingoflolz/mesh-transformer-jax(或者,最好是在推送配置文件和索引文件之后,您自己的fork)。
- 使用cd mesh-transformer-jax移动到新目录,并运行pip install-r requirements.txt。由于requirements.txt文件没有固定微调所需的确切jax版本,因此运行pip安装jax==0.2.12,您就完成了。
- 最后,运行python3 device_train.py --config=YOUR_config.json --t