commit aac0893f25fc5a497f7572e70694d65835e3e732 Author: Adam <24621027+WhiteDopeOnPunk@users.noreply.github.com> Date: Thu Oct 13 21:53:26 2022 -0400 ye diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bee8a64 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ diff --git a/main.py b/main.py new file mode 100644 index 0000000..d017f52 --- /dev/null +++ b/main.py @@ -0,0 +1,29 @@ +from fastapi import FastAPI +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.models.auto.modeling_auto import AutoModelForCausalLM +import torch + +api = FastAPI() + +tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-large') +cartman = AutoModelForCausalLM.from_pretrained('../southpark/output-medium') + +def cartman_speak(user_message): + new_user_input_ids = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt') + bot_output = new_user_input_ids + bot_input_ids = torch.cat([new_user_input_ids, bot_output]) + bot_output = cartman.generate( + bot_input_ids, max_length= 200, + pad_token_id=tokenizer.eos_token_id, + no_repeat_ngram_size=3, + do_sample=True, + top_k=100, + top_p=0.7, + temperature=.8 + ) + return '{}'.format(tokenizer.decode(bot_output[:,bot_input_ids.shape[-1]:][0], skip_special_tokens=True)) + +@api.get("/cartman/{user_message}") +def read_item(user_message: str): + return {"Cartman": cartman_speak(user_message)} +