cartman/train/train.ipynb

2762 lines
92 KiB
Text
Raw Permalink Normal View History

2023-02-08 10:22:57 -05:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "VTze-VbeU1c0"
},
"source": [
"# Fine-tune a DialoGPT model\n",
"\n",
"Adapted from a notebook adapted from the notebook in [this Medium post](https://towardsdatascience.com/make-your-own-rick-sanchez-bot-with-transformers-and-dialogpt-fine-tuning-f85e6d1f4e30?gi=e4a72d1510f0)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y17kuzFNUSrZ"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dnv5kT-mLsB-",
"tags": []
},
"outputs": [],
"source": [
"# all the imports\n",
"\n",
"import glob\n",
"import logging\n",
"import os\n",
"import pickle\n",
"import random\n",
"import re\n",
"import shutil\n",
"from typing import Dict, List, Tuple\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler\n",
"from torch.utils.data.distributed import DistributedSampler\n",
"from tqdm.notebook import tqdm, trange\n",
"\n",
"from pathlib import Path\n",
"\n",
"from transformers import (\n",
" MODEL_WITH_LM_HEAD_MAPPING,\n",
" WEIGHTS_NAME,\n",
" AdamW,\n",
" AutoConfig,\n",
" PreTrainedModel,\n",
" PreTrainedTokenizer,\n",
" get_linear_schedule_with_warmup,\n",
")\n",
"\n",
"\n",
"try:\n",
" from torch.utils.tensorboard import SummaryWriter\n",
"except ImportError:\n",
" from tensorboardX import SummaryWriter"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BmrbGB8aUmBm"
},
"source": [
"## Load Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-18T05:27:13.388169Z",
"iopub.status.busy": "2022-10-18T05:27:13.387823Z",
"iopub.status.idle": "2022-10-18T05:27:13.451208Z",
"shell.execute_reply": "2022-10-18T05:27:13.450527Z",
"shell.execute_reply.started": "2022-10-18T05:27:13.388152Z"
},
"id": "RXdJTSVwWGHj",
"tags": []
},
"outputs": [],
"source": [
"data = pd.read_csv('data/train.csv')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 238
},
"execution": {
"iopub.execute_input": "2022-10-18T05:27:13.452257Z",
"iopub.status.busy": "2022-10-18T05:27:13.451996Z",
"iopub.status.idle": "2022-10-18T05:27:13.465447Z",
"shell.execute_reply": "2022-10-18T05:27:13.464448Z",
"shell.execute_reply.started": "2022-10-18T05:27:13.452239Z"
},
"id": "h6kGx-9eG7qA",
"outputId": "bd2efe43-1e50-4716-81a2-bf15a3dd03bd",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>name</th>\n",
" <th>line</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>14152</th>\n",
" <td>Cartman</td>\n",
" <td>I... told your mom you got an F on that socia...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10447</th>\n",
" <td>Kyle</td>\n",
" <td>It's that ring. Somehow, putting on that ring ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>65079</th>\n",
" <td>Kyle</td>\n",
" <td>Hey Look!</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14862</th>\n",
" <td>Benedict XVI</td>\n",
" <td>Tom, Tom! The gingers are claiming they have M...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6060</th>\n",
" <td>Squirrelly Squirrel</td>\n",
" <td>Now come on y'all. We can't waste time arguing...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>63693</th>\n",
" <td>Cartman</td>\n",
" <td>How could you be so stupid!</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" name line\n",
"14152 Cartman I... told your mom you got an F on that socia...\n",
"10447 Kyle It's that ring. Somehow, putting on that ring ...\n",
"65079 Kyle Hey Look!\n",
"14862 Benedict XVI Tom, Tom! The gingers are claiming they have M...\n",
"6060 Squirrelly Squirrel Now come on y'all. We can't waste time arguing...\n",
"63693 Cartman How could you be so stupid!"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.sample(6)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-18T05:27:13.466439Z",
"iopub.status.busy": "2022-10-18T05:27:13.466236Z",
"iopub.status.idle": "2022-10-18T05:27:13.469713Z",
"shell.execute_reply": "2022-10-18T05:27:13.468950Z",
"shell.execute_reply.started": "2022-10-18T05:27:13.466421Z"
},
"id": "PG8v6--qWUwj",
"tags": []
},
"outputs": [],
"source": [
"CHARACTER_NAME = 'Cartman'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-18T05:27:13.470543Z",
"iopub.status.busy": "2022-10-18T05:27:13.470349Z",
"iopub.status.idle": "2022-10-18T05:27:13.643993Z",
"shell.execute_reply": "2022-10-18T05:27:13.643224Z",
"shell.execute_reply.started": "2022-10-18T05:27:13.470528Z"
},
"id": "GZUcEMd2WLDT",
"tags": []
},
"outputs": [],
"source": [
"contexted = []\n",
"\n",
"# context window of size 7\n",
"n = 1\n",
"\n",
"for i in data[data.name == CHARACTER_NAME].index:\n",
" if i < n:\n",
" continue\n",
" row = []\n",
" prev = i - 1 - n # we additionally substract 1, so row will contain current response and 7 previous responses \n",
" for j in range(i, prev, -1):\n",
" row.append(data.line[j])\n",
" contexted.append(row)\n",
"\n",
"columns = ['response', 'context'] \n",
"columns = columns + ['context/' + str(i) for i in range(n - 1)]\n",
"\n",
"df = pd.DataFrame.from_records(contexted, columns=columns)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 446
},
"execution": {
"iopub.execute_input": "2022-10-18T05:27:13.644878Z",
"iopub.status.busy": "2022-10-18T05:27:13.644713Z",
"iopub.status.idle": "2022-10-18T05:27:13.652020Z",
"shell.execute_reply": "2022-10-18T05:27:13.651291Z",
"shell.execute_reply.started": "2022-10-18T05:27:13.644862Z"
},
"id": "4T5OlNZHUxij",
"outputId": "895603a6-ca02-4301-c4b0-5bccbee8a3b8",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>response</th>\n",
" <th>context</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>13726</th>\n",
" <td>Yeah, we're gonna use him to help raise money ...</td>\n",
" <td>Yahahah!!!</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6269</th>\n",
" <td>My password is uloveboobs!</td>\n",
" <td>How did you know that?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2703</th>\n",
" <td>God, you are such a pussy, Stan! You're such a...</td>\n",
" <td>Dude, there's nothing I can do about it.</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5788</th>\n",
" <td>Wow, that makes sense. Don't think anyone can ...</td>\n",
" <td>It's a complicated political issue, my son. An...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9795</th>\n",
" <td>Ugh. Uunnh. Ey, give me that! Ahhhh...</td>\n",
" <td>It'll make your itches go away.</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5086</th>\n",
" <td>Heh! You're crazy! It can't be done!</td>\n",
" <td>Alright Eric, here's the deal: This school can...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" response \\\n",
"13726 Yeah, we're gonna use him to help raise money ... \n",
"6269 My password is uloveboobs! \n",
"2703 God, you are such a pussy, Stan! You're such a... \n",
"5788 Wow, that makes sense. Don't think anyone can ... \n",
"9795 Ugh. Uunnh. Ey, give me that! Ahhhh... \n",
"5086 Heh! You're crazy! It can't be done! \n",
"\n",
" context \n",
"13726 Yahahah!!! \n",
"6269 How did you know that? \n",
"2703 Dude, there's nothing I can do about it. \n",
"5788 It's a complicated political issue, my son. An... \n",
"9795 It'll make your itches go away. \n",
"5086 Alright Eric, here's the deal: This school can... "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.sample(6)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 380
},
"execution": {
"iopub.execute_input": "2022-10-18T05:27:13.654567Z",
"iopub.status.busy": "2022-10-18T05:27:13.653983Z",
"iopub.status.idle": "2022-10-18T05:27:13.669974Z",
"shell.execute_reply": "2022-10-18T05:27:13.669272Z",
"shell.execute_reply.started": "2022-10-18T05:27:13.654539Z"
},
"id": "NGy0MxMQVIAP",
"outputId": "08b7f0eb-6a38-4b83-efdc-e53778d7547a",
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>response</th>\n",
" <th>context</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>8751</th>\n",
" <td>Gimme that cake!</td>\n",
" <td>Who the hell are you?!</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10518</th>\n",
" <td>Aww, screw you guys anyway!</td>\n",
" <td>Cartman, will you shut the hell up and get som...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7483</th>\n",
" <td>You wanna take it out for a spin?</td>\n",
" <td>It can fly like a quarter mile away from whoev...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16857</th>\n",
" <td>Well, I guess we can go back to playing laundr...</td>\n",
" <td>Son of a bitch stupid FBI!</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6644</th>\n",
" <td>Hey Token? How are you doing?</td>\n",
" <td>No! Nonono! Nooo!</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" response \\\n",
"8751 Gimme that cake! \n",
"10518 Aww, screw you guys anyway! \n",
"7483 You wanna take it out for a spin? \n",
"16857 Well, I guess we can go back to playing laundr... \n",
"6644 Hey Token? How are you doing? \n",
"\n",
" context \n",
"8751 Who the hell are you?! \n",
"10518 Cartman, will you shut the hell up and get som... \n",
"7483 It can fly like a quarter mile away from whoev... \n",
"16857 Son of a bitch stupid FBI! \n",
"6644 No! Nonono! Nooo! "
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trn_df, val_df = train_test_split(df, test_size=0.1)\n",
"trn_df.head()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-18T05:27:13.671526Z",
"iopub.status.busy": "2022-10-18T05:27:13.671152Z",
"iopub.status.idle": "2022-10-18T05:27:13.682715Z",
"shell.execute_reply": "2022-10-18T05:27:13.681858Z",
"shell.execute_reply.started": "2022-10-18T05:27:13.671496Z"
},
"id": "aEeJQlAKWtiJ",
"tags": []
},
"outputs": [],
"source": [
"# create dataset suitable for our model\n",
"def construct_conv(row, tokenizer, eos = True):\n",
" flatten = lambda l: [item for sublist in l for item in sublist]\n",
" conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))\n",
" conv = flatten(conv)\n",
" return conv\n",
"\n",
"class ConversationDataset(Dataset):\n",
" def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):\n",
"\n",
" block_size = block_size - (tokenizer.model_max_length - tokenizer.max_len_single_sentence)\n",
"\n",
" directory = args.cache_dir\n",
" cached_features_file = os.path.join(\n",
" directory, args.model_type + \"_cached_lm_\" + str(block_size)\n",
" )\n",
"\n",
" if os.path.exists(cached_features_file) and not args.overwrite_cache:\n",
" logger.info(\"Loading features from cached file %s\", cached_features_file)\n",
" with open(cached_features_file, \"rb\") as handle:\n",
" self.examples = pickle.load(handle)\n",
" else:\n",
" logger.info(\"Creating features from dataset file at %s\", directory)\n",
"\n",
" self.examples = []\n",
" for _, row in df.iterrows():\n",
" conv = construct_conv(row, tokenizer)\n",
" self.examples.append(conv)\n",
"\n",
" logger.info(\"Saving features into cached file %s\", cached_features_file)\n",
" with open(cached_features_file, \"wb\") as handle:\n",
" pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
"\n",
" def __len__(self):\n",
" return len(self.examples)\n",
"\n",
" def __getitem__(self, item):\n",
" return torch.tensor(self.examples[item], dtype=torch.long)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-18T05:27:13.684175Z",
"iopub.status.busy": "2022-10-18T05:27:13.683859Z",
"iopub.status.idle": "2022-10-18T05:27:13.693632Z",
"shell.execute_reply": "2022-10-18T05:27:13.692937Z",
"shell.execute_reply.started": "2022-10-18T05:27:13.684147Z"
},
"id": "-3iHwoKlWyrs",
"tags": []
},
"outputs": [],
"source": [
"# Caching and storing of data/checkpoints\n",
"\n",
"def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):\n",
" return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)\n",
"\n",
"\n",
"def set_seed(args):\n",
" random.seed(args.seed)\n",
" np.random.seed(args.seed)\n",
" torch.manual_seed(args.seed)\n",
" if args.n_gpu > 0:\n",
" torch.cuda.manual_seed_all(args.seed)\n",
"\n",
"\n",
"def _sorted_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> List[str]:\n",
" ordering_and_checkpoint_path = []\n",
"\n",
" glob_checkpoints = glob.glob(os.path.join(args.output_dir, \"{}-*\".format(checkpoint_prefix)))\n",
"\n",
" for path in glob_checkpoints:\n",
" if use_mtime:\n",
" ordering_and_checkpoint_path.append((os.path.getmtime(path), path))\n",
" else:\n",
" regex_match = re.match(\".*{}-([0-9]+)\".format(checkpoint_prefix), path)\n",
" if regex_match and regex_match.groups():\n",
" ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))\n",
"\n",
" checkpoints_sorted = sorted(ordering_and_checkpoint_path)\n",
" checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]\n",
" return checkpoints_sorted\n",
"\n",
"\n",
"def _rotate_checkpoints(args, checkpoint_prefix=\"checkpoint\", use_mtime=False) -> None:\n",
" if not args.save_total_limit:\n",
" return\n",
" if args.save_total_limit <= 0:\n",
" return\n",
"\n",
" # Check if we should delete older checkpoint(s)\n",
" checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)\n",
" if len(checkpoints_sorted) <= args.save_total_limit:\n",
" return\n",
"\n",
" number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)\n",
" checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]\n",
" for checkpoint in checkpoints_to_be_deleted:\n",
" logger.info(\"Deleting older checkpoint [{}] due to args.save_total_limit\".format(checkpoint))\n",
" shutil.rmtree(checkpoint)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EEDdTJTqUwZJ"
},
"source": [
"## Build Model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"execution": {
"iopub.execute_input": "2022-10-18T05:27:13.694533Z",
"iopub.status.busy": "2022-10-18T05:27:13.694353Z",
"iopub.status.idle": "2022-10-18T05:27:21.231192Z",
"shell.execute_reply": "2022-10-18T05:27:21.230658Z",
"shell.execute_reply.started": "2022-10-18T05:27:13.694517Z"
},
"id": "r2cE0fY5UHpz",
"outputId": "e4f382cd-57d9-49b7-9da4-4b44fe57df5b",
"tags": []
},
"outputs": [],
"source": [
"from transformers import AutoModelWithLMHead, AutoModelForCausalLM, AutoTokenizer\n",
"import torch\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-large\")\n",
"model = AutoModelForCausalLM.from_pretrained(\"microsoft/DialoGPT-large\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-18T05:27:21.232063Z",
"iopub.status.busy": "2022-10-18T05:27:21.231898Z",
"iopub.status.idle": "2022-10-18T05:27:21.259427Z",
"shell.execute_reply": "2022-10-18T05:27:21.258757Z",
"shell.execute_reply.started": "2022-10-18T05:27:21.232048Z"
},
"id": "ra2vsRp-UMXo",
"tags": []
},
"outputs": [],
"source": [
"\"\"\"\n",
"Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).\n",
"GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned\n",
"using a masked language modeling (MLM) loss.\n",
"\"\"\"\n",
"\n",
"# Configs\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())\n",
"MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-18T05:27:21.260380Z",
"iopub.status.busy": "2022-10-18T05:27:21.260189Z",
"iopub.status.idle": "2022-10-18T05:27:21.265817Z",
"shell.execute_reply": "2022-10-18T05:27:21.265203Z",
"shell.execute_reply.started": "2022-10-18T05:27:21.260362Z"
},
"id": "2OnASqJjUNJa",
"tags": []
},
"outputs": [],
"source": [
"# Args to allow for easy conversion of python script to notebook\n",
"class Args():\n",
" def __init__(self):\n",
" self.output_dir = 'models/output-medium-oneline'\n",
" self.model_type = 'gpt2'\n",
" self.model_name_or_path = 'microsoft/DialoGPT-medium'\n",
" self.config_name = 'microsoft/DialoGPT-medium'\n",
" self.tokenizer_name = 'microsoft/DialoGPT-medium'\n",
" self.cache_dir = 'cached'\n",
" self.block_size = 512\n",
" self.do_train = True\n",
" self.do_eval = True\n",
" self.evaluate_during_training = False\n",
" self.per_gpu_train_batch_size = 4\n",
" self.per_gpu_eval_batch_size = 4\n",
" self.gradient_accumulation_steps = 1\n",
" self.learning_rate = 5e-5\n",
" self.weight_decay = 0.0\n",
" self.adam_epsilon = 1e-8\n",
" self.max_grad_norm = 1.0\n",
" self.num_train_epochs = 4\n",
" self.max_steps = -1\n",
" self.warmup_steps = 0\n",
" self.logging_steps = 1000\n",
" self.save_steps = 3500\n",
" self.save_total_limit = None\n",
" self.eval_all_checkpoints = False\n",
" self.no_cuda = False\n",
" self.overwrite_output_dir = True\n",
" self.overwrite_cache = True\n",
" self.should_continue = False\n",
" self.seed = 42\n",
" self.local_rank = -1\n",
" self.fp16 = False\n",
" self.fp16_opt_level = 'O1'\n",
"\n",
"args = Args()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Q1dTFXxW9NE"
},
"source": [
"## Train and Evaluate"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-18T05:27:21.266896Z",
"iopub.status.busy": "2022-10-18T05:27:21.266734Z",
"iopub.status.idle": "2022-10-18T05:27:21.302284Z",
"shell.execute_reply": "2022-10-18T05:27:21.301566Z",
"shell.execute_reply.started": "2022-10-18T05:27:21.266881Z"
},
"id": "PaarIDZrW81h",
"tags": []
},
"outputs": [],
"source": [
"def train(args, train_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:\n",
" \"\"\" Train the model \"\"\"\n",
" if args.local_rank in [-1, 0]:\n",
" tb_writer = SummaryWriter()\n",
"\n",
" args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)\n",
"\n",
" def collate(examples: List[torch.Tensor]):\n",
" if tokenizer._pad_token is None:\n",
" return pad_sequence(examples, batch_first=True)\n",
" return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n",
"\n",
" train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)\n",
" train_dataloader = DataLoader(\n",
" train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate, drop_last = True\n",
" )\n",
"\n",
" if args.max_steps > 0:\n",
" t_total = args.max_steps\n",
" args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1\n",
" else:\n",
" t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs\n",
"\n",
" model = model.module if hasattr(model, \"module\") else model # Take care of distributed/parallel training\n",
" model.resize_token_embeddings(len(tokenizer))\n",
" # add_special_tokens_(model, tokenizer)\n",
"\n",
"\n",
" # Prepare optimizer and schedule (linear warmup and decay)\n",
" no_decay = [\"bias\", \"LayerNorm.weight\"]\n",
" optimizer_grouped_parameters = [\n",
" {\n",
" \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
" \"weight_decay\": args.weight_decay,\n",
" },\n",
" {\"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], \"weight_decay\": 0.0},\n",
" ]\n",
" optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)\n",
" scheduler = get_linear_schedule_with_warmup(\n",
" optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total\n",
" )\n",
"\n",
" # Check if saved optimizer or scheduler states exist\n",
" if (\n",
" args.model_name_or_path\n",
" and os.path.isfile(os.path.join(args.model_name_or_path, \"optimizer.pt\"))\n",
" and os.path.isfile(os.path.join(args.model_name_or_path, \"scheduler.pt\"))\n",
" ):\n",
" # Load in optimizer and scheduler states\n",
" optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"optimizer.pt\")))\n",
" scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, \"scheduler.pt\")))\n",
"\n",
" if args.fp16:\n",
" try:\n",
" from apex import amp\n",
" except ImportError:\n",
" raise ImportError(\"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.\")\n",
" model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)\n",
"\n",
" # multi-gpu training (should be after apex fp16 initialization)\n",
" if args.n_gpu > 1:\n",
" model = torch.nn.DataParallel(model)\n",
"\n",
" # Distributed training (should be after apex fp16 initialization)\n",
" if args.local_rank != -1:\n",
" model = torch.nn.parallel.DistributedDataParallel(\n",
" model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True\n",
" )\n",
"\n",
" # Train!\n",
" logger.info(\"***** Running training *****\")\n",
" logger.info(\" Num examples = %d\", len(train_dataset))\n",
" logger.info(\" Num Epochs = %d\", args.num_train_epochs)\n",
" logger.info(\" Instantaneous batch size per GPU = %d\", args.per_gpu_train_batch_size)\n",
" logger.info(\n",
" \" Total train batch size (w. parallel, distributed & accumulation) = %d\",\n",
" args.train_batch_size\n",
" * args.gradient_accumulation_steps\n",
" * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),\n",
" )\n",
" logger.info(\" Gradient Accumulation steps = %d\", args.gradient_accumulation_steps)\n",
" logger.info(\" Total optimization steps = %d\", t_total)\n",
"\n",
" global_step = 0\n",
" epochs_trained = 0\n",
" steps_trained_in_current_epoch = 0\n",
" # Check if continuing training from a checkpoint\n",
" if args.model_name_or_path and os.path.exists(args.model_name_or_path):\n",
" try:\n",
" # set global_step to gobal_step of last saved checkpoint from model path\n",
" checkpoint_suffix = args.model_name_or_path.split(\"-\")[-1].split(\"/\")[0]\n",
" global_step = int(checkpoint_suffix)\n",
" epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)\n",
" steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)\n",
"\n",
" logger.info(\" Continuing training from checkpoint, will skip to saved global_step\")\n",
" logger.info(\" Continuing training from epoch %d\", epochs_trained)\n",
" logger.info(\" Continuing training from global step %d\", global_step)\n",
" logger.info(\" Will skip the first %d steps in the first epoch\", steps_trained_in_current_epoch)\n",
" except ValueError:\n",
" logger.info(\" Starting fine-tuning.\")\n",
"\n",
" tr_loss, logging_loss = 0.0, 0.0\n",
"\n",
" model.zero_grad()\n",
" train_iterator = trange(\n",
" epochs_trained, int(args.num_train_epochs), desc=\"Epoch\", disable=args.local_rank not in [-1, 0]\n",
" )\n",
" set_seed(args) # Added here for reproducibility\n",
" for _ in train_iterator:\n",
" epoch_iterator = tqdm(train_dataloader, desc=\"Iteration\", disable=args.local_rank not in [-1, 0])\n",
" for step, batch in enumerate(epoch_iterator):\n",
"\n",
" # Skip past any already trained steps if resuming training\n",
" if steps_trained_in_current_epoch > 0:\n",
" steps_trained_in_current_epoch -= 1\n",
" continue\n",
"\n",
" inputs, labels = (batch, batch)\n",
" if inputs.shape[1] > 1024: continue\n",
" inputs = inputs.to(args.device)\n",
" labels = labels.to(args.device)\n",
" model.train()\n",
" outputs = model(inputs, labels=labels)\n",
" loss = outputs[0] # model outputs are always tuple in transformers (see doc)\n",
"\n",
" if args.n_gpu > 1:\n",
" loss = loss.mean() # mean() to average on multi-gpu parallel training\n",
" if args.gradient_accumulation_steps > 1:\n",
" loss = loss / args.gradient_accumulation_steps\n",
"\n",
" if args.fp16:\n",
" with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
" scaled_loss.backward()\n",
" else:\n",
" loss.backward()\n",
"\n",
" tr_loss += loss.item()\n",
" if (step + 1) % args.gradient_accumulation_steps == 0:\n",
" if args.fp16:\n",
" torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)\n",
" else:\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)\n",
" optimizer.step()\n",
" scheduler.step() # Update learning rate schedule\n",
" model.zero_grad()\n",
" global_step += 1\n",
"\n",
" if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:\n",
" # Log metrics\n",
" if (\n",
" args.local_rank == -1 and args.evaluate_during_training\n",
" ): # Only evaluate when single GPU otherwise metrics may not average well\n",
" results = evaluate(args, model, tokenizer)\n",
" for key, value in results.items():\n",
" tb_writer.add_scalar(\"eval_{}\".format(key), value, global_step)\n",
" tb_writer.add_scalar(\"lr\", scheduler.get_lr()[0], global_step)\n",
" tb_writer.add_scalar(\"loss\", (tr_loss - logging_loss) / args.logging_steps, global_step)\n",
" logging_loss = tr_loss\n",
"\n",
" if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:\n",
" checkpoint_prefix = \"checkpoint\"\n",
" # Save model checkpoint\n",
" output_dir = os.path.join(args.output_dir, \"{}-{}\".format(checkpoint_prefix, global_step))\n",
" os.makedirs(output_dir, exist_ok=True)\n",
" model_to_save = (\n",
" model.module if hasattr(model, \"module\") else model\n",
" ) # Take care of distributed/parallel training\n",
" model_to_save.save_pretrained(output_dir)\n",
" tokenizer.save_pretrained(output_dir)\n",
"\n",
" torch.save(args, os.path.join(output_dir, \"training_args.bin\"))\n",
" logger.info(\"Saving model checkpoint to %s\", output_dir)\n",
"\n",
" _rotate_checkpoints(args, checkpoint_prefix)\n",
"\n",
" torch.save(optimizer.state_dict(), os.path.join(output_dir, \"optimizer.pt\"))\n",
" torch.save(scheduler.state_dict(), os.path.join(output_dir, \"scheduler.pt\"))\n",
" logger.info(\"Saving optimizer and scheduler states to %s\", output_dir)\n",
"\n",
" if args.max_steps > 0 and global_step > args.max_steps:\n",
" epoch_iterator.close()\n",
" break\n",
" if args.max_steps > 0 and global_step > args.max_steps:\n",
" train_iterator.close()\n",
" break\n",
"\n",
" if args.local_rank in [-1, 0]:\n",
" tb_writer.close()\n",
"\n",
" return global_step, tr_loss / global_step\n",
"\n",
"# Evaluation of some model\n",
"\n",
"def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, df_trn, df_val, prefix=\"\") -> Dict:\n",
" # Loop to handle MNLI double evaluation (matched, mis-matched)\n",
" eval_output_dir = args.output_dir\n",
"\n",
" eval_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=True)\n",
" os.makedirs(eval_output_dir, exist_ok=True)\n",
" args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)\n",
" # Note that DistributedSampler samples randomly\n",
"\n",
" def collate(examples: List[torch.Tensor]):\n",
" if tokenizer._pad_token is None:\n",
" return pad_sequence(examples, batch_first=True)\n",
" return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)\n",
"\n",
" eval_sampler = SequentialSampler(eval_dataset)\n",
" eval_dataloader = DataLoader(\n",
" eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, drop_last = True\n",
" )\n",
"\n",
" # multi-gpu evaluate\n",
" if args.n_gpu > 1:\n",
" model = torch.nn.DataParallel(model)\n",
"\n",
" # Eval!\n",
" logger.info(\"***** Running evaluation {} *****\".format(prefix))\n",
" logger.info(\" Num examples = %d\", len(eval_dataset))\n",
" logger.info(\" Batch size = %d\", args.eval_batch_size)\n",
" eval_loss = 0.0\n",
" nb_eval_steps = 0\n",
" model.eval()\n",
"\n",
" for batch in tqdm(eval_dataloader, desc=\"Evaluating\"):\n",
" inputs, labels = (batch, batch)\n",
" inputs = inputs.to(args.device)\n",
" labels = labels.to(args.device)\n",
"\n",
" with torch.no_grad():\n",
" outputs = model(inputs, labels=labels)\n",
" lm_loss = outputs[0]\n",
" eval_loss += lm_loss.mean().item()\n",
" nb_eval_steps += 1\n",
"\n",
" eval_loss = eval_loss / nb_eval_steps\n",
" perplexity = torch.exp(torch.tensor(eval_loss))\n",
"\n",
" result = {\"perplexity\": perplexity}\n",
"\n",
" output_eval_file = os.path.join(eval_output_dir, prefix, \"eval_results.txt\")\n",
" with open(output_eval_file, \"w\") as writer:\n",
" logger.info(\"***** Eval results {} *****\".format(prefix))\n",
" for key in sorted(result.keys()):\n",
" logger.info(\" %s = %s\", key, str(result[key]))\n",
" writer.write(\"%s = %s\\n\" % (key, str(result[key])))\n",
"\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"execution": {
"iopub.execute_input": "2022-10-18T05:27:21.303584Z",
"iopub.status.busy": "2022-10-18T05:27:21.303400Z",
"iopub.status.idle": "2022-10-18T05:27:21.314912Z",
"shell.execute_reply": "2022-10-18T05:27:21.314269Z",
"shell.execute_reply.started": "2022-10-18T05:27:21.303568Z"
},
"id": "SCnGAJWbXD9C",
"tags": []
},
"outputs": [],
"source": [
"# Main runner\n",
"\n",
"def main(df_trn, df_val):\n",
" args = Args()\n",
" \n",
" if args.should_continue:\n",
" sorted_checkpoints = _sorted_checkpoints(args)\n",
" if len(sorted_checkpoints) == 0:\n",
" raise ValueError(\"Used --should_continue but no checkpoint was found in --output_dir.\")\n",
" else:\n",
" args.model_name_or_path = sorted_checkpoints[-1]\n",
"\n",
" if (\n",
" os.path.exists(args.output_dir)\n",
" and os.listdir(args.output_dir)\n",
" and args.do_train\n",
" and not args.overwrite_output_dir\n",
" and not args.should_continue\n",
" ):\n",
" raise ValueError(\n",
" \"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.\".format(\n",
" args.output_dir\n",
" )\n",
" )\n",
"\n",
" # Setup CUDA, GPU & distributed training\n",
" device = torch.device(\"cuda\")\n",
" args.n_gpu = torch.cuda.device_count()\n",
" args.device = device\n",
"\n",
" # Setup logging\n",
" logging.basicConfig(\n",
" format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n",
" datefmt=\"%m/%d/%Y %H:%M:%S\",\n",
" level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,\n",
" )\n",
" logger.warning(\n",
" \"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s\",\n",
" args.local_rank,\n",
" device,\n",
" args.n_gpu,\n",
" bool(args.local_rank != -1),\n",
" args.fp16,\n",
" )\n",
"\n",
" # Set seed\n",
" set_seed(args)\n",
"\n",
" config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)\n",
" tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)\n",
" model = AutoModelForCausalLM.from_pretrained(\n",
" args.model_name_or_path,\n",
" from_tf=False,\n",
" config=config,\n",
" cache_dir=args.cache_dir,\n",
" )\n",
" model.to(args.device)\n",
" \n",
" logger.info(\"Training/evaluation parameters %s\", args)\n",
"\n",
" # Training\n",
" if args.do_train:\n",
" train_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False)\n",
"\n",
" global_step, tr_loss = train(args, train_dataset, model, tokenizer)\n",
" logger.info(\" global_step = %s, average loss = %s\", global_step, tr_loss)\n",
"\n",
" # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()\n",
" if args.do_train:\n",
" # Create output directory if needed\n",
" os.makedirs(args.output_dir, exist_ok=True)\n",
"\n",
" logger.info(\"Saving model checkpoint to %s\", args.output_dir)\n",
" # Save a trained model, configuration and tokenizer using `save_pretrained()`.\n",
" # They can then be reloaded using `from_pretrained()`\n",
" model_to_save = (\n",
" model.module if hasattr(model, \"module\") else model\n",
" ) # Take care of distributed/parallel training\n",
" model_to_save.save_pretrained(args.output_dir)\n",
" tokenizer.save_pretrained(args.output_dir)\n",
"\n",
" # Good practice: save your training arguments together with the trained model\n",
" torch.save(args, os.path.join(args.output_dir, \"training_args.bin\"))\n",
"\n",
" # Load a trained model and vocabulary that you have fine-tuned\n",
" model = AutoModelForCausalLM.from_pretrained(args.output_dir)\n",
" tokenizer = AutoTokenizer.from_pretrained(args.output_dir)\n",
" model.to(args.device)\n",
"\n",
" # Evaluation\n",
" results = {}\n",
" if args.do_eval and args.local_rank in [-1, 0]:\n",
" checkpoints = [args.output_dir]\n",
" if args.eval_all_checkpoints:\n",
" checkpoints = list(\n",
" os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + \"/**/\" + WEIGHTS_NAME, recursive=True))\n",
" )\n",
" logging.getLogger(\"transformers.modeling_utils\").setLevel(logging.WARN) # Reduce logging\n",
" logger.info(\"Evaluate the following checkpoints: %s\", checkpoints)\n",
" for checkpoint in checkpoints:\n",
" global_step = checkpoint.split(\"-\")[-1] if len(checkpoints) > 1 else \"\"\n",
" prefix = checkpoint.split(\"/\")[-1] if checkpoint.find(\"checkpoint\") != -1 else \"\"\n",
"\n",
" model = AutoModelForCausalLM.from_pretrained(checkpoint)\n",
" model.to(args.device)\n",
" result = evaluate(args, model, tokenizer, df_trn, df_val, prefix=prefix)\n",
" result = dict((k + \"_{}\".format(global_step), v) for k, v in result.items())\n",
" results.update(result)\n",
"\n",
" return results"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7NWvkdR-XHeB"
},
"source": [
"## Train it"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 780,
"referenced_widgets": [
"1d7f4c82687540f1ad69eb54ac3c25b4",
"e7b9f3fc77a24259a87ef0dc735dfecb",
"f3bf54733c2d4d9daa1cc9a7746ccb14",
"aa40eb6346b54e7dac98e0b068cd4927",
"021b771a270f479aa3b9e2b5f17e3d97",
"450b0e7fd7a347c7beb78b7d72f64385",
"9391d7abf6ed4400903995f56d7a1260",
"ea6b919964d24c2f9de1c64c9cefaf23",
"2fa1fa2407384cb98d79a912de2d5b8f",
"dc27e2caf1ea4a4ab9ae3708fb06952f",
"e38fb98fd7b3413392dc39c93a107a35",
"855ca0a6125a4d698416214a9425ad98",
"4699416338ae40a5b6abf19e45089aec",
"43fdb31d3f314624ba07a15718b0c8f3",
"de252cd193114c40ad5f5e9622b7abc7",
"5e48b617cc3f41c3945efc28fc5e0c75",
"68a9dc52819c48fb97259f318f9b5c6a",
"b4e00059cf3a49929978ed780aae8358",
"0ff5f4e3506b493a98d72008a467f35f",
"77b97fa3271b48ac9f93665a102b4fd1",
"a937f1dfeee5432ba31b3016fd30e9e2",
"3c6d446f491c48fcae03e0034bfaaae9",
"a193bb3a0b5b4cbba587e2460075a445",
"75f8aebc30304fe198b5a2898a53a92d",
"8b8a7c771d234f6c9d758a1f07f75a90",
"c6518c4a721745bf97ee682f2ebe4635",
"29cffa2b4f234e12802344eb53838641",
"96243b7b227f465f83a289481680b925",
"8c016a54f0a24fcdacf369baa9d24f1e",
"7fe5b457ca0f417f90a20d235e9cec07",
"fdffb26b99c24c978580f1cf97359fea",
"8e3f1740c82f47949eefc2eb53052eae",
"9cccd43f6acc4e25b4876fd0ae7a2ad6",
"175e94deab7f4d20b99b419bea33583b",
"41f26f7210e540479814e5d68de13ddb",
"cf5cd281fa3b453093e210650bf81e9e",
"e1fbe239c2394cbf973ac5b95e1e1491",
"810ac22adad344b7bf8b556ded990122",
"8b3a41c1900b45ebb9c56601deca0e84",
"002f56aac3d64b33a0e799c0baf1e6b9",
"a0f2a9a279734aa5bf146f0a5b33c43b",
"850b5411122e4d608511fe26818bea68",
"0663fb4bd85f4d87a7d61910b995be14",
"cb7f52610fcf49bda46a14b296ff5bb5",
"0ca29b4a62e04d9c937189ea19b25de8",
"f871b83632974e0088bae65e78efaf28",
"4cacf7fc20754a7ca7fe08c8ec187a81",
"8bcc625c0f284398bbd287fe45021b17"
]
},
"id": "e61zo2JtXGNX",
"outputId": "22d4916e-7169-44b5-f9d8-79b9c43fab2e",
"tags": []
},
"outputs": [],
"source": [
"main(trn_df, val_df)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YRpQ_n2zXQj-"
},
"source": [
"## Test it"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"execution": {
"iopub.status.busy": "2022-10-18T05:27:32.954246Z",
"iopub.status.idle": "2022-10-18T05:27:32.954450Z",
"shell.execute_reply": "2022-10-18T05:27:32.954353Z",
"shell.execute_reply.started": "2022-10-18T05:27:32.954342Z"
},
"id": "HGw3qgfaXQHX",
"outputId": "93e84cfd-9718-42e5-bd11-418112c91d71",
"tags": []
},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-large')\n",
"model = AutoModelForCausalLM.from_pretrained('models/output-medium-oneline')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"execution": {
"iopub.status.busy": "2022-10-18T05:27:32.955300Z",
"iopub.status.idle": "2022-10-18T05:27:32.955497Z",
"shell.execute_reply": "2022-10-18T05:27:32.955404Z",
"shell.execute_reply.started": "2022-10-18T05:27:32.955394Z"
},
"id": "lAWsiAvNXbxd",
"outputId": "0fd2541e-ee68-4976-b098-8483efe38d5e",
"tags": []
},
"outputs": [],
"source": [
"# Let's chat for 4 lines\n",
"for step in range(4):\n",
" # encode the new user input, add the eos_token and return a tensor in Pytorch\n",
" new_user_input_ids = tokenizer.encode(input(\">> User:\") + tokenizer.eos_token, return_tensors='pt')\n",
" # print(new_user_input_ids)\n",
"\n",
" # append the new user input tokens to the chat history\n",
" bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids\n",
"\n",
" # generated a response while limiting the total chat history to 1000 tokens, \n",
" chat_history_ids = model.generate(\n",
" bot_input_ids, max_length= 200,\n",
" pad_token_id=tokenizer.eos_token_id, \n",
" no_repeat_ngram_size=3, \n",
" do_sample=True, \n",
" top_k=100, \n",
" top_p=0.7,\n",
" temperature=0.8\n",
" )\n",
" \n",
" # pretty print last ouput tokens from bot\n",
" print(\"Cartman: {}\".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "model_train_upload_workflow.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"002f56aac3d64b33a0e799c0baf1e6b9": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"021b771a270f479aa3b9e2b5f17e3d97": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"0663fb4bd85f4d87a7d61910b995be14": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "Evaluating: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_f871b83632974e0088bae65e78efaf28",
"max": 21,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_0ca29b4a62e04d9c937189ea19b25de8",
"value": 21
}
},
"0ca29b4a62e04d9c937189ea19b25de8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"0ff5f4e3506b493a98d72008a467f35f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "Iteration: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_3c6d446f491c48fcae03e0034bfaaae9",
"max": 195,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_a937f1dfeee5432ba31b3016fd30e9e2",
"value": 195
}
},
"175e94deab7f4d20b99b419bea33583b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1d7f4c82687540f1ad69eb54ac3c25b4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_f3bf54733c2d4d9daa1cc9a7746ccb14",
"IPY_MODEL_aa40eb6346b54e7dac98e0b068cd4927"
],
"layout": "IPY_MODEL_e7b9f3fc77a24259a87ef0dc735dfecb"
}
},
"29cffa2b4f234e12802344eb53838641": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "Iteration: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_7fe5b457ca0f417f90a20d235e9cec07",
"max": 195,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_8c016a54f0a24fcdacf369baa9d24f1e",
"value": 195
}
},
"2fa1fa2407384cb98d79a912de2d5b8f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_e38fb98fd7b3413392dc39c93a107a35",
"IPY_MODEL_855ca0a6125a4d698416214a9425ad98"
],
"layout": "IPY_MODEL_dc27e2caf1ea4a4ab9ae3708fb06952f"
}
},
"3c6d446f491c48fcae03e0034bfaaae9": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"41f26f7210e540479814e5d68de13ddb": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "Iteration: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_810ac22adad344b7bf8b556ded990122",
"max": 195,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_e1fbe239c2394cbf973ac5b95e1e1491",
"value": 195
}
},
"43fdb31d3f314624ba07a15718b0c8f3": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"450b0e7fd7a347c7beb78b7d72f64385": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4699416338ae40a5b6abf19e45089aec": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"4cacf7fc20754a7ca7fe08c8ec187a81": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"5e48b617cc3f41c3945efc28fc5e0c75": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"68a9dc52819c48fb97259f318f9b5c6a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_0ff5f4e3506b493a98d72008a467f35f",
"IPY_MODEL_77b97fa3271b48ac9f93665a102b4fd1"
],
"layout": "IPY_MODEL_b4e00059cf3a49929978ed780aae8358"
}
},
"75f8aebc30304fe198b5a2898a53a92d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"77b97fa3271b48ac9f93665a102b4fd1": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_75f8aebc30304fe198b5a2898a53a92d",
"placeholder": "",
"style": "IPY_MODEL_a193bb3a0b5b4cbba587e2460075a445",
"value": " 195/195 [00:35&lt;00:00, 5.45it/s]"
}
},
"7fe5b457ca0f417f90a20d235e9cec07": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"810ac22adad344b7bf8b556ded990122": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"850b5411122e4d608511fe26818bea68": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"855ca0a6125a4d698416214a9425ad98": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_5e48b617cc3f41c3945efc28fc5e0c75",
"placeholder": "",
"style": "IPY_MODEL_de252cd193114c40ad5f5e9622b7abc7",
"value": " 195/195 [00:44&lt;00:00, 4.39it/s]"
}
},
"8b3a41c1900b45ebb9c56601deca0e84": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"8b8a7c771d234f6c9d758a1f07f75a90": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_29cffa2b4f234e12802344eb53838641",
"IPY_MODEL_96243b7b227f465f83a289481680b925"
],
"layout": "IPY_MODEL_c6518c4a721745bf97ee682f2ebe4635"
}
},
"8bcc625c0f284398bbd287fe45021b17": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"8c016a54f0a24fcdacf369baa9d24f1e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"8e3f1740c82f47949eefc2eb53052eae": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9391d7abf6ed4400903995f56d7a1260": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"96243b7b227f465f83a289481680b925": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_8e3f1740c82f47949eefc2eb53052eae",
"placeholder": "",
"style": "IPY_MODEL_fdffb26b99c24c978580f1cf97359fea",
"value": " 195/195 [01:17&lt;00:00, 2.53it/s]"
}
},
"9cccd43f6acc4e25b4876fd0ae7a2ad6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_41f26f7210e540479814e5d68de13ddb",
"IPY_MODEL_cf5cd281fa3b453093e210650bf81e9e"
],
"layout": "IPY_MODEL_175e94deab7f4d20b99b419bea33583b"
}
},
"a0f2a9a279734aa5bf146f0a5b33c43b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_0663fb4bd85f4d87a7d61910b995be14",
"IPY_MODEL_cb7f52610fcf49bda46a14b296ff5bb5"
],
"layout": "IPY_MODEL_850b5411122e4d608511fe26818bea68"
}
},
"a193bb3a0b5b4cbba587e2460075a445": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"a937f1dfeee5432ba31b3016fd30e9e2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"aa40eb6346b54e7dac98e0b068cd4927": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ea6b919964d24c2f9de1c64c9cefaf23",
"placeholder": "",
"style": "IPY_MODEL_9391d7abf6ed4400903995f56d7a1260",
"value": " 4/4 [02:23&lt;00:00, 36.00s/it]"
}
},
"b4e00059cf3a49929978ed780aae8358": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c6518c4a721745bf97ee682f2ebe4635": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"cb7f52610fcf49bda46a14b296ff5bb5": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_8bcc625c0f284398bbd287fe45021b17",
"placeholder": "",
"style": "IPY_MODEL_4cacf7fc20754a7ca7fe08c8ec187a81",
"value": " 21/21 [00:01&lt;00:00, 10.78it/s]"
}
},
"cf5cd281fa3b453093e210650bf81e9e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_002f56aac3d64b33a0e799c0baf1e6b9",
"placeholder": "",
"style": "IPY_MODEL_8b3a41c1900b45ebb9c56601deca0e84",
"value": " 195/195 [00:40&lt;00:00, 4.84it/s]"
}
},
"dc27e2caf1ea4a4ab9ae3708fb06952f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"de252cd193114c40ad5f5e9622b7abc7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"e1fbe239c2394cbf973ac5b95e1e1491": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"e38fb98fd7b3413392dc39c93a107a35": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "Iteration: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_43fdb31d3f314624ba07a15718b0c8f3",
"max": 195,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_4699416338ae40a5b6abf19e45089aec",
"value": 195
}
},
"e7b9f3fc77a24259a87ef0dc735dfecb": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"ea6b919964d24c2f9de1c64c9cefaf23": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f3bf54733c2d4d9daa1cc9a7746ccb14": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "Epoch: 100%",
"description_tooltip": null,
"layout": "IPY_MODEL_450b0e7fd7a347c7beb78b7d72f64385",
"max": 4,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_021b771a270f479aa3b9e2b5f17e3d97",
"value": 4
}
},
"f871b83632974e0088bae65e78efaf28": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"fdffb26b99c24c978580f1cf97359fea": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}