validate subscribe url type, #299

This commit is contained in:
Simon 2023-08-23 21:22:09 +07:00
parent 33ff586af4
commit a2eb42ebb9
No known key found for this signature in database
GPG Key ID: 2C15AA5E89985DD4
4 changed files with 22 additions and 9 deletions

View File

@ -331,7 +331,7 @@ class ChannelApiListView(ApiBaseView):
pending = [i["channel_id"] for i in to_add if i["channel_subscribed"]] pending = [i["channel_id"] for i in to_add if i["channel_subscribed"]]
url_str = " ".join(pending) url_str = " ".join(pending)
subscribe_to.delay(url_str) subscribe_to.delay(url_str, expected_type="channel")
return Response(data) return Response(data)

View File

@ -332,7 +332,7 @@ class SubscriptionHandler:
self.task = task self.task = task
self.to_subscribe = False self.to_subscribe = False
def subscribe(self): def subscribe(self, expected_type=False):
"""subscribe to url_str items""" """subscribe to url_str items"""
if self.task: if self.task:
self.task.send_progress(["Processing form content."]) self.task.send_progress(["Processing form content."])
@ -343,11 +343,16 @@ class SubscriptionHandler:
if self.task: if self.task:
self._notify(idx, item, total) self._notify(idx, item, total)
self.subscribe_type(item) self.subscribe_type(item, expected_type=expected_type)
def subscribe_type(self, item): def subscribe_type(self, item, expected_type):
"""process single item""" """process single item"""
if item["type"] == "playlist": if item["type"] == "playlist":
if expected_type and expected_type != "playlist":
raise TypeError(
f"expected {expected_type} url but got {item.get('type')}"
)
PlaylistSubscription().process_url_str([item]) PlaylistSubscription().process_url_str([item])
return return
@ -360,6 +365,11 @@ class SubscriptionHandler:
else: else:
raise ValueError("failed to subscribe to: " + item["url"]) raise ValueError("failed to subscribe to: " + item["url"])
if expected_type and expected_type != "channel":
raise TypeError(
f"expected {expected_type} url but got {item.get('type')}"
)
self._subscribe(channel_id) self._subscribe(channel_id)
def _subscribe(self, channel_id): def _subscribe(self, channel_id):

View File

@ -343,9 +343,12 @@ def re_sync_thumbs(self):
@shared_task(bind=True, name="subscribe_to", base=BaseTask) @shared_task(bind=True, name="subscribe_to", base=BaseTask)
def subscribe_to(self, url_str): def subscribe_to(self, url_str: str, expected_type: str | bool = False):
"""take a list of urls to subscribe to""" """
SubscriptionHandler(url_str, task=self).subscribe() take a list of urls to subscribe to
optionally validate expected_type channel / playlist
"""
SubscriptionHandler(url_str, task=self).subscribe(expected_type)
@shared_task(bind=True, name="index_playlists", base=BaseTask) @shared_task(bind=True, name="index_playlists", base=BaseTask)

View File

@ -736,7 +736,7 @@ class ChannelView(ArchivistResultsView):
if subscribe_form.is_valid(): if subscribe_form.is_valid():
url_str = request.POST.get("subscribe") url_str = request.POST.get("subscribe")
print(url_str) print(url_str)
subscribe_to.delay(url_str) subscribe_to.delay(url_str, expected_type="channel")
sleep(1) sleep(1)
return redirect("channel", permanent=True) return redirect("channel", permanent=True)
@ -879,7 +879,7 @@ class PlaylistView(ArchivistResultsView):
if subscribe_form.is_valid(): if subscribe_form.is_valid():
url_str = request.POST.get("subscribe") url_str = request.POST.get("subscribe")
print(url_str) print(url_str)
subscribe_to.delay(url_str) subscribe_to.delay(url_str, expected_type="playlist")
sleep(1) sleep(1)
return redirect("playlist") return redirect("playlist")