should be done better but it works
This commit is contained in:
parent
85f2d57a17
commit
6663e8a366
1 changed files with 13 additions and 12 deletions
25
main.py
25
main.py
|
@ -17,19 +17,19 @@ from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
||||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
|
||||||
model = AutoModelForCausalLM.from_pretrained("../southpark/output-medium")
|
model = AutoModelForCausalLM.from_pretrained("../southpark/output-medium")
|
||||||
|
|
||||||
def cartman_respond(input_text):
|
def cartman_respond(packet):
|
||||||
input_ids = tokenizer(input_text + tokenizer.eos_token, return_tensors="pt").input_ids
|
input_ids = tokenizer(str(packet.get('message')) + tokenizer.eos_token, return_tensors="pt").input_ids
|
||||||
outputs = model.generate(
|
outputs = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
max_new_tokens = 200,
|
max_new_tokens = int(packet.get('max_new_tokens')),
|
||||||
num_beams = 8,
|
num_beams = int(packet.get('num_beams')),
|
||||||
num_beam_groups = 4,
|
num_beam_groups = int(packet.get('num_beam_groups')),
|
||||||
no_repeat_ngram_size=3,
|
no_repeat_ngram_size = int(packet.get('no_repeat_ngram_size')),
|
||||||
length_penalty = 1.4,
|
length_penalty = float(packet.get('length_penalty')),
|
||||||
diversity_penalty = 0,
|
diversity_penalty = float(packet.get('diversity_penalty')),
|
||||||
repetition_penalty = 2.1,
|
repetition_penalty = float(packet.get('repetition_penalty')),
|
||||||
early_stopping = True,
|
early_stopping = bool(packet.get('early_stopping')),
|
||||||
|
|
||||||
# do_sample = True,
|
# do_sample = True,
|
||||||
# top_k = 100,
|
# top_k = 100,
|
||||||
|
@ -41,9 +41,10 @@ def cartman_respond(input_text):
|
||||||
@api.post('/chat/')
|
@api.post('/chat/')
|
||||||
async def getInformation(data : Request):
|
async def getInformation(data : Request):
|
||||||
packet = await data.json()
|
packet = await data.json()
|
||||||
message = packet.get('Message')
|
print(packet)
|
||||||
|
message = str(packet.get('message'))
|
||||||
print(message)
|
print(message)
|
||||||
response = cartman_respond(message)
|
response = cartman_respond(packet)
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
Loading…
Add table
Reference in a new issue