Initial commit
This commit is contained in:
476
kaizenbot/utils.py
Normal file
476
kaizenbot/utils.py
Normal file
@ -0,0 +1,476 @@
|
||||
import asyncio
|
||||
import random
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from threading import Timer as _Timer
|
||||
from time import sleep
|
||||
|
||||
import discord
|
||||
from discord.ext import menus
|
||||
|
||||
from . import logger
|
||||
|
||||
|
||||
class AsyncTimer:
|
||||
|
||||
def __init__(self, start: float, callback, *args):
|
||||
self._callback = callback
|
||||
self._args = args
|
||||
self._start = start
|
||||
self._task = asyncio.ensure_future(self._job())
|
||||
|
||||
async def _job(self):
|
||||
await asyncio.sleep(self._start)
|
||||
await self._callback(*self._args)
|
||||
|
||||
def cancel(self):
|
||||
self._task.cancel()
|
||||
|
||||
|
||||
class AsyncIntervalTimer(AsyncTimer):
|
||||
|
||||
def __init__(self, first_start: float, interval: float, callback, *args):
|
||||
super().__init__(first_start, callback, *args)
|
||||
self._interval = interval
|
||||
|
||||
async def _job(self):
|
||||
await super()._job()
|
||||
while True:
|
||||
await asyncio.sleep(self._interval)
|
||||
await self._callback(*self._args)
|
||||
|
||||
def cancel(self):
|
||||
self._task.cancel()
|
||||
|
||||
|
||||
class IntervalTimer:
|
||||
def __init__(self, first_start: float, interval: float, func, *args):
|
||||
self.first_start = first_start
|
||||
self.interval = interval
|
||||
self.handlerFunction = func
|
||||
self.args = args
|
||||
self.running = False
|
||||
self.timer = _Timer(self.interval, self.run, args)
|
||||
|
||||
def run(self, *args):
|
||||
sleep(self.first_start)
|
||||
self.handlerFunction(*args)
|
||||
while self.running:
|
||||
sleep(self.interval)
|
||||
self.handlerFunction(*args)
|
||||
|
||||
def start(self):
|
||||
self.running = True
|
||||
self.timer.start()
|
||||
|
||||
def cancel(self):
|
||||
self.running = False
|
||||
pass
|
||||
|
||||
|
||||
class Embeds:
|
||||
|
||||
@staticmethod
|
||||
async def send_kaizen_infos(channel):
|
||||
file = discord.File(Path.cwd().joinpath('assets', 'kaizen-round.png'))
|
||||
|
||||
embed = discord.Embed(title='**Kaizen**', description='Folge Kaizen auf den folgenden Kanälen, um nichts mehr zu verpassen!', color=discord.Color(0xff0000))
|
||||
embed.set_thumbnail(url='attachment://kaizen-round.png')
|
||||
embed.add_field(name='**🎥Youtube Hauptkanal**', value='Abonniere Kaizen auf __**[Youtube](https://www.youtube.com/c/KaizenAnime)**__ um kein Anime Video mehr zu verpassen!', inline=False)
|
||||
embed.add_field(name='**📑Youtube Toplisten-Kanal**', value='Abonniere Kaizen\'s __**[Toplisten-Kanal](https://www.youtube.com/channel/UCoijG8JqKb1rRZofx5b-LCw)**__ um kein Toplisten-Video mehr zu verpassen!', inline=False)
|
||||
embed.add_field(name='**📯Youtube Stream-Clips & mehr**', value='Abonniere Kaizen\'s __**[Youtube Kanal](https://www.youtube.com/channel/UCodeTj8SJ-5HhJgC_Elr1Dw)**__ für Stream-Clips & mehr!', inline=False)
|
||||
embed.add_field(name='**📲Twitch**', value='Folge Kaizen auf __**[Twitch](https://www.twitch.tv/kaizenanime)**__ und verpasse keinen Stream mehr! '
|
||||
'Subbe Kaizen um eine exklusive Rolle auf dem Discord Server zu bekommen!', inline=False)
|
||||
embed.add_field(name='**📢Twitter**', value='Folge Kaizen auf __**[Twitter](https://twitter.com/Kaizen_Anime)**__ um aktuelle Informationen zu bekommen und in Videos / Streams mitzuwirken!', inline=False)
|
||||
embed.add_field(name='**📷Instagram**', value='Folge Kaizen auf __**[Instagram](https://www.instagram.com/kaizen.animeyt/)**__!', inline=False)
|
||||
await channel.send(embed=embed, file=file)
|
||||
|
||||
@staticmethod
|
||||
def error_embed(title: typing.Union[str, None] = None, description: typing.Union[str, None] = None) -> discord.Embed:
|
||||
embed = discord.Embed(color=discord.Color(0xff0000))
|
||||
if title:
|
||||
embed.title = title
|
||||
if description:
|
||||
embed.description = description
|
||||
return embed
|
||||
|
||||
@staticmethod
|
||||
def warn_embed(title: typing.Union[str, None] = None, description: typing.Union[str, None] = None) -> discord.Embed:
|
||||
embed = discord.Embed(color=discord.Color(0xff9055))
|
||||
if title:
|
||||
embed.title = title
|
||||
if description:
|
||||
embed.description = description
|
||||
return embed
|
||||
|
||||
@staticmethod
|
||||
def success_embed(title: typing.Union[str, None] = None, description: typing.Union[str, None] = None) -> discord.Embed:
|
||||
embed = discord.Embed(color=discord.Color(0x00ff00))
|
||||
if title:
|
||||
embed.title = title
|
||||
if description:
|
||||
embed.description = description
|
||||
return embed
|
||||
|
||||
|
||||
class MenuListPageSource(menus.ListPageSource):
|
||||
|
||||
def __init__(self, data):
|
||||
super().__init__(data, per_page=1)
|
||||
|
||||
async def format_page(self, menu, embeds):
|
||||
return embeds
|
||||
|
||||
|
||||
def random_sequence_not_in_string(string: str):
|
||||
sequence = '+'
|
||||
while sequence in string:
|
||||
choice = random.choice('+*~-:%&')
|
||||
sequence = choice + sequence + choice
|
||||
|
||||
return sequence
|
||||
|
||||
|
||||
def role_names(member: discord.Member) -> typing.List[str]:
|
||||
return [role.name for role in member.roles]
|
||||
|
||||
|
||||
# ADDED AFTERWARDS: I've stol- copied the following code from a tweepy (https://github.com/tweepy/tweepy) PR or gist (from where exactly I do not know anymore lul)
|
||||
# at the time when they didn't support async actions
|
||||
|
||||
# Tweepy
|
||||
# Copyright 2009-2021 Joshua Roesslein
|
||||
# See LICENSE for details.
|
||||
|
||||
import json
|
||||
from math import inf
|
||||
from platform import python_version
|
||||
|
||||
import aiohttp
|
||||
from oauthlib.oauth1 import Client as OAuthClient
|
||||
from yarl import URL
|
||||
|
||||
import tweepy
|
||||
from tweepy.error import TweepError
|
||||
from tweepy.models import Status
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""Stream realtime Tweets asynchronously
|
||||
Parameters
|
||||
----------
|
||||
consumer_key: :class:`str`
|
||||
Consumer key
|
||||
consumer_secret: :class:`str`
|
||||
Consuemr secret
|
||||
access_token: :class:`str`
|
||||
Access token
|
||||
access_token_secret: :class:`str`
|
||||
Access token secret
|
||||
max_retries: Optional[:class:`int`]
|
||||
Number of times to attempt to (re)connect the stream.
|
||||
Defaults to infinite.
|
||||
proxy: Optional[:class:`str`]
|
||||
Proxy URL
|
||||
"""
|
||||
|
||||
def __init__(self, consumer_key, consumer_secret, access_token,
|
||||
access_token_secret, max_retries=inf, proxy=None):
|
||||
self.consumer_key = consumer_key
|
||||
self.consumer_secret = consumer_secret
|
||||
self.access_token = access_token
|
||||
self.access_token_secret = access_token_secret
|
||||
self.max_retries = max_retries
|
||||
self.proxy = proxy
|
||||
|
||||
self.session = None
|
||||
self.task = None
|
||||
self.user_agent = (
|
||||
f"Python/{python_version()} "
|
||||
f"aiohttp/{aiohttp.__version__} "
|
||||
f"Tweepy/{tweepy.__version__}"
|
||||
)
|
||||
|
||||
async def _connect(self, method, endpoint, params={}, headers=None,
|
||||
body=None):
|
||||
error_count = 0
|
||||
# https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/connecting
|
||||
stall_timeout = 90
|
||||
network_error_wait = network_error_wait_step = 0.25
|
||||
network_error_wait_max = 16
|
||||
http_error_wait = http_error_wait_start = 5
|
||||
http_error_wait_max = 320
|
||||
http_420_error_wait_start = 60
|
||||
|
||||
oauth_client = OAuthClient(self.consumer_key, self.consumer_secret,
|
||||
self.access_token, self.access_token_secret)
|
||||
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers={"User-Agent": self.user_agent},
|
||||
timeout=aiohttp.ClientTimeout(sock_read=stall_timeout)
|
||||
)
|
||||
|
||||
url = f"https://stream.twitter.com/1.1/{endpoint}.json"
|
||||
url = str(URL(url).with_query(sorted(params.items())))
|
||||
|
||||
try:
|
||||
while error_count <= self.max_retries:
|
||||
request_url, request_headers, request_body = oauth_client.sign(
|
||||
url, method, body, headers
|
||||
)
|
||||
try:
|
||||
async with self.session.request(
|
||||
method, request_url, headers=request_headers,
|
||||
data=request_body, proxy=self.proxy
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
error_count = 0
|
||||
http_error_wait = http_error_wait_start
|
||||
network_error_wait = network_error_wait_step
|
||||
|
||||
await self.on_connect()
|
||||
|
||||
async for line in resp.content:
|
||||
line = line.strip()
|
||||
if line:
|
||||
await self.on_data(line)
|
||||
else:
|
||||
await self.on_keep_alive()
|
||||
|
||||
await self.on_closed(resp)
|
||||
else:
|
||||
await self.on_request_error(resp.status)
|
||||
|
||||
error_count += 1
|
||||
|
||||
if resp.status == 420:
|
||||
if http_error_wait < http_420_error_wait_start:
|
||||
http_error_wait = http_420_error_wait_start
|
||||
|
||||
await asyncio.sleep(http_error_wait)
|
||||
|
||||
http_error_wait *= 2
|
||||
if resp.status != 420:
|
||||
if http_error_wait > http_error_wait_max:
|
||||
http_error_wait = http_error_wait_max
|
||||
except (aiohttp.ClientConnectionError,
|
||||
aiohttp.ClientPayloadError) as e:
|
||||
await self.on_connection_error()
|
||||
|
||||
await asyncio.sleep(network_error_wait)
|
||||
|
||||
network_error_wait += network_error_wait_step
|
||||
if network_error_wait > network_error_wait_max:
|
||||
network_error_wait = network_error_wait_max
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as e:
|
||||
await self.on_exception(e)
|
||||
finally:
|
||||
await self.session.close()
|
||||
await self.on_disconnect()
|
||||
|
||||
async def filter(self, follow=None, track=None, locations=None,
|
||||
stall_warnings=False):
|
||||
"""This method is a coroutine.
|
||||
Filter realtime Tweets
|
||||
https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/api-reference/post-statuses-filter
|
||||
Parameters
|
||||
----------
|
||||
follow: Optional[List[Union[:class:`int`, :class:`str`]]]
|
||||
A list of user IDs, indicating the users to return statuses for in
|
||||
the stream. See https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/basic-stream-parameters
|
||||
for more information.
|
||||
track: Optional[List[:class:`str`]]
|
||||
Keywords to track. Phrases of keywords are specified by a list. See
|
||||
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
|
||||
for more information.
|
||||
locations: Optional[List[:class:`float`]]
|
||||
Specifies a set of bounding boxes to track. See
|
||||
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
|
||||
for more information.
|
||||
stall_warnings: Optional[:class:`bool`]
|
||||
Specifies whether stall warnings should be delivered. See
|
||||
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
|
||||
for more information. Def
|
||||
logger = logging.getLogger('kaizen')aults to False.
|
||||
Returns :class:`asyncio.Task`
|
||||
"""
|
||||
if self.task is not None and not self.task.done():
|
||||
raise TweepError("Stream is already connected")
|
||||
|
||||
endpoint = "statuses/filter"
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
body = {}
|
||||
if follow is not None:
|
||||
body["follow"] = ','.join(map(str, follow))
|
||||
if track is not None:
|
||||
body["track"] = ','.join(map(str, track))
|
||||
if locations is not None:
|
||||
if len(locations) % 4:
|
||||
raise TweepError(
|
||||
"Number of location coordinates should be a multiple of 4"
|
||||
)
|
||||
body["locations"] = ','.join(
|
||||
f"{location:.4f}" for location in locations
|
||||
)
|
||||
if stall_warnings:
|
||||
body["stall_warnings"] = "true"
|
||||
|
||||
self.task = asyncio.create_task(
|
||||
self._connect("POST", endpoint, headers=headers, body=body or None)
|
||||
)
|
||||
return self.task
|
||||
|
||||
async def sample(self, stall_warnings=False):
|
||||
"""This method is a coroutine.
|
||||
Sample realtime Tweets
|
||||
https://developer.twitter.com/en/docs/twitter-api/v1/tweets/sample-realtime/api-reference/get-statuses-sample
|
||||
Parameters
|
||||
----------
|
||||
stall_warnings: Optional[:class:`bool`]
|
||||
Specifies whether stall warnings should be delivered. See
|
||||
https://developer.twitter.com/en/docs/tweets/filter-realtime/guides/basic-stream-parameters
|
||||
for more information. Defaults to False.
|
||||
Returns :class:`asyncio.Task`
|
||||
"""
|
||||
if self.task is not None and not self.task.done():
|
||||
raise TweepError("Stream is already connected")
|
||||
|
||||
endpoint = "statuses/sample"
|
||||
|
||||
params = {}
|
||||
if stall_warnings:
|
||||
params["stall_warnings"] = "true"
|
||||
|
||||
self.task = asyncio.create_task(
|
||||
self._connect("GET", endpoint, params=params)
|
||||
)
|
||||
return self.task
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect the stream"""
|
||||
if self.task is not None:
|
||||
self.task.cancel()
|
||||
|
||||
async def on_closed(self, resp):
|
||||
"""This method is a coroutine.
|
||||
This is called when the stream has been closed by Twitter.
|
||||
"""
|
||||
logger.error("Stream connection closed by Twitter")
|
||||
|
||||
async def on_connect(self):
|
||||
"""This method is a coroutine.
|
||||
This is called after successfully connecting to the streaming API.
|
||||
"""
|
||||
# logger.info("Stream connected")
|
||||
|
||||
async def on_connection_error(self):
|
||||
"""This method is a coroutine.
|
||||
This is called when the stream connection errors or times out.
|
||||
"""
|
||||
# logger.error("Stream connection has errored or timed out")
|
||||
|
||||
async def on_disconnect(self):
|
||||
"""This method is a coroutine.
|
||||
This is called when the stream has disconnected.
|
||||
"""
|
||||
# logger.info("Stream disconnected")
|
||||
|
||||
async def on_exception(self, exception):
|
||||
"""This method is a coroutine.
|
||||
This is called when an unhandled exception occurs.
|
||||
"""
|
||||
logger.exception("Stream encountered an exception")
|
||||
|
||||
async def on_keep_alive(self):
|
||||
"""This method is a coroutine.
|
||||
This is called when a keep-alive message is received.
|
||||
"""
|
||||
#logger.debug("Received keep-alive message")
|
||||
|
||||
async def on_request_error(self, status_code):
|
||||
"""This method is a coroutine.
|
||||
This is called when a non-200 HTTP status code is encountered.
|
||||
"""
|
||||
# logger.error("Stream encountered HTTP Error: %d", status_code)
|
||||
|
||||
async def on_data(self, raw_data):
|
||||
"""This method is a coroutine.
|
||||
This is called when raw data is received from the stream.
|
||||
This method handles sending the data to other methods, depending on the
|
||||
message type.
|
||||
https://developer.twitter.com/en/docs/twitter-api/v1/tweets/filter-realtime/guides/streaming-message-types
|
||||
"""
|
||||
data = json.loads(raw_data)
|
||||
|
||||
if "in_reply_to_status_id" in data:
|
||||
status = Status.parse(None, data)
|
||||
return await self.on_status(status)
|
||||
if "delete" in data:
|
||||
delete = data["delete"]["status"]
|
||||
return await self.on_delete(delete["id"], delete["user_id"])
|
||||
if "disconnect" in data:
|
||||
return await self.on_disconnect_message(data["disconnect"])
|
||||
if "limit" in data:
|
||||
return await self.on_limit(data["limit"]["track"])
|
||||
if "scrub_geo" in data:
|
||||
return await self.on_scrub_geo(data["scrub_geo"])
|
||||
if "status_withheld" in data:
|
||||
return await self.on_status_withheld(data["status_withheld"])
|
||||
if "user_withheld" in data:
|
||||
return await self.on_user_withheld(data["user_withheld"])
|
||||
if "warning" in data:
|
||||
return await self.on_warning(data["warning"])
|
||||
|
||||
logger.warning("Received unknown message type: %s", raw_data)
|
||||
|
||||
async def on_status(self, status):
|
||||
"""This method is a coroutine.
|
||||
This is called when a status is received.
|
||||
"""
|
||||
# logger.debug("Received status: %d", status.id)
|
||||
|
||||
async def on_delete(self, status_id, user_id):
|
||||
"""This method is a coroutine.
|
||||
This is called when a status deletion notice is received.
|
||||
"""
|
||||
# logger.debug("Received status deletion notice: %d", status_id)
|
||||
|
||||
async def on_disconnect_message(self, message):
|
||||
"""This method is a coroutine.
|
||||
This is called when a disconnect message is received.
|
||||
"""
|
||||
# logger.warning("Received disconnect message: %s", message)
|
||||
|
||||
async def on_limit(self, track):
|
||||
"""This method is a coroutine.
|
||||
This is called when a limit notice is received.
|
||||
"""
|
||||
# logger.debug("Received limit notice: %d", track)
|
||||
|
||||
async def on_scrub_geo(self, notice):
|
||||
"""This method is a coroutine.
|
||||
This is called when a location deletion notice is received.
|
||||
"""
|
||||
# logger.debug("Received location deletion notice: %s", notice)
|
||||
|
||||
async def on_status_withheld(self, notice):
|
||||
"""This method is a coroutine.
|
||||
This is called when a status withheld content notice is received.
|
||||
"""
|
||||
# logger.debug("Received status withheld content notice: %s", notice)
|
||||
|
||||
async def on_user_withheld(self, notice):
|
||||
"""This method is a coroutine.
|
||||
This is called when a user withheld content notice is received.
|
||||
"""
|
||||
# logger.debug("Received user withheld content notice: %s", notice)
|
||||
|
||||
async def on_warning(self, notice):
|
||||
"""This method is a coroutine.
|
||||
This is called when a stall warning message is received.
|
||||
"""
|
||||
# logger.warning("Received stall warning: %s", notice)
|
Reference in New Issue
Block a user