Browse Source

fixed multiple bugs with the scrapper and improved speeds for the reddit cog.

tags/v2.0.0
Roxie Gibson 5 years ago
parent
commit
0760c39161
1 changed files with 42 additions and 27 deletions
  1. +42
    -27
      roxbot/cogs/reddit.py

+ 42
- 27
roxbot/cogs/reddit.py View File

@@ -23,9 +23,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import time

import random
import fnmatch
from html import unescape
from bs4 import BeautifulSoup

@@ -44,34 +45,43 @@ class Scrapper:

@staticmethod
async def _imgur_removed(url):
page = await roxbot.http.get_page(url)
try:
page = await roxbot.http.get_page(url)
except UnicodeDecodeError:
return False # This is if it is an image with a weird url
soup = BeautifulSoup(page, 'html.parser')
if "404 Not Found" in soup.title.string:
return True
try:
return bool("removed.png" in soup.img["src"])
except TypeError: # This should protect roxbot in case bs4 returns nothing.
return False

async def imgur_get(self, url):
if url.split(".")[-1] in ("png", "jpg", "jpeg", "gif", "gifv"):
return url
extensions = ("png", "jpg", "jpeg", "gif", "gifv", "mp4", "webm", "webp")
for ext in extensions:
if fnmatch.fnmatch(url.split(".")[-1], ext+"*"):
return url
else:

if await self._imgur_removed(url):
return False

if not roxbot.imgur_token:
return False

base_endpoint = "https://api.imgur.com/3/"
endpoint_album = base_endpoint + "album/{}/images.json".format(url.split("/")[-1])
endpoint_image = base_endpoint + "image/{}.json".format(url.split("/")[-1])

resp = await roxbot.http.api_request(endpoint_image,
headers={"Authorization": "Client-ID {}".format(roxbot.imgur_token)})
if bool(resp["success"]) is True:
return resp["data"]["link"]
else:
resp = await roxbot.http.api_request(endpoint_album,headers={"Authorization": "Client-ID {}".format(roxbot.imgur_token)})
return resp["data"][0]["link"]
try:
resp = await roxbot.http.api_request(endpoint_image, headers={"Authorization": "Client-ID {}".format(roxbot.imgur_token)})
if bool(resp["success"]) is True:
return resp["data"]["link"]
else:
resp = await roxbot.http.api_request(endpoint_album, headers={"Authorization": "Client-ID {}".format(roxbot.imgur_token)})
return resp["data"][0]["link"]
except TypeError as e:
raise e

async def parse_url(self, url):
if url.split(".")[-1] in ("png", "jpg", "jpeg", "gif", "gifv", "webm", "mp4", "webp"):
@@ -108,23 +118,24 @@ class Scrapper:
if not self.post_cache.get(cache_id, False):
self.post_cache[cache_id] = [("", "")]

async def random(self, posts, subreddit, cache_id, nsfw_allowed, loop_amount=20):
def add_to_cache(self, to_cache, cache_id):
self.post_cache[cache_id].append(to_cache)

def cache_clean_up(self, cache_id):
if len(self.post_cache[cache_id]) >= self.cache_limit:
self.post_cache[cache_id].pop(0)

async def random(self, posts, cache_id, nsfw_allowed, loop_amount=20):
"""Function to pick a random post of a given list of reddit posts. Using the internal cache.
Returns:
None for failing to get a url that could be posted.
A dict with the key success and the value False for failing the NSFW check
or the post dict if getting the post is successful
"""
# This returns False if the subreddit is set the NSFW and the channel isn't.
sub_search = await roxbot.http.api_request("https://reddit.com/subreddits/search.json?q={}".format(subreddit))
for listing in sub_search["data"]["children"]:
if listing["data"]["id"] == posts[0]["data"]["subreddit_id"].strip("t5_"):
if listing["data"]["over18"] and not nsfw_allowed:
return False

# Loop to get the post randomly and make sure it hasn't been posted before
url = None
choice = None

for x in range(loop_amount):
choice = random.choice(posts)
url = await self.parse_url(choice["data"]["url"])
@@ -134,10 +145,12 @@ class Scrapper:
url = False # Reject post and move to next loop
else:
# Check cache for post
in_cache = False
for cache in self.post_cache[cache_id]:
if url in cache or choice["data"]["id"] in cache:
continue
break
in_cache = True
if not in_cache:
break

# This is for either a False (NSFW post not allowed) or a None for none.
if url is None:
@@ -145,11 +158,9 @@ class Scrapper:
elif url is False:
return {"success": False}
# Cache post
post = (choice["data"]["id"], url)
self.post_cache[cache_id].append(post)
self.add_to_cache((choice["data"]["id"], url), cache_id)
# If too many posts in cache, remove oldest value.
if len(self.post_cache[cache_id]) >= self.cache_limit:
self.post_cache[cache_id].pop(0)
self.cache_clean_up(cache_id)
return choice["data"]


@@ -167,6 +178,7 @@ class Reddit:
Example:
{command_prefix}subreddit pics
"""
start = time.time()
subreddit = subreddit.lower()
if isinstance(ctx.channel, discord.DMChannel):
cache_id = ctx.author.id
@@ -183,7 +195,10 @@ class Reddit:
else:
nsfw_allowed = True

choice = await self.scrapper.random(posts["children"], subreddit, cache_id, nsfw_allowed)
choice = await self.scrapper.random(posts["children"], cache_id, nsfw_allowed)

end = time.time()
print("Time Taken: {}s".format(end-start))

if not choice:
return await ctx.send("I couldn't find any images from that subreddit.")

Loading…
Cancel
Save