should be done better but it works

This commit is contained in:
Adam 2023-02-04 22:19:41 -05:00
parent 85f2d57a17
commit 6663e8a366

25
main.py
View file

@ -17,19 +17,19 @@ from transformers.models.auto.modeling_auto import AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("../southpark/output-medium") model = AutoModelForCausalLM.from_pretrained("../southpark/output-medium")
def cartman_respond(input_text): def cartman_respond(packet):
input_ids = tokenizer(input_text + tokenizer.eos_token, return_tensors="pt").input_ids input_ids = tokenizer(str(packet.get('message')) + tokenizer.eos_token, return_tensors="pt").input_ids
outputs = model.generate( outputs = model.generate(
input_ids, input_ids,
pad_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id,
max_new_tokens = 200, max_new_tokens = int(packet.get('max_new_tokens')),
num_beams = 8, num_beams = int(packet.get('num_beams')),
num_beam_groups = 4, num_beam_groups = int(packet.get('num_beam_groups')),
no_repeat_ngram_size=3, no_repeat_ngram_size = int(packet.get('no_repeat_ngram_size')),
length_penalty = 1.4, length_penalty = float(packet.get('length_penalty')),
diversity_penalty = 0, diversity_penalty = float(packet.get('diversity_penalty')),
repetition_penalty = 2.1, repetition_penalty = float(packet.get('repetition_penalty')),
early_stopping = True, early_stopping = bool(packet.get('early_stopping')),
# do_sample = True, # do_sample = True,
# top_k = 100, # top_k = 100,
@ -41,9 +41,10 @@ def cartman_respond(input_text):
@api.post('/chat/') @api.post('/chat/')
async def getInformation(data : Request): async def getInformation(data : Request):
packet = await data.json() packet = await data.json()
message = packet.get('Message') print(packet)
message = str(packet.get('message'))
print(message) print(message)
response = cartman_respond(message) response = cartman_respond(packet)
print(response) print(response)
return { return {