del
This commit is contained in:
parent
d2c8d5dca2
commit
5f3c23ac69
2 changed files with 0 additions and 56 deletions
53
main.py
53
main.py
|
@ -1,53 +0,0 @@
|
||||||
from fastapi import FastAPI, Request
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
api = FastAPI()
|
|
||||||
|
|
||||||
api.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=['*'],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=['*'],
|
|
||||||
allow_headers=['*']
|
|
||||||
)
|
|
||||||
|
|
||||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
||||||
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
|
|
||||||
model = AutoModelForCausalLM.from_pretrained("../southpark/output-medium")
|
|
||||||
|
|
||||||
def cartman_respond(packet):
|
|
||||||
input_ids = tokenizer(str(packet.get('message')) + tokenizer.eos_token, return_tensors="pt").input_ids
|
|
||||||
outputs = model.generate(
|
|
||||||
input_ids,
|
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
|
||||||
max_new_tokens = int(packet.get('max_new_tokens')),
|
|
||||||
num_beams = int(packet.get('num_beams')),
|
|
||||||
num_beam_groups = int(packet.get('num_beam_groups')),
|
|
||||||
no_repeat_ngram_size = int(packet.get('no_repeat_ngram_size')),
|
|
||||||
length_penalty = float(packet.get('length_penalty')),
|
|
||||||
diversity_penalty = float(packet.get('diversity_penalty')),
|
|
||||||
repetition_penalty = float(packet.get('repetition_penalty')),
|
|
||||||
early_stopping = bool(packet.get('early_stopping')),
|
|
||||||
|
|
||||||
# do_sample = True,
|
|
||||||
# top_k = 100,
|
|
||||||
# top_p = 0.7,
|
|
||||||
# temperature = 0.8,
|
|
||||||
)
|
|
||||||
return tokenizer.decode(outputs[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
|
||||||
|
|
||||||
@api.post('/chat/')
|
|
||||||
async def getInformation(data : Request):
|
|
||||||
packet = await data.json()
|
|
||||||
print(packet)
|
|
||||||
message = str(packet.get('message'))
|
|
||||||
print(message)
|
|
||||||
response = cartman_respond(packet)
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"Cartman" : response
|
|
||||||
}
|
|
||||||
|
|
3
run
3
run
|
@ -1,3 +0,0 @@
|
||||||
#!/bin/zsh
|
|
||||||
|
|
||||||
uvicorn main:api --host 10.0.1.1 --reload
|
|
Loading…
Add table
Reference in a new issue