organize
This commit is contained in:
parent
b217ef460d
commit
76d8d2a2cc
6 changed files with 185 additions and 86 deletions
82
api/main.py
82
api/main.py
|
@ -1,82 +0,0 @@
|
||||||
from fastapi import FastAPI, Request
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
||||||
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
"microsoft/DialoGPT-large", padding_side='left')
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
"../train/cartman/models/output-medium")
|
|
||||||
|
|
||||||
|
|
||||||
class Packet(BaseModel):
|
|
||||||
message: str
|
|
||||||
max_new_tokens: int
|
|
||||||
num_beams: int
|
|
||||||
num_beam_groups: int
|
|
||||||
no_repeat_ngram_size: int
|
|
||||||
length_penalty: float
|
|
||||||
diversity_penalty: float
|
|
||||||
repetition_penalty: float
|
|
||||||
early_stopping: bool
|
|
||||||
|
|
||||||
|
|
||||||
def cartman_respond(packet: Packet) -> str:
|
|
||||||
input_ids = tokenizer(packet.message +
|
|
||||||
tokenizer.eos_token, return_tensors="pt").input_ids
|
|
||||||
|
|
||||||
outputs = model.generate(
|
|
||||||
input_ids,
|
|
||||||
pad_token_id=tokenizer.eos_token_id,
|
|
||||||
max_new_tokens=packet.max_new_tokens,
|
|
||||||
num_beams=packet.num_beams,
|
|
||||||
num_beam_groups=packet.num_beam_groups,
|
|
||||||
no_repeat_ngram_size=packet.no_repeat_ngram_size,
|
|
||||||
length_penalty=packet.length_penalty,
|
|
||||||
diversity_penalty=packet.diversity_penalty,
|
|
||||||
repetition_penalty=packet.repetition_penalty,
|
|
||||||
early_stopping=packet.early_stopping,
|
|
||||||
|
|
||||||
# do_sample = True,
|
|
||||||
# top_k = 100,
|
|
||||||
# top_p = 0.7,
|
|
||||||
# temperature = 0.8,
|
|
||||||
)
|
|
||||||
return tokenizer.decode(outputs[:, input_ids.shape[-1]:][0],
|
|
||||||
skip_special_tokens=True)
|
|
||||||
|
|
||||||
|
|
||||||
api = FastAPI()
|
|
||||||
|
|
||||||
api.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=['*'],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@api.post('/chat/')
|
|
||||||
async def getInformation(request: Request) -> dict[str, str]:
|
|
||||||
data = await request.json()
|
|
||||||
|
|
||||||
packet = Packet(
|
|
||||||
message=data.get('message'),
|
|
||||||
max_new_tokens=data.get('max_new_tokens'),
|
|
||||||
num_beams=data.get('num_beams'),
|
|
||||||
num_beam_groups=data.get('num_beam_groups'),
|
|
||||||
no_repeat_ngram_size=data.get('no_repeat_ngram_size'),
|
|
||||||
length_penalty=data.get('length_penalty'),
|
|
||||||
diversity_penalty=data.get('diversity_penalty'),
|
|
||||||
repetition_penalty=data.get('repetition_penalty'),
|
|
||||||
early_stopping=data.get('early_stopping'),
|
|
||||||
)
|
|
||||||
|
|
||||||
print(packet.message)
|
|
||||||
response = cartman_respond(packet)
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
return {"Cartman": response}
|
|
2
api/run
2
api/run
|
@ -1,3 +1,3 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
uvicorn main:api --host 10.0.1.1 --reload
|
uvicorn cartman:api --host 10.0.1.1 --reload
|
||||||
|
|
44
api/src/bots/cartman.py
Normal file
44
api/src/bots/cartman.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
from ..models import Packet, BotResponse
|
||||||
|
|
||||||
|
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||||
|
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
"microsoft/DialoGPT-large", padding_side='left'
|
||||||
|
)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"../train/cartman/models/output-medium"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cartman(packet: Packet) -> BotResponse:
|
||||||
|
input_ids = tokenizer(
|
||||||
|
packet.message + tokenizer.eos_token,
|
||||||
|
return_tensors="pt"
|
||||||
|
).input_ids
|
||||||
|
|
||||||
|
outputs = model.generate(
|
||||||
|
input_ids,
|
||||||
|
pad_token_id=tokenizer.eos_token_id,
|
||||||
|
max_new_tokens=packet.max_new_tokens,
|
||||||
|
num_beams=packet.num_beams,
|
||||||
|
num_beam_groups=packet.num_beam_groups,
|
||||||
|
no_repeat_ngram_size=packet.no_repeat_ngram_size,
|
||||||
|
length_penalty=packet.length_penalty,
|
||||||
|
diversity_penalty=packet.diversity_penalty,
|
||||||
|
repetition_penalty=packet.repetition_penalty,
|
||||||
|
early_stopping=packet.early_stopping,
|
||||||
|
|
||||||
|
# do_sample = True,
|
||||||
|
# top_k = 100,
|
||||||
|
# top_p = 0.7,
|
||||||
|
# temperature = 0.8,
|
||||||
|
)
|
||||||
|
|
||||||
|
return BotResponse(
|
||||||
|
name='Cartman',
|
||||||
|
message=tokenizer.decode(
|
||||||
|
outputs[:, input_ids.shape[-1]:][0],
|
||||||
|
skip_special_tokens=True
|
||||||
|
)
|
||||||
|
)
|
19
api/src/models.py
Normal file
19
api/src/models.py
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Packet(BaseModel):
|
||||||
|
bot_name: str
|
||||||
|
message: str
|
||||||
|
max_new_tokens: int
|
||||||
|
num_beams: int
|
||||||
|
num_beam_groups: int
|
||||||
|
no_repeat_ngram_size: int
|
||||||
|
length_penalty: float
|
||||||
|
diversity_penalty: float
|
||||||
|
repetition_penalty: float
|
||||||
|
early_stopping: bool
|
||||||
|
|
||||||
|
|
||||||
|
class BotResponse(BaseModel):
|
||||||
|
name: str
|
||||||
|
message: str
|
117
api/test/test.html
Normal file
117
api/test/test.html
Normal file
|
@ -0,0 +1,117 @@
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<title>Chat</title>
|
||||||
|
<style>
|
||||||
|
input,
|
||||||
|
textarea {
|
||||||
|
border: 1px solid;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#history,
|
||||||
|
#message {
|
||||||
|
width: 600px;
|
||||||
|
}
|
||||||
|
|
||||||
|
textarea {
|
||||||
|
height: 20rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
input[type=text] {
|
||||||
|
font-size: 18;
|
||||||
|
}
|
||||||
|
|
||||||
|
input[type=submit],
|
||||||
|
input[type=reset] {
|
||||||
|
font-size: 16;
|
||||||
|
padding: .5em;
|
||||||
|
}
|
||||||
|
|
||||||
|
input[type=submit]:hover {
|
||||||
|
background-color: var(--color10);
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<script>
|
||||||
|
function onSubmit(message, history) {
|
||||||
|
let options = {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type':
|
||||||
|
'application/json;charset=utf-8'
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
'bot_name': 'cartman',
|
||||||
|
'message': message.value,
|
||||||
|
'max_new_tokens': max_new_tokens.value,
|
||||||
|
'num_beams': num_beams.value,
|
||||||
|
'num_beam_groups': num_beam_groups.value,
|
||||||
|
'no_repeat_ngram_size': no_repeat_ngram_size.value,
|
||||||
|
'length_penalty': length_penalty.value,
|
||||||
|
'diversity_penalty': diversity_penalty.value,
|
||||||
|
'repetition_penalty': repetition_penalty.value,
|
||||||
|
'early_stopping': true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
history.value = history.value + 'You: ' + message.value + '\n';
|
||||||
|
message.value = "";
|
||||||
|
history.scrollTop = history.scrollHeight
|
||||||
|
|
||||||
|
let fetchRes = fetch('http://localhost:8000/chat', options);
|
||||||
|
fetchRes.then(res =>
|
||||||
|
res.json()).then(d => {
|
||||||
|
history.value = `${history.value}${d.name}: ${d.message}\n`;
|
||||||
|
history.scrollTop = history.scrollHeight
|
||||||
|
})
|
||||||
|
};
|
||||||
|
</script>
|
||||||
|
<div>
|
||||||
|
<form id="chatbox" onsubmit="onSubmit(message, history);return false">
|
||||||
|
<textarea id="history" readonly="true" wrap="soft"></textarea>
|
||||||
|
<p>
|
||||||
|
<input type="text" id="message" autocomplete="off">
|
||||||
|
</p>
|
||||||
|
<input type="submit" value="Send">
|
||||||
|
<br />
|
||||||
|
<h3>Knobs to spin:</h3>
|
||||||
|
<table>
|
||||||
|
<tr>
|
||||||
|
<td><label for="max_new_tokens">Max new tokens:</label></td>
|
||||||
|
<td><input id="max_new_tokens" type="number" value="200"></td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><label for="num_beams">Num beams:</label>
|
||||||
|
<td><input id="num_beams" type="number" value="8"> (must be divisible by num_beam_groups)
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><label for="num_beam_groups">Num beam groups:</label>
|
||||||
|
<td><input id="num_beam_groups" type="number" value="4">
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><label for="no_repeat_ngram_size">No repeat ngram size:</label>
|
||||||
|
<td><input id="no_repeat_ngram_size" type="number" value="3">
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><label for="length_penalty">Length penalty:</label>
|
||||||
|
<td><input id="length_penalty" type="number" step="0.1" value="1.4">
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><label for="diversity_penalty">Diversity penalty:</label>
|
||||||
|
<td><input id="diversity_penalty" type="number" step="0.1" value="0">
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td><label for="repetition_penalty">Repetition penalty</label>
|
||||||
|
<td><input id="repetition_penalty" type="number" step="0.1" value="2.1">
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
<input type="reset">
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
|
@ -1,5 +1,4 @@
|
||||||
import requests
|
import requests
|
||||||
import json
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
user_input: str = input('>> ')
|
user_input: str = input('>> ')
|
||||||
|
@ -19,6 +18,8 @@ while True:
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
'http://127.0.0.1:8000/chat/', json=packet)
|
'http://127.0.0.1:8000/chat/',
|
||||||
|
json=packet,
|
||||||
|
).json()
|
||||||
|
|
||||||
print(response.json())
|
print(f"{response.get('name')}: {response.get('message')}")
|
||||||
|
|
Loading…
Add table
Reference in a new issue