This commit is contained in:
Adam 2023-02-09 15:09:46 -05:00
parent b217ef460d
commit 76d8d2a2cc
6 changed files with 185 additions and 86 deletions

View file

@ -1,82 +0,0 @@
from fastapi import FastAPI, Request
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/DialoGPT-large", padding_side='left')
model = AutoModelForCausalLM.from_pretrained(
"../train/cartman/models/output-medium")
class Packet(BaseModel):
message: str
max_new_tokens: int
num_beams: int
num_beam_groups: int
no_repeat_ngram_size: int
length_penalty: float
diversity_penalty: float
repetition_penalty: float
early_stopping: bool
def cartman_respond(packet: Packet) -> str:
input_ids = tokenizer(packet.message +
tokenizer.eos_token, return_tensors="pt").input_ids
outputs = model.generate(
input_ids,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=packet.max_new_tokens,
num_beams=packet.num_beams,
num_beam_groups=packet.num_beam_groups,
no_repeat_ngram_size=packet.no_repeat_ngram_size,
length_penalty=packet.length_penalty,
diversity_penalty=packet.diversity_penalty,
repetition_penalty=packet.repetition_penalty,
early_stopping=packet.early_stopping,
# do_sample = True,
# top_k = 100,
# top_p = 0.7,
# temperature = 0.8,
)
return tokenizer.decode(outputs[:, input_ids.shape[-1]:][0],
skip_special_tokens=True)
api = FastAPI()
api.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@api.post('/chat/')
async def getInformation(request: Request) -> dict[str, str]:
data = await request.json()
packet = Packet(
message=data.get('message'),
max_new_tokens=data.get('max_new_tokens'),
num_beams=data.get('num_beams'),
num_beam_groups=data.get('num_beam_groups'),
no_repeat_ngram_size=data.get('no_repeat_ngram_size'),
length_penalty=data.get('length_penalty'),
diversity_penalty=data.get('diversity_penalty'),
repetition_penalty=data.get('repetition_penalty'),
early_stopping=data.get('early_stopping'),
)
print(packet.message)
response = cartman_respond(packet)
print(response)
return {"Cartman": response}

View file

@ -1,3 +1,3 @@
#!/bin/bash #!/bin/bash
uvicorn main:api --host 10.0.1.1 --reload uvicorn cartman:api --host 10.0.1.1 --reload

44
api/src/bots/cartman.py Normal file
View file

@ -0,0 +1,44 @@
from ..models import Packet, BotResponse
from transformers.models.auto.tokenization_auto import AutoTokenizer
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/DialoGPT-large", padding_side='left'
)
model = AutoModelForCausalLM.from_pretrained(
"../train/cartman/models/output-medium"
)
def cartman(packet: Packet) -> BotResponse:
input_ids = tokenizer(
packet.message + tokenizer.eos_token,
return_tensors="pt"
).input_ids
outputs = model.generate(
input_ids,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=packet.max_new_tokens,
num_beams=packet.num_beams,
num_beam_groups=packet.num_beam_groups,
no_repeat_ngram_size=packet.no_repeat_ngram_size,
length_penalty=packet.length_penalty,
diversity_penalty=packet.diversity_penalty,
repetition_penalty=packet.repetition_penalty,
early_stopping=packet.early_stopping,
# do_sample = True,
# top_k = 100,
# top_p = 0.7,
# temperature = 0.8,
)
return BotResponse(
name='Cartman',
message=tokenizer.decode(
outputs[:, input_ids.shape[-1]:][0],
skip_special_tokens=True
)
)

19
api/src/models.py Normal file
View file

@ -0,0 +1,19 @@
from pydantic import BaseModel
class Packet(BaseModel):
bot_name: str
message: str
max_new_tokens: int
num_beams: int
num_beam_groups: int
no_repeat_ngram_size: int
length_penalty: float
diversity_penalty: float
repetition_penalty: float
early_stopping: bool
class BotResponse(BaseModel):
name: str
message: str

117
api/test/test.html Normal file
View file

@ -0,0 +1,117 @@
<!DOCTYPE html>
<html>
<head>
<title>Chat</title>
<style>
input,
textarea {
border: 1px solid;
border-radius: 4px;
}
#history,
#message {
width: 600px;
}
textarea {
height: 20rem;
}
input[type=text] {
font-size: 18;
}
input[type=submit],
input[type=reset] {
font-size: 16;
padding: .5em;
}
input[type=submit]:hover {
background-color: var(--color10);
}
</style>
</head>
<body>
<script>
function onSubmit(message, history) {
let options = {
method: 'POST',
headers: {
'Content-Type':
'application/json;charset=utf-8'
},
body: JSON.stringify({
'bot_name': 'cartman',
'message': message.value,
'max_new_tokens': max_new_tokens.value,
'num_beams': num_beams.value,
'num_beam_groups': num_beam_groups.value,
'no_repeat_ngram_size': no_repeat_ngram_size.value,
'length_penalty': length_penalty.value,
'diversity_penalty': diversity_penalty.value,
'repetition_penalty': repetition_penalty.value,
'early_stopping': true
})
}
history.value = history.value + 'You: ' + message.value + '\n';
message.value = "";
history.scrollTop = history.scrollHeight
let fetchRes = fetch('http://localhost:8000/chat', options);
fetchRes.then(res =>
res.json()).then(d => {
history.value = `${history.value}${d.name}: ${d.message}\n`;
history.scrollTop = history.scrollHeight
})
};
</script>
<div>
<form id="chatbox" onsubmit="onSubmit(message, history);return false">
<textarea id="history" readonly="true" wrap="soft"></textarea>
<p>
<input type="text" id="message" autocomplete="off">
</p>
<input type="submit" value="Send">
<br />
<h3>Knobs to spin:</h3>
<table>
<tr>
<td><label for="max_new_tokens">Max new tokens:</label></td>
<td><input id="max_new_tokens" type="number" value="200"></td>
</tr>
<tr>
<td><label for="num_beams">Num beams:</label>
<td><input id="num_beams" type="number" value="8"> (must be divisible by num_beam_groups)
</tr>
<tr>
<td><label for="num_beam_groups">Num beam groups:</label>
<td><input id="num_beam_groups" type="number" value="4">
</tr>
<tr>
<td><label for="no_repeat_ngram_size">No repeat ngram size:</label>
<td><input id="no_repeat_ngram_size" type="number" value="3">
</tr>
<tr>
<td><label for="length_penalty">Length penalty:</label>
<td><input id="length_penalty" type="number" step="0.1" value="1.4">
</tr>
<tr>
<td><label for="diversity_penalty">Diversity penalty:</label>
<td><input id="diversity_penalty" type="number" step="0.1" value="0">
</tr>
<tr>
<td><label for="repetition_penalty">Repetition penalty</label>
<td><input id="repetition_penalty" type="number" step="0.1" value="2.1">
</tr>
</table>
<input type="reset">
</form>
</div>
</body>
</html>

View file

@ -1,5 +1,4 @@
import requests import requests
import json
while True: while True:
user_input: str = input('>> ') user_input: str = input('>> ')
@ -19,6 +18,8 @@ while True:
} }
response = requests.post( response = requests.post(
'http://127.0.0.1:8000/chat/', json=packet) 'http://127.0.0.1:8000/chat/',
json=packet,
).json()
print(response.json()) print(f"{response.get('name')}: {response.get('message')}")