This commit is contained in:
Adam 2022-10-13 21:53:26 -04:00
commit aac0893f25
2 changed files with 30 additions and 0 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
__pycache__

29
main.py Normal file
View file

@ -0,0 +1,29 @@
from fastapi import FastAPI
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
import torch
api = FastAPI()
tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-large')
cartman = AutoModelForCausalLM.from_pretrained('../southpark/output-medium')
def cartman_speak(user_message):
new_user_input_ids = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt')
bot_output = new_user_input_ids
bot_input_ids = torch.cat([new_user_input_ids, bot_output])
bot_output = cartman.generate(
bot_input_ids, max_length= 200,
pad_token_id=tokenizer.eos_token_id,
no_repeat_ngram_size=3,
do_sample=True,
top_k=100,
top_p=0.7,
temperature=.8
)
return '{}'.format(tokenizer.decode(bot_output[:,bot_input_ids.shape[-1]:][0], skip_special_tokens=True))
@api.get("/cartman/{user_message}")
def read_item(user_message: str):
return {"Cartman": cartman_speak(user_message)}