diff --git a/lulzbot.py b/lulzbot.py index de755c1..799bc87 100644 --- a/lulzbot.py +++ b/lulzbot.py @@ -625,6 +625,24 @@ async def on_message(message): if message.author == client.user: return + if message.channel.name == 'cartman': + async with message.channel.typing(): + new_user_input_ids = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt') + bot_output = new_user_input_ids + bot_input_ids = torch.cat([new_user_input_ids, bot_output]) + bot_output = model.generate( + bot_input_ids, max_length= 200, + pad_token_id=tokenizer.eos_token_id, + no_repeat_ngram_size=3, + do_sample=True, + top_k=100, + top_p=0.7, + temperature=.8 + ) + + await message.channel.send('{}'.format(tokenizer.decode(bot_output[:,bot_input_ids.shape[-1]:][0], skip_special_tokens=True))) + return + if message.channel.name == 'shitposting': if user_message.lower().count('musk') > 0: await message.channel.send(user_tweet('elonmusk')) @@ -714,21 +732,4 @@ I was finally there\n\ To sit on my throne as the Prince of Bel-Air') return - if message.channel.name == 'cartman': - new_user_input_ids = tokenizer.encode(user_message + tokenizer.eos_token, return_tensors='pt') - bot_output = new_user_input_ids - bot_input_ids = torch.cat([new_user_input_ids, bot_output]) - bot_output = model.generate( - bot_input_ids, max_length= 200, - pad_token_id=tokenizer.eos_token_id, - no_repeat_ngram_size=3, - do_sample=True, - top_k=100, - top_p=0.7, - temperature=.8 - ) - - await message.channel.send('{}'.format(tokenizer.decode(bot_output[:,bot_input_ids.shape[-1]:][0], skip_special_tokens=True))) - return - client.run(TOKEN)