309 lines
12 KiB
Python
309 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from .mwt import MWT
|
|
import sys
|
|
import telegram
|
|
import telegram.ext
|
|
from telegram.ext import Updater
|
|
from telegram.ext import CommandHandler
|
|
from telegram.ext import MessageHandler
|
|
from telegram.ext.filters import Filters
|
|
from spellchecker import SpellChecker
|
|
import logging
|
|
from django.utils import timezone
|
|
import random
|
|
import datetime
|
|
from io import BytesIO, StringIO
|
|
from telegram.constants import MAX_MESSAGE_LENGTH
|
|
|
|
CHECKERS = dict()
|
|
Group: 'propergrammar.models.Group' = None
|
|
GroupDictionaryEntry: 'propergrammar.models.GroupDictionaryEntry' = None
|
|
logger = None
|
|
models = None
|
|
|
|
|
|
@MWT(timeout=60*5)
|
|
def get_admin_ids(bot, chat_id):
|
|
"""Returns a list of admin IDs for a given chat. Results are cached for 1 hour."""
|
|
return [admin.user.id for admin in bot.get_chat_administrators(chat_id)]
|
|
|
|
|
|
def start(update: telegram.Update, context: telegram.ext.CallbackContext):
|
|
chattype = update.message.chat.type
|
|
if chattype in ['private']:
|
|
context.bot.send_message(
|
|
chat_id=update.message.chat_id,
|
|
text="Invite me to your server!"
|
|
)
|
|
elif chattype in ['group', 'supergroup']:
|
|
admins = get_admin_ids(context.bot, update.message.chat_id)
|
|
if update.message.from_user.id in admins:
|
|
is_group_admin(update, context)
|
|
send_help_message(update, context)
|
|
|
|
|
|
def send_help_message(update: telegram.Update, context: telegram.ext.CallbackContext):
|
|
context.bot.send_message(
|
|
chat_id=update.message.chat_id,
|
|
text='''Try using my services with the commands:
|
|
/add_group_dictionary
|
|
/show_group_dictionary
|
|
/remove_group_dictionary
|
|
/export_group_dictionary
|
|
(there is a clear_group_dictionary, but you'll have to type it yourself)'''
|
|
)
|
|
|
|
|
|
def split_long_messages(message: str) -> list:
|
|
if len(message) <= MAX_MESSAGE_LENGTH:
|
|
return [message]
|
|
else:
|
|
lst = []
|
|
lines = message.splitlines()
|
|
size = 0
|
|
buffer = ''
|
|
for line in lines:
|
|
if len(buffer) + len(line) + 1 <= MAX_MESSAGE_LENGTH:
|
|
buffer += line + '\n'
|
|
else:
|
|
lst.append(buffer)
|
|
buffer = ''
|
|
if len(line) + 1 <= MAX_MESSAGE_LENGTH:
|
|
buffer += line + '\n'
|
|
else:
|
|
remainder = line
|
|
while len(remainder) + 1 >= MAX_MESSAGE_LENGTH:
|
|
humongous, remainder = (
|
|
remainder[:MAX_MESSAGE_LENGTH], remainder[MAX_MESSAGE_LENGTH:])
|
|
lst.append(humongous)
|
|
lst.append(remainder)
|
|
buffer = ''
|
|
if len(buffer) > 0:
|
|
lst.append(buffer)
|
|
return lst
|
|
|
|
|
|
def get_group(chat):
|
|
group = Group.objects.filter(chat_id=chat.id).first()
|
|
title = chat.title
|
|
if title is None:
|
|
title = f'{chat.first_name} {chat.last_name}'
|
|
if group is None:
|
|
group = Group(chat_id=chat.id, name=title)
|
|
group.save()
|
|
elif group.name != title:
|
|
group.name = title
|
|
group.save()
|
|
return group
|
|
|
|
|
|
def cmd_null(update: telegram.Update, context: telegram.ext.CallbackContext):
|
|
return
|
|
|
|
|
|
def remove_links(text, entities):
|
|
if isinstance(text, str):
|
|
text = list(text)
|
|
for entity in entities:
|
|
if entity.type in ['url', 'email', 'bot_command']:
|
|
for i in range(entity.offset, entity.offset+entity.length):
|
|
text[i] = ''
|
|
return ''.join(text)
|
|
else:
|
|
return text
|
|
|
|
|
|
def handle_message(update: telegram.Update, context: telegram.ext.CallbackContext):
|
|
chattype = update.message.chat.type
|
|
if chattype in ['group', 'supergroup', 'private']:
|
|
to_check = ' '.join(
|
|
list(filter(lambda x: x is not None, [
|
|
remove_links(update.message.text, update.message.entities),
|
|
remove_links(update.message.caption,
|
|
update.message.caption_entities)
|
|
]))
|
|
).strip()
|
|
if len(to_check) > 0:
|
|
group = get_group(update.message.chat)
|
|
languages = models.VALID_LANGUAGES_LST
|
|
ignored = [x.word for x in group.words.all()]
|
|
language_rank = list()
|
|
for language in languages:
|
|
chk = CHECKERS[language]
|
|
words = chk.split_words(to_check)
|
|
if len(words) > 0:
|
|
language_rank.append(
|
|
(len(chk.known(words))/len(words), language)
|
|
)
|
|
language_rank.sort()
|
|
language_rank.reverse()
|
|
language_confidence, language_main = language_rank[0]
|
|
if language_confidence == 0:
|
|
language_main = 'en'
|
|
checker_main = CHECKERS[language_main]
|
|
words = checker_main.split_words(to_check)
|
|
unknown = checker_main.unknown(words)
|
|
for lng in languages:
|
|
unknown = CHECKERS[lng].unknown(unknown)
|
|
unknown = set(unknown)
|
|
for unknown_word in list(unknown):
|
|
for ignored_word in ignored:
|
|
if unknown_word.lower() == ignored_word.lower():
|
|
unknown.difference_update(unknown_word)
|
|
unknown = sorted(list(unknown))
|
|
formatted_suggestions = []
|
|
for typo in unknown:
|
|
typo_fixed = checker_main.correction(typo)
|
|
typo_fixes = checker_main.candidates(typo)
|
|
typo_fixes -= {typo_fixed}
|
|
formatted_suggestions.append(
|
|
f'{typo} → {typo_fixed}; {", ".join(typo_fixes)}'
|
|
)
|
|
if len(unknown) > 0:
|
|
lecture = models.LANGUAGES_SCOLDING_DICT[language_main].format(
|
|
models.VALID_LANGUAGES_DICT[language_main]
|
|
)+'\n'+"\n".join(formatted_suggestions)
|
|
for message_segments in split_long_messages(lecture):
|
|
context.bot.send_message(
|
|
chat_id=update.message.chat_id,
|
|
reply_to_message_id=update.message.message_id,
|
|
text=message_segments
|
|
)
|
|
|
|
|
|
def cmd_agd(update: telegram.Update, context: telegram.ext.CallbackContext):
|
|
chattype = update.message.chat.type
|
|
if chattype in ['group', 'supergroup']:
|
|
admins = get_admin_ids(context.bot, update.message.chat_id)
|
|
if update.message.from_user.id in admins:
|
|
group = get_group(update.message.chat)
|
|
wordlist = sorted(
|
|
list(set([w.word.lower() for w in group.words.all()])))
|
|
new_words = ' '.join(remove_links(
|
|
update.message.text, update.message.entities).splitlines()).split()
|
|
status = f'Adding {len(new_words)} new words...\n'
|
|
for nword in new_words:
|
|
status += f'{nword}... '
|
|
if nword.lower() in wordlist:
|
|
status += 'already on list\n'
|
|
else:
|
|
GroupDictionaryEntry(group=group, word=nword).save()
|
|
wordlist.append(nword.lower())
|
|
status += 'OK\n'
|
|
status += f'New word count: {len(group.words.all())}'
|
|
for segment in split_long_messages(status):
|
|
context.bot.send_message(
|
|
chat_id=update.message.chat_id,
|
|
reply_to_message_id=update.message.message_id,
|
|
text=segment
|
|
)
|
|
|
|
|
|
def cmd_rgd(update: telegram.Update, context: telegram.ext.CallbackContext):
|
|
chattype = update.message.chat.type
|
|
if chattype in ['group', 'supergroup']:
|
|
admins = get_admin_ids(context.bot, update.message.chat_id)
|
|
if update.message.from_user.id in admins:
|
|
group = get_group(update.message.chat)
|
|
old_words = ' '.join(remove_links(
|
|
update.message.text, update.message.entities).splitlines()).split()
|
|
status = f'Removing {len(old_words)} words...\n'
|
|
for oword in old_words:
|
|
status += f'{oword}... '
|
|
entry = group.words.all().filter(word__iexact=oword).first()
|
|
if entry is None:
|
|
status += 'not found\n'
|
|
else:
|
|
entry.delete()
|
|
status += 'REMOVED\n'
|
|
status += f'New word count: {len(group.words.all())}'
|
|
for segment in split_long_messages(status):
|
|
context.bot.send_message(
|
|
chat_id=update.message.chat_id,
|
|
reply_to_message_id=update.message.message_id,
|
|
text=segment
|
|
)
|
|
|
|
|
|
def cmd_sgd(update: telegram.Update, context: telegram.ext.CallbackContext):
|
|
chattype = update.message.chat.type
|
|
if chattype in ['group', 'supergroup']:
|
|
admins = get_admin_ids(context.bot, update.message.chat_id)
|
|
if update.message.from_user.id in admins:
|
|
group = get_group(update.message.chat)
|
|
wordlist = sorted([w.word for w in group.words.all()])
|
|
wordlist_str = '\n'.join(wordlist)
|
|
msg = 'Here are all your %d entries you have on your group\'s dictionary:\n%s' % (
|
|
len(wordlist), wordlist_str
|
|
)
|
|
for fragment in split_long_messages(msg):
|
|
context.bot.send_message(
|
|
chat_id=update.message.chat_id,
|
|
reply_to_message_id=update.message.message_id,
|
|
text=fragment
|
|
)
|
|
|
|
|
|
def cmd_egd(update: telegram.Update, context: telegram.ext.CallbackContext):
|
|
chattype = update.message.chat.type
|
|
if chattype in ['group', 'supergroup']:
|
|
admins = get_admin_ids(context.bot, update.message.chat_id)
|
|
if update.message.from_user.id in admins:
|
|
group = get_group(update.message.chat)
|
|
wordlist = sorted([w.word for w in group.words.all()])
|
|
wordlist_str = '\n'.join(wordlist)
|
|
wordlist_str += '\n'
|
|
bio = BytesIO(wordlist_str.encode('UTF-8'))
|
|
bio.name = f'DictExport_{update.message.chat_id}_{str(update.message.date).replace(" ", "_").replace(":", "-")}.txt'
|
|
context.bot.send_document(
|
|
chat_id=update.message.chat_id,
|
|
reply_to_message_id=update.message.message_id,
|
|
document=bio
|
|
)
|
|
|
|
|
|
def cmd_cgd(update: telegram.Update, context: telegram.ext.CallbackContext):
|
|
chattype = update.message.chat.type
|
|
if chattype in ['group', 'supergroup']:
|
|
admins = get_admin_ids(context.bot, update.message.chat_id)
|
|
if update.message.from_user.id in admins:
|
|
group = get_group(update.message.chat)
|
|
group.words.all().delete()
|
|
context.bot.send_message(
|
|
chat_id=update.message.chat_id,
|
|
reply_to_message_id=update.message.message_id,
|
|
text='Erased all entries from local dictionary.'
|
|
)
|
|
|
|
|
|
def start_bot(token):
|
|
global Group
|
|
global GroupDictionaryEntry
|
|
global logger
|
|
global CHECKERS
|
|
global models
|
|
from .models import Group
|
|
from .models import GroupDictionaryEntry
|
|
from . import models
|
|
for lng in models.VALID_LANGUAGES_LST:
|
|
CHECKERS[lng] = SpellChecker(lng)
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
u = Updater(token, use_context=True)
|
|
d = u.dispatcher
|
|
d.add_handler(CommandHandler('start', start))
|
|
d.add_handler(CommandHandler('add_group_dictionary', cmd_agd))
|
|
d.add_handler(CommandHandler('remove_group_dictionary', cmd_rgd))
|
|
d.add_handler(CommandHandler('show_group_dictionary', cmd_sgd))
|
|
d.add_handler(CommandHandler('export_group_dictionary', cmd_egd))
|
|
d.add_handler(CommandHandler('clear_group_dictionary', cmd_cgd))
|
|
d.add_handler(CommandHandler('help', send_help_message))
|
|
d.add_handler(MessageHandler(Filters.all, handle_message))
|
|
u.start_polling()
|