organize
This commit is contained in:
parent
b217ef460d
commit
76d8d2a2cc
6 changed files with 185 additions and 86 deletions
82
api/main.py
82
api/main.py
|
@ -1,82 +0,0 @@
|
|||
from fastapi import FastAPI, Request
|
||||
from pydantic import BaseModel
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"microsoft/DialoGPT-large", padding_side='left')
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"../train/cartman/models/output-medium")
|
||||
|
||||
|
||||
class Packet(BaseModel):
|
||||
message: str
|
||||
max_new_tokens: int
|
||||
num_beams: int
|
||||
num_beam_groups: int
|
||||
no_repeat_ngram_size: int
|
||||
length_penalty: float
|
||||
diversity_penalty: float
|
||||
repetition_penalty: float
|
||||
early_stopping: bool
|
||||
|
||||
|
||||
def cartman_respond(packet: Packet) -> str:
|
||||
input_ids = tokenizer(packet.message +
|
||||
tokenizer.eos_token, return_tensors="pt").input_ids
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=packet.max_new_tokens,
|
||||
num_beams=packet.num_beams,
|
||||
num_beam_groups=packet.num_beam_groups,
|
||||
no_repeat_ngram_size=packet.no_repeat_ngram_size,
|
||||
length_penalty=packet.length_penalty,
|
||||
diversity_penalty=packet.diversity_penalty,
|
||||
repetition_penalty=packet.repetition_penalty,
|
||||
early_stopping=packet.early_stopping,
|
||||
|
||||
# do_sample = True,
|
||||
# top_k = 100,
|
||||
# top_p = 0.7,
|
||||
# temperature = 0.8,
|
||||
)
|
||||
return tokenizer.decode(outputs[:, input_ids.shape[-1]:][0],
|
||||
skip_special_tokens=True)
|
||||
|
||||
|
||||
api = FastAPI()
|
||||
|
||||
api.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=['*'],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@api.post('/chat/')
|
||||
async def getInformation(request: Request) -> dict[str, str]:
|
||||
data = await request.json()
|
||||
|
||||
packet = Packet(
|
||||
message=data.get('message'),
|
||||
max_new_tokens=data.get('max_new_tokens'),
|
||||
num_beams=data.get('num_beams'),
|
||||
num_beam_groups=data.get('num_beam_groups'),
|
||||
no_repeat_ngram_size=data.get('no_repeat_ngram_size'),
|
||||
length_penalty=data.get('length_penalty'),
|
||||
diversity_penalty=data.get('diversity_penalty'),
|
||||
repetition_penalty=data.get('repetition_penalty'),
|
||||
early_stopping=data.get('early_stopping'),
|
||||
)
|
||||
|
||||
print(packet.message)
|
||||
response = cartman_respond(packet)
|
||||
print(response)
|
||||
|
||||
return {"Cartman": response}
|
2
api/run
2
api/run
|
@ -1,3 +1,3 @@
|
|||
#!/bin/bash
|
||||
|
||||
uvicorn main:api --host 10.0.1.1 --reload
|
||||
uvicorn cartman:api --host 10.0.1.1 --reload
|
||||
|
|
44
api/src/bots/cartman.py
Normal file
44
api/src/bots/cartman.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
from ..models import Packet, BotResponse
|
||||
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"microsoft/DialoGPT-large", padding_side='left'
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"../train/cartman/models/output-medium"
|
||||
)
|
||||
|
||||
|
||||
def cartman(packet: Packet) -> BotResponse:
|
||||
input_ids = tokenizer(
|
||||
packet.message + tokenizer.eos_token,
|
||||
return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
max_new_tokens=packet.max_new_tokens,
|
||||
num_beams=packet.num_beams,
|
||||
num_beam_groups=packet.num_beam_groups,
|
||||
no_repeat_ngram_size=packet.no_repeat_ngram_size,
|
||||
length_penalty=packet.length_penalty,
|
||||
diversity_penalty=packet.diversity_penalty,
|
||||
repetition_penalty=packet.repetition_penalty,
|
||||
early_stopping=packet.early_stopping,
|
||||
|
||||
# do_sample = True,
|
||||
# top_k = 100,
|
||||
# top_p = 0.7,
|
||||
# temperature = 0.8,
|
||||
)
|
||||
|
||||
return BotResponse(
|
||||
name='Cartman',
|
||||
message=tokenizer.decode(
|
||||
outputs[:, input_ids.shape[-1]:][0],
|
||||
skip_special_tokens=True
|
||||
)
|
||||
)
|
19
api/src/models.py
Normal file
19
api/src/models.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Packet(BaseModel):
|
||||
bot_name: str
|
||||
message: str
|
||||
max_new_tokens: int
|
||||
num_beams: int
|
||||
num_beam_groups: int
|
||||
no_repeat_ngram_size: int
|
||||
length_penalty: float
|
||||
diversity_penalty: float
|
||||
repetition_penalty: float
|
||||
early_stopping: bool
|
||||
|
||||
|
||||
class BotResponse(BaseModel):
|
||||
name: str
|
||||
message: str
|
117
api/test/test.html
Normal file
117
api/test/test.html
Normal file
|
@ -0,0 +1,117 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
|
||||
<head>
|
||||
<title>Chat</title>
|
||||
<style>
|
||||
input,
|
||||
textarea {
|
||||
border: 1px solid;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
#history,
|
||||
#message {
|
||||
width: 600px;
|
||||
}
|
||||
|
||||
textarea {
|
||||
height: 20rem;
|
||||
}
|
||||
|
||||
input[type=text] {
|
||||
font-size: 18;
|
||||
}
|
||||
|
||||
input[type=submit],
|
||||
input[type=reset] {
|
||||
font-size: 16;
|
||||
padding: .5em;
|
||||
}
|
||||
|
||||
input[type=submit]:hover {
|
||||
background-color: var(--color10);
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<script>
|
||||
function onSubmit(message, history) {
|
||||
let options = {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type':
|
||||
'application/json;charset=utf-8'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
'bot_name': 'cartman',
|
||||
'message': message.value,
|
||||
'max_new_tokens': max_new_tokens.value,
|
||||
'num_beams': num_beams.value,
|
||||
'num_beam_groups': num_beam_groups.value,
|
||||
'no_repeat_ngram_size': no_repeat_ngram_size.value,
|
||||
'length_penalty': length_penalty.value,
|
||||
'diversity_penalty': diversity_penalty.value,
|
||||
'repetition_penalty': repetition_penalty.value,
|
||||
'early_stopping': true
|
||||
})
|
||||
}
|
||||
|
||||
history.value = history.value + 'You: ' + message.value + '\n';
|
||||
message.value = "";
|
||||
history.scrollTop = history.scrollHeight
|
||||
|
||||
let fetchRes = fetch('http://localhost:8000/chat', options);
|
||||
fetchRes.then(res =>
|
||||
res.json()).then(d => {
|
||||
history.value = `${history.value}${d.name}: ${d.message}\n`;
|
||||
history.scrollTop = history.scrollHeight
|
||||
})
|
||||
};
|
||||
</script>
|
||||
<div>
|
||||
<form id="chatbox" onsubmit="onSubmit(message, history);return false">
|
||||
<textarea id="history" readonly="true" wrap="soft"></textarea>
|
||||
<p>
|
||||
<input type="text" id="message" autocomplete="off">
|
||||
</p>
|
||||
<input type="submit" value="Send">
|
||||
<br />
|
||||
<h3>Knobs to spin:</h3>
|
||||
<table>
|
||||
<tr>
|
||||
<td><label for="max_new_tokens">Max new tokens:</label></td>
|
||||
<td><input id="max_new_tokens" type="number" value="200"></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><label for="num_beams">Num beams:</label>
|
||||
<td><input id="num_beams" type="number" value="8"> (must be divisible by num_beam_groups)
|
||||
</tr>
|
||||
<tr>
|
||||
<td><label for="num_beam_groups">Num beam groups:</label>
|
||||
<td><input id="num_beam_groups" type="number" value="4">
|
||||
</tr>
|
||||
<tr>
|
||||
<td><label for="no_repeat_ngram_size">No repeat ngram size:</label>
|
||||
<td><input id="no_repeat_ngram_size" type="number" value="3">
|
||||
</tr>
|
||||
<tr>
|
||||
<td><label for="length_penalty">Length penalty:</label>
|
||||
<td><input id="length_penalty" type="number" step="0.1" value="1.4">
|
||||
</tr>
|
||||
<tr>
|
||||
<td><label for="diversity_penalty">Diversity penalty:</label>
|
||||
<td><input id="diversity_penalty" type="number" step="0.1" value="0">
|
||||
</tr>
|
||||
<tr>
|
||||
<td><label for="repetition_penalty">Repetition penalty</label>
|
||||
<td><input id="repetition_penalty" type="number" step="0.1" value="2.1">
|
||||
</tr>
|
||||
</table>
|
||||
<input type="reset">
|
||||
</form>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
|
@ -1,5 +1,4 @@
|
|||
import requests
|
||||
import json
|
||||
|
||||
while True:
|
||||
user_input: str = input('>> ')
|
||||
|
@ -19,6 +18,8 @@ while True:
|
|||
}
|
||||
|
||||
response = requests.post(
|
||||
'http://127.0.0.1:8000/chat/', json=packet)
|
||||
'http://127.0.0.1:8000/chat/',
|
||||
json=packet,
|
||||
).json()
|
||||
|
||||
print(response.json())
|
||||
print(f"{response.get('name')}: {response.get('message')}")
|
||||
|
|
Loading…
Add table
Reference in a new issue