diff --git a/main.py b/main.py index cb571df..067f0e3 100644 --- a/main.py +++ b/main.py @@ -17,19 +17,19 @@ from transformers.models.auto.modeling_auto import AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large") model = AutoModelForCausalLM.from_pretrained("../southpark/output-medium") -def cartman_respond(input_text): - input_ids = tokenizer(input_text + tokenizer.eos_token, return_tensors="pt").input_ids +def cartman_respond(packet): + input_ids = tokenizer(str(packet.get('message')) + tokenizer.eos_token, return_tensors="pt").input_ids outputs = model.generate( input_ids, pad_token_id=tokenizer.eos_token_id, - max_new_tokens = 200, - num_beams = 8, - num_beam_groups = 4, - no_repeat_ngram_size=3, - length_penalty = 1.4, - diversity_penalty = 0, - repetition_penalty = 2.1, - early_stopping = True, + max_new_tokens = int(packet.get('max_new_tokens')), + num_beams = int(packet.get('num_beams')), + num_beam_groups = int(packet.get('num_beam_groups')), + no_repeat_ngram_size = int(packet.get('no_repeat_ngram_size')), + length_penalty = float(packet.get('length_penalty')), + diversity_penalty = float(packet.get('diversity_penalty')), + repetition_penalty = float(packet.get('repetition_penalty')), + early_stopping = bool(packet.get('early_stopping')), # do_sample = True, # top_k = 100, @@ -41,9 +41,10 @@ def cartman_respond(input_text): @api.post('/chat/') async def getInformation(data : Request): packet = await data.json() - message = packet.get('Message') + print(packet) + message = str(packet.get('message')) print(message) - response = cartman_respond(message) + response = cartman_respond(packet) print(response) return {