diff --git a/pgbot b/pgbot index 40f9b82..67cd595 100755 --- a/pgbot +++ b/pgbot @@ -25,8 +25,6 @@ def init(args: list) -> tuple: 'bot_session', config['api_id'], config['api_hash']).start(bot_token=config['bot_token']) - # db_conn = pgbotlib.dbstuff.DBConn( - # f'dbname={config['db_name']} user={config['db_user']}') db_conn = pgbotlib.dbstuff.DBConn(config['db_spec']) return config, db_conn, client @@ -35,9 +33,10 @@ def init(args: list) -> tuple: def main(): config, db_conn, client = init(sys.argv[1:]) - responder = pgbotlib.response.Responder(config, client, db_conn) + namegen = pgbotlib.misc.NameGenerator(config, db_conn) + responder = pgbotlib.response.Responder(config, client, db_conn, namegen) commander = pgbotlib.commands.Commander(config, client, config['admins'], - db_conn, responder) + db_conn, namegen, responder) sched_thread = threading.Thread( target=pgbotlib.sched.spawn_scheduler, @@ -47,9 +46,6 @@ def main(): @client.on(telethon.events.NewMessage()) async def handle_new_message(event): - chat = await event.get_chat() - result = await client.get_messages(chat.id, ids=[event.message.reply_to.reply_to_msg_id]) - print(result) if event.message.text.startswith('/'): await commander.action(event) else: diff --git a/pgbotlib/api.py b/pgbotlib/api.py index 848f3c0..20ea1ec 100644 --- a/pgbotlib/api.py +++ b/pgbotlib/api.py @@ -3,6 +3,7 @@ import json import random import re +import typing import requests import bs4 @@ -27,16 +28,17 @@ class ApiWrapper: # this is the entry point for the api calls # if you add another api, make sure there is a match here - def call(self, api: str, data: str | None, message: str) -> str: - match api: - case 'img_url': return self.format_img(data) - case 'gif': return self.get_gif() - case 'kmp': return self.get_kmp() - case 'fga': return self.get_fga() - case 'fakenews': return self.get_fakenews() - case 'anek': return self.get_anek() - case 'y_search': return self.y_search(message) - case _: return self.FAILED + # this could have used match - case statement, but python 3.9 + def call(self, api: str, data: typing.Union[str, None], + message: str) -> str: + if api == 'img_url': return self.format_img(data) + elif api == 'gif': return self.get_gif() + elif api == 'kmp': return self.get_kmp() + elif api == 'fga': return self.get_fga() + elif api == 'fakenews': return self.get_fakenews() + elif api == 'anek': return self.get_anek() + elif api == 'y_search': return self.y_search(message) + return self.FAILED def __sanitize_search(self, message: str) -> str: """Removes one of each of the search tokens from the query diff --git a/pgbotlib/commands.py b/pgbotlib/commands.py index 360b938..e55b2aa 100644 --- a/pgbotlib/commands.py +++ b/pgbotlib/commands.py @@ -4,22 +4,31 @@ import telethon import pgbotlib.api import pgbotlib.dbstuff +import pgbotlib.misc import pgbotlib.response +# TODO: quote via response? +# chat = await event.get_chat() +# result = await client.get_messages(chat.id, ids=[event.message.reply_to.reply_to_msg_id]) +# print(result) class Commander: - T_START = frozenset(['start_cmd']) - T_STOP = frozenset(['stop_cmd']) + T_START = frozenset(['cmd_start']) + T_START_E = frozenset(['cmd_start_enabled']) + T_STOP = frozenset(['cmd_stop']) + T_STOP_D = frozenset(['cmd_stop_d']) def __init__(self, config: dict, client: telethon.TelegramClient, admins: list, db_conn: pgbotlib.dbstuff.DBConn, + namegen: pgbotlib.misc.NameGenerator, responder: pgbotlib.response.Responder) -> None: self.config = config self.client = client self.admins = admins self.db_conn = db_conn + self.namegen = namegen self.responder = responder self.available_tokens = [ str(token) for token, _ in self.responder.tokens] @@ -37,25 +46,56 @@ class Commander: values = (','.join(sorted(input_tokenset)), phrase.strip()) return self.db_conn.update(query, values) + def __add_user(self, caller: int, userspec: str) -> bool: + if caller not in self.admins: + print('fuck off!') + return None + user_id, names = userspec.strip().split(' ', 1) + for name in names.strip().split(','): + query = 'INSERT INTO names (tg_id, name) values(%s,%s)' + values = (user_id, name) + self.db_conn.update(query, values) + return True + + def __start_response(self) -> str: + if self.responder.enabled(): + return self.responder.get_response(self.T_START_E) + return self.responder.get_response(self.T_START) + + def __stop_response(self) -> str: + if self.responder.enabled(): + return self.responder.get_response(self.T_STOP) + return self.responder.get_response(self.T_STOP_D) + + def __list_users(self, users: list) -> str: + userlist = [f'{user.id}: {self.namegen.get_tg_name(user)}' + for user in users] + return '\n'.join(userlist) + async def action(self, event: telethon.events.common.EventBuilder) -> None: command = event.message.text sender = await event.get_sender() response = None - match command: - case command if command.startswith('/add '): - if self.__add_entry(sender.id, command[5:]): - response = 'success' - else: - response = 'failure' - case '/list': - response = ', '.join(self.available_tokens) - case '/start': - self.responder.enable() - response = self.responder.get_response(self.T_START) - case '/stop': - self.responder.disable() - response = self.responder.get_response(self.T_STOP) + if command.startswith('/add '): + if self.__add_entry(sender.id, command[5:]): + response = 'success' + else: + response = 'failure' + elif command.startswith('/adduser '): + self.__add_user(sender.id, command[9:]) + elif command == '/list': + response = ', '.join(self.available_tokens) + elif command == '/users': + users = await self.client.get_participants( + entity=event.message.peer_id) + response = self.__list_users(users) + elif command == '/start': + response = self.__start_response() + self.responder.enable() + elif command == '/stop': + response = self.__stop_response() + self.responder.disable() if response: await self.client.send_message(event.message.peer_id, response) return None diff --git a/pgbotlib/misc.py b/pgbotlib/misc.py index cc51a94..47eda4a 100644 --- a/pgbotlib/misc.py +++ b/pgbotlib/misc.py @@ -1,6 +1,5 @@ import telethon import pgbotlib.dbstuff -import pgbotlib.response class NameGenerator: diff --git a/pgbotlib/response.py b/pgbotlib/response.py index a46559f..df0960b 100644 --- a/pgbotlib/response.py +++ b/pgbotlib/response.py @@ -4,6 +4,7 @@ import telethon import yaml import pgbotlib.api import pgbotlib.dbstuff +import pgbotlib.misc def get_token(token_name: str, token_regex: list) -> tuple: @@ -22,10 +23,11 @@ def get_tokens(path: str) -> list: class Responder: def __init__(self, config: dict, client: telethon.TelegramClient, - db_connection: pgbotlib.dbstuff.DBConn) -> None: + db_connection: pgbotlib.dbstuff.DBConn, + namegen: pgbotlib.misc.NameGenerator) -> None: # apiregex matches "{apiname}optional data" # message itself is also passed to the api call method - self.started = True + self.enabled = True self.apiregex = re.compile(r'^\{(\w+)\}(.+)?$') self.namegen = pgbotlib.misc.NameGenerator(config, db_connection) self.tokens = get_tokens(config['response_tokens']) @@ -54,10 +56,13 @@ class Responder: "SELECT response FROM responses WHERE tokens = %s", (key,)) def enable(self) -> None: - self.started = True + self.enabled = True def disable(self) -> None: - self.started = False + self.enabled = False + + def enabled(self) -> bool: + return self.enabled def get_response(self, tokens: frozenset) -> str: counter = 0 @@ -93,7 +98,7 @@ class Responder: async def respond(self, event: telethon.events.common.EventBuilder) -> None: - if not self.started: + if not self.enabled: return None message = event.message.text.lower() tokens = self.__tokenize(message)