diff --git a/api/main.py b/api/main.py deleted file mode 100755 index 37540d3..0000000 --- a/api/main.py +++ /dev/null @@ -1,82 +0,0 @@ -from fastapi import FastAPI, Request -from pydantic import BaseModel -from fastapi.middleware.cors import CORSMiddleware - -from transformers.models.auto.tokenization_auto import AutoTokenizer -from transformers.models.auto.modeling_auto import AutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained( - "microsoft/DialoGPT-large", padding_side='left') -model = AutoModelForCausalLM.from_pretrained( - "../train/cartman/models/output-medium") - - -class Packet(BaseModel): - message: str - max_new_tokens: int - num_beams: int - num_beam_groups: int - no_repeat_ngram_size: int - length_penalty: float - diversity_penalty: float - repetition_penalty: float - early_stopping: bool - - -def cartman_respond(packet: Packet) -> str: - input_ids = tokenizer(packet.message + - tokenizer.eos_token, return_tensors="pt").input_ids - - outputs = model.generate( - input_ids, - pad_token_id=tokenizer.eos_token_id, - max_new_tokens=packet.max_new_tokens, - num_beams=packet.num_beams, - num_beam_groups=packet.num_beam_groups, - no_repeat_ngram_size=packet.no_repeat_ngram_size, - length_penalty=packet.length_penalty, - diversity_penalty=packet.diversity_penalty, - repetition_penalty=packet.repetition_penalty, - early_stopping=packet.early_stopping, - - # do_sample = True, - # top_k = 100, - # top_p = 0.7, - # temperature = 0.8, - ) - return tokenizer.decode(outputs[:, input_ids.shape[-1]:][0], - skip_special_tokens=True) - - -api = FastAPI() - -api.add_middleware( - CORSMiddleware, - allow_origins=['*'], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -@api.post('/chat/') -async def getInformation(request: Request) -> dict[str, str]: - data = await request.json() - - packet = Packet( - message=data.get('message'), - max_new_tokens=data.get('max_new_tokens'), - num_beams=data.get('num_beams'), - num_beam_groups=data.get('num_beam_groups'), - no_repeat_ngram_size=data.get('no_repeat_ngram_size'), - length_penalty=data.get('length_penalty'), - diversity_penalty=data.get('diversity_penalty'), - repetition_penalty=data.get('repetition_penalty'), - early_stopping=data.get('early_stopping'), - ) - - print(packet.message) - response = cartman_respond(packet) - print(response) - - return {"Cartman": response} diff --git a/api/run b/api/run index b433d51..4d9181b 100755 --- a/api/run +++ b/api/run @@ -1,3 +1,3 @@ #!/bin/bash -uvicorn main:api --host 10.0.1.1 --reload +uvicorn cartman:api --host 10.0.1.1 --reload diff --git a/api/src/bots/cartman.py b/api/src/bots/cartman.py new file mode 100644 index 0000000..a4e499d --- /dev/null +++ b/api/src/bots/cartman.py @@ -0,0 +1,44 @@ +from ..models import Packet, BotResponse + +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.models.auto.modeling_auto import AutoModelForCausalLM + +tokenizer = AutoTokenizer.from_pretrained( + "microsoft/DialoGPT-large", padding_side='left' +) +model = AutoModelForCausalLM.from_pretrained( + "../train/cartman/models/output-medium" +) + + +def cartman(packet: Packet) -> BotResponse: + input_ids = tokenizer( + packet.message + tokenizer.eos_token, + return_tensors="pt" + ).input_ids + + outputs = model.generate( + input_ids, + pad_token_id=tokenizer.eos_token_id, + max_new_tokens=packet.max_new_tokens, + num_beams=packet.num_beams, + num_beam_groups=packet.num_beam_groups, + no_repeat_ngram_size=packet.no_repeat_ngram_size, + length_penalty=packet.length_penalty, + diversity_penalty=packet.diversity_penalty, + repetition_penalty=packet.repetition_penalty, + early_stopping=packet.early_stopping, + + # do_sample = True, + # top_k = 100, + # top_p = 0.7, + # temperature = 0.8, + ) + + return BotResponse( + name='Cartman', + message=tokenizer.decode( + outputs[:, input_ids.shape[-1]:][0], + skip_special_tokens=True + ) + ) diff --git a/api/src/models.py b/api/src/models.py new file mode 100644 index 0000000..4ea7c70 --- /dev/null +++ b/api/src/models.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + + +class Packet(BaseModel): + bot_name: str + message: str + max_new_tokens: int + num_beams: int + num_beam_groups: int + no_repeat_ngram_size: int + length_penalty: float + diversity_penalty: float + repetition_penalty: float + early_stopping: bool + + +class BotResponse(BaseModel): + name: str + message: str diff --git a/api/test/test.html b/api/test/test.html new file mode 100644 index 0000000..cc125bd --- /dev/null +++ b/api/test/test.html @@ -0,0 +1,117 @@ + + + +
+