diff --git a/src/fediblockhole/__init__.py b/src/fediblockhole/__init__.py index 9be1b78..67b1f06 100755 --- a/src/fediblockhole/__init__.py +++ b/src/fediblockhole/__init__.py @@ -59,16 +59,16 @@ def sync_blocklists(conf: argparse.Namespace): # Add extra export fields if defined in config export_fields.extend(conf.export_fields) - blocklists = {} + blocklists = [] # Fetch blocklists from URLs if not conf.no_fetch_url: - blocklists = fetch_from_urls(blocklists, conf.blocklist_url_sources, - import_fields, conf.save_intermediate, conf.savedir, export_fields) + blocklists.extend(fetch_from_urls(conf.blocklist_url_sources, + import_fields, conf.save_intermediate, conf.savedir, export_fields)) # Fetch blocklists from remote instances if not conf.no_fetch_instance: - blocklists = fetch_from_instances(blocklists, conf.blocklist_instance_sources, - import_fields, conf.save_intermediate, conf.savedir, export_fields) + blocklists.extend(fetch_from_instances(conf.blocklist_instance_sources, + import_fields, conf.save_intermediate, conf.savedir, export_fields)) # Merge blocklists into an update dict merged = merge_blocklists(blocklists, conf.mergeplan) @@ -80,48 +80,48 @@ def sync_blocklists(conf: argparse.Namespace): # Save the final mergelist, if requested if conf.blocklist_savefile: log.info(f"Saving merged blocklist to {conf.blocklist_savefile}") - save_blocklist_to_file(merged.values(), conf.blocklist_savefile, export_fields) + save_blocklist_to_file(merged, conf.blocklist_savefile, export_fields) # Push the blocklist to destination instances if not conf.no_push_instance: log.info("Pushing domain blocks to instances...") for dest in conf.blocklist_instance_destinations: - domain = dest['domain'] + target = dest['domain'] token = dest['token'] scheme = dest.get('scheme', 'https') max_followed_severity = BlockSeverity(dest.get('max_followed_severity', 'silence')) - push_blocklist(token, domain, merged.values(), conf.dryrun, import_fields, max_followed_severity, scheme) + push_blocklist(token, target, merged, conf.dryrun, import_fields, max_followed_severity, scheme) -def apply_allowlists(merged: dict, conf: argparse.Namespace, allowlists: dict): +def apply_allowlists(merged: Blocklist, conf: argparse.Namespace, allowlists: dict): """Apply allowlists """ # Apply allows specified on the commandline for domain in conf.allow_domains: log.info(f"'{domain}' allowed by commandline, removing any blocks...") - if domain in merged: - del merged[domain] + if domain in merged.blocks: + del merged.blocks[domain] # Apply allows from URLs lists log.info("Removing domains from URL allowlists...") - for key, alist in allowlists.items(): - log.debug(f"Processing allows from '{key}'...") - for allowed in alist: + for alist in allowlists: + log.debug(f"Processing allows from '{alist.origin}'...") + for allowed in alist.blocks.values(): domain = allowed.domain log.debug(f"Removing allowlisted domain '{domain}' from merged list.") - if domain in merged: - del merged[domain] + if domain in merged.blocks: + del merged.blocks[domain] return merged -def fetch_allowlists(conf: argparse.Namespace) -> dict: +def fetch_allowlists(conf: argparse.Namespace) -> Blocklist: """ """ if conf.allowlist_url_sources: - allowlists = fetch_from_urls({}, conf.allowlist_url_sources, ALLOWLIST_IMPORT_FIELDS) + allowlists = fetch_from_urls(conf.allowlist_url_sources, ALLOWLIST_IMPORT_FIELDS, conf.save_intermediate, conf.savedir) return allowlists - return {} + return Blocklist() -def fetch_from_urls(blocklists: dict, url_sources: dict, +def fetch_from_urls(url_sources: dict, import_fields: list=IMPORT_FIELDS, save_intermediate: bool=False, savedir: str=None, export_fields: list=EXPORT_FIELDS) -> dict: @@ -131,7 +131,7 @@ def fetch_from_urls(blocklists: dict, url_sources: dict, @returns: A dict of blocklists, same as input, but (possibly) modified """ log.info("Fetching domain blocks from URLs...") - + blocklists = [] for item in url_sources: url = item['url'] # If import fields are provided, they override the global ones passed in @@ -144,14 +144,14 @@ def fetch_from_urls(blocklists: dict, url_sources: dict, listformat = item.get('format', 'csv') with urlr.urlopen(url) as fp: rawdata = fp.read(URL_BLOCKLIST_MAXSIZE).decode('utf-8') - blocklists[url] = parse_blocklist(rawdata, listformat, import_fields, max_severity) - - if save_intermediate: - save_intermediate_blocklist(blocklists[url], url, savedir, export_fields) + bl = parse_blocklist(rawdata, url, listformat, import_fields, max_severity) + blocklists.append(bl) + if save_intermediate: + save_intermediate_blocklist(bl, savedir, export_fields) return blocklists -def fetch_from_instances(blocklists: dict, sources: dict, +def fetch_from_instances(sources: dict, import_fields: list=IMPORT_FIELDS, save_intermediate: bool=False, savedir: str=None, export_fields: list=EXPORT_FIELDS) -> dict: @@ -161,12 +161,13 @@ def fetch_from_instances(blocklists: dict, sources: dict, @returns: A dict of blocklists, same as input, but (possibly) modified """ log.info("Fetching domain blocks from instances...") + blocklists = [] for item in sources: domain = item['domain'] admin = item.get('admin', False) token = item.get('token', None) scheme = item.get('scheme', 'https') - itemsrc = f"{scheme}://{domain}/api" + # itemsrc = f"{scheme}://{domain}/api" # If import fields are provided, they override the global ones passed in source_import_fields = item.get('import_fields', None) @@ -174,15 +175,15 @@ def fetch_from_instances(blocklists: dict, sources: dict, # Ensure we always use the default fields import_fields = IMPORT_FIELDS.extend(source_import_fields) - # Add the blocklist with the domain as the source key - blocklists[itemsrc] = fetch_instance_blocklist(domain, token, admin, import_fields, scheme) + bl = fetch_instance_blocklist(domain, token, admin, import_fields, scheme) + blocklists.append(bl) if save_intermediate: - save_intermediate_blocklist(blocklists[itemsrc], domain, savedir, export_fields) + save_intermediate_blocklist(bl, savedir, export_fields) return blocklists def merge_blocklists(blocklists: list[Blocklist], mergeplan: str='max', threshold: int=0, - threshold_type: str='count') -> dict: + threshold_type: str='count') -> Blocklist: """Merge fetched remote blocklists into a bulk update @param blocklists: A dict of lists of DomainBlocks, keyed by source. Each value is a list of DomainBlocks @@ -199,7 +200,7 @@ def merge_blocklists(blocklists: list[Blocklist], mergeplan: str='max', count_of_mentions / number_of_blocklists. @param returns: A dict of DomainBlocks keyed by domain """ - merged = {} + merged = Blocklist('fediblockhole.merge_blocklists') num_blocklists = len(blocklists) @@ -209,7 +210,7 @@ def merge_blocklists(blocklists: list[Blocklist], mergeplan: str='max', for bl in blocklists: for block in bl.values(): if '*' in block.domain: - log.debug(f"Domain '{domain}' is obfuscated. Skipping it.") + log.debug(f"Domain '{block.domain}' is obfuscated. Skipping it.") continue elif block.domain in domain_blocks: domain_blocks[block.domain].append(block) @@ -224,40 +225,17 @@ def merge_blocklists(blocklists: list[Blocklist], mergeplan: str='max', domain_threshold_level = len(domain_blocks[domain]) / num_blocklists else: raise ValueError(f"Unsupported threshold type '{threshold_type}'. Supported values are: 'count', 'pct'") - + if domain_threshold_level >= threshold: # Add first block in the list to merged - merged[domain] = domain_blocks[domain][0] + block = domain_blocks[domain][0] # Merge the others with this record - for block in domain_blocks[domain][1:]: - merged[domain] = apply_mergeplan(merged[domain], block, mergeplan) - + for newblock in domain_blocks[domain][1:]: + block = apply_mergeplan(block, newblock, mergeplan) + merged.blocks[block.domain] = block + return merged - # for key, blist in blocklists.items(): - # log.debug(f"processing blocklist from: {key} ...") - # for newblock in blist: - # domain = newblock.domain - # # If the domain has two asterisks in it, it's obfuscated - # # and we can't really use it, so skip it and do the next one - # if '*' in domain: - # log.debug(f"Domain '{domain}' is obfuscated. Skipping it.") - # continue - - # elif domain in merged: - # log.debug(f"Overlapping block for domain {domain}. Merging...") - # blockdata = apply_mergeplan(merged[domain], newblock, mergeplan) - - # else: - # # New block - # blockdata = newblock - - # # end if - # log.debug(f"blockdata is: {blockdata}") - # merged[domain] = blockdata - # # end for - # return merged - def apply_mergeplan(oldblock: DomainBlock, newblock: DomainBlock, mergeplan: str='max') -> dict: """Use a mergeplan to decide how to merge two overlapping block definitions @@ -282,10 +260,10 @@ def apply_mergeplan(oldblock: DomainBlock, newblock: DomainBlock, mergeplan: str # How do we override an earlier block definition? if mergeplan in ['max', None]: # Use the highest block level found (the default) - log.debug(f"Using 'max' mergeplan.") + # log.debug(f"Using 'max' mergeplan.") if newblock.severity > oldblock.severity: - log.debug(f"New block severity is higher. Using that.") + # log.debug(f"New block severity is higher. Using that.") blockdata['severity'] = newblock.severity # For 'reject_media', 'reject_reports', and 'obfuscate' if @@ -314,7 +292,7 @@ def apply_mergeplan(oldblock: DomainBlock, newblock: DomainBlock, mergeplan: str else: raise NotImplementedError(f"Mergeplan '{mergeplan}' not implemented.") - log.debug(f"Block severity set to {blockdata['severity']}") + # log.debug(f"Block severity set to {blockdata['severity']}") return DomainBlock(**blockdata) @@ -396,17 +374,19 @@ def fetch_instance_blocklist(host: str, token: str=None, admin: bool=False, url = f"{scheme}://{host}{api_path}" - blocklist = [] + blockdata = [] link = True - while link: response = requests.get(url, headers=headers, timeout=REQUEST_TIMEOUT) if response.status_code != 200: log.error(f"Cannot fetch remote blocklist: {response.content}") raise ValueError("Unable to fetch domain block list: %s", response) - blocklist.extend( parse_blocklist(response.content, parse_format, import_fields) ) - + # Each block of returned data is a JSON list of dicts + # so we parse them and append them to the fetched list + # of JSON data we need to parse. + + blockdata.extend(json.loads(response.content.decode('utf-8'))) # Parse the link header to find the next url to fetch # This is a weird and janky way of doing pagination but # hey nothing we can do about it we just have to deal @@ -424,6 +404,8 @@ def fetch_instance_blocklist(host: str, token: str=None, admin: bool=False, urlstring, rel = next.split('; ') url = urlstring.strip('<').rstrip('>') + blocklist = parse_blocklist(blockdata, url, parse_format, import_fields) + return blocklist def delete_block(token: str, host: str, id: int, scheme: str='https'): @@ -513,13 +495,9 @@ def update_known_block(token: str, host: str, block: DomainBlock, scheme: str='h """Update an existing domain block with information in blockdict""" api_path = "/api/v1/admin/domain_blocks/" - try: - id = block.id - blockdata = block._asdict() - del blockdata['id'] - except KeyError: - import pdb - pdb.set_trace() + id = block.id + blockdata = block._asdict() + del blockdata['id'] url = f"{scheme}://{host}{api_path}{id}" @@ -553,7 +531,7 @@ def add_block(token: str, host: str, blockdata: DomainBlock, scheme: str='https' raise ValueError(f"Something went wrong: {response.status_code}: {response.content}") -def push_blocklist(token: str, host: str, blocklist: list[dict], +def push_blocklist(token: str, host: str, blocklist: list[DomainBlock], dryrun: bool=False, import_fields: list=['domain', 'severity'], max_followed_severity:BlockSeverity=BlockSeverity('silence'), @@ -561,8 +539,7 @@ def push_blocklist(token: str, host: str, blocklist: list[dict], ): """Push a blocklist to a remote instance. - Merging the blocklist with the existing list the instance has, - updating existing entries if they exist. + Updates existing entries if they exist, creates new blocks if they don't. @param token: The Bearer token for OAUTH API authentication @param host: The instance host, FQDN or IP @@ -577,15 +554,16 @@ def push_blocklist(token: str, host: str, blocklist: list[dict], serverblocks = fetch_instance_blocklist(host, token, True, import_fields, scheme) # # Convert serverblocks to a dictionary keyed by domain name - knownblocks = {row.domain: row for row in serverblocks} + # knownblocks = {row.domain: row for row in serverblocks} - for newblock in blocklist: + for newblock in blocklist.values(): log.debug(f"Processing block: {newblock}") - oldblock = knownblocks.get(newblock.domain, None) - if oldblock: + if newblock.domain in serverblocks: log.debug(f"Block already exists for {newblock.domain}, checking for differences...") + oldblock = serverblocks[newblock.domain] + change_needed = is_change_needed(oldblock, newblock, import_fields) # Is the severity changing? @@ -644,15 +622,14 @@ def load_config(configfile: str): conf = toml.load(configfile) return conf -def save_intermediate_blocklist( - blocklist: list[dict], source: str, - filedir: str, +def save_intermediate_blocklist(blocklist: Blocklist, filedir: str, export_fields: list=['domain','severity']): """Save a local copy of a blocklist we've downloaded """ # Invent a filename based on the remote source # If the source was a URL, convert it to something less messy # If the source was a remote domain, just use the name of the domain + source = blocklist.origin log.debug(f"Saving intermediate blocklist from {source}") source = source.replace('/','-') filename = f"{source}.csv" @@ -660,7 +637,7 @@ def save_intermediate_blocklist( save_blocklist_to_file(blocklist, filepath, export_fields) def save_blocklist_to_file( - blocklist: list[DomainBlock], + blocklist: Blocklist, filepath: str, export_fields: list=['domain','severity']): """Save a blocklist we've downloaded from a remote source @@ -670,18 +647,22 @@ def save_blocklist_to_file( @param export_fields: Which fields to include in the export. """ try: - blocklist = sorted(blocklist, key=lambda x: x.domain) + sorted_list = sorted(blocklist.blocks.items()) except KeyError: log.error("Field 'domain' not found in blocklist.") - log.debug(f"blocklist is: {blocklist}") + log.debug(f"blocklist is: {sorted_list}") + except AttributeError: + log.error("Attribute error!") + import pdb + pdb.set_trace() log.debug(f"export fields: {export_fields}") with open(filepath, "w") as fp: writer = csv.DictWriter(fp, export_fields, extrasaction='ignore') writer.writeheader() - for item in blocklist: - writer.writerow(item._asdict()) + for key, value in sorted_list: + writer.writerow(value) def augment_args(args, tomldata: str=None): """Augment commandline arguments with config file parameters diff --git a/src/fediblockhole/blocklists.py b/src/fediblockhole/blocklists.py index f79f3d2..7a9e44f 100644 --- a/src/fediblockhole/blocklists.py +++ b/src/fediblockhole/blocklists.py @@ -41,7 +41,7 @@ class BlocklistParser(object): """ Base class for parsing blocklists """ - preparse = False + do_preparse = False def __init__(self, import_fields: list=['domain', 'severity'], max_severity: str='suspend'): @@ -63,7 +63,7 @@ class BlocklistParser(object): @param blocklist: An Iterable of blocklist items @returns: A dict of DomainBlocks, keyed by domain """ - if self.preparse: + if self.do_preparse: blockdata = self.preparse(blockdata) parsed_list = Blocklist(origin) @@ -82,12 +82,13 @@ class BlocklistParser(object): class BlocklistParserJSON(BlocklistParser): """Parse a JSON formatted blocklist""" - preparse = True + do_preparse = True def preparse(self, blockdata) -> Iterable: - """Parse the blockdata as JSON - """ - return json.loads(blockdata) + """Parse the blockdata as JSON if needed""" + if type(blockdata) == type(''): + return json.loads(blockdata) + return blockdata def parse_item(self, blockitem: dict) -> DomainBlock: # Remove fields we don't want to import @@ -131,7 +132,7 @@ class BlocklistParserCSV(BlocklistParser): The parser expects the CSV data to include a header with the field names. """ - preparse = True + do_preparse = True def preparse(self, blockdata) -> Iterable: """Use a csv.DictReader to create an iterable from the blockdata @@ -237,6 +238,7 @@ def parse_blocklist( max_severity: str='suspend'): """Parse a blocklist in the given format """ - parser = FORMAT_PARSERS[format](import_fields, max_severity) log.debug(f"parsing {format} blocklist with import_fields: {import_fields}...") + + parser = FORMAT_PARSERS[format](import_fields, max_severity) return parser.parse_blocklist(blockdata, origin) \ No newline at end of file diff --git a/tests/test_allowlist.py b/tests/test_allowlist.py index 902b301..ddd53b9 100644 --- a/tests/test_allowlist.py +++ b/tests/test_allowlist.py @@ -4,6 +4,7 @@ import pytest from util import shim_argparse from fediblockhole.const import DomainBlock +from fediblockhole.blocklists import Blocklist from fediblockhole import fetch_allowlists, apply_allowlists def test_cmdline_allow_removes_domain(): @@ -11,17 +12,13 @@ def test_cmdline_allow_removes_domain(): """ conf = shim_argparse(['-A', 'removeme.org']) - merged = { + merged = Blocklist('test_allowlist.merged', { 'example.org': DomainBlock('example.org'), 'example2.org': DomainBlock('example2.org'), 'removeme.org': DomainBlock('removeme.org'), 'keepblockingme.org': DomainBlock('keepblockingme.org'), - } + }) - # allowlists = { - # 'testlist': [ DomainBlock('removeme.org', 'noop'), ] - # } - merged = apply_allowlists(merged, conf, {}) with pytest.raises(KeyError): @@ -32,16 +29,18 @@ def test_allowlist_removes_domain(): """ conf = shim_argparse() - merged = { + merged = Blocklist('test_allowlist.merged', { 'example.org': DomainBlock('example.org'), 'example2.org': DomainBlock('example2.org'), 'removeme.org': DomainBlock('removeme.org'), 'keepblockingme.org': DomainBlock('keepblockingme.org'), - } + }) - allowlists = { - 'testlist': [ DomainBlock('removeme.org', 'noop'), ] - } + allowlists = [ + Blocklist('test_allowlist', { + 'removeme.org': DomainBlock('removeme.org', 'noop'), + }) + ] merged = apply_allowlists(merged, conf, allowlists) @@ -53,19 +52,19 @@ def test_allowlist_removes_tld(): """ conf = shim_argparse() - merged = { + merged = Blocklist('test_allowlist.merged', { '.cf': DomainBlock('.cf'), 'example.org': DomainBlock('example.org'), '.tk': DomainBlock('.tk'), 'keepblockingme.org': DomainBlock('keepblockingme.org'), - } + }) - allowlists = { - 'list1': [ - DomainBlock('.cf', 'noop'), - DomainBlock('.tk', 'noop'), - ] - } + allowlists = [ + Blocklist('test_allowlist.list1', { + '.cf': DomainBlock('.cf', 'noop'), + '.tk': DomainBlock('.tk', 'noop'), + }) + ] merged = apply_allowlists(merged, conf, allowlists)