#!/usr/bin/env python3 # Arch Linux packages proxy # # Proxies requests, caching files with ".pkg.tar.*" suffix. # If the cached file exists, serve that file. # Otherwise, try to download file (as normally) and optionally cache it with # ".download" extension. If the file exists and the file is not being # downloaded, resume the download (Range requests). import argparse import http.server import os import re import socket from datetime import datetime import requests from contextlib import closing, contextmanager import fcntl DATE_FORMAT = '%a, %d %b %Y %H:%M:%S GMT' def text_to_epoch(text): return datetime.strptime(text, DATE_FORMAT).timestamp() def epoch_to_text(epoch): return datetime.fromtimestamp(epoch).strftime(DATE_FORMAT) class BadRequest(Exception): pass class RequestHandler(http.server.BaseHTTPRequestHandler): def send_ok(self, size, headers={}, upstream=None, range_offset=None): if range_offset is None: code = 200 else: if size is None: raise BadRequest("Content-Length missing in upstream response" " for range request") if range_offset >= size: # TODO need different status code raise BadRequest("416 Requested Range Not Satisfiable") code = 206 content_range = "bytes %d-%d/%d" % (range_offset, size - 1, size) headers["Content-Range"] = content_range size -= range_offset self.log_message('"%s" %d %s %s', self.requestline, code, size, "HIT" if upstream is None else "MISS:%s" % (upstream,)) self.send_response_only(code) if size is not None: self.send_header('Content-Length', size) for k, v in headers.items(): self.send_header(k, v) self.end_headers() def request_data(self, head_only=False, mtime_out=None, range_offset=None): """ Retrieves the full response body. The given "range_offset" serves only as hint for the response to the client, it is not used with the upstream request. """ method = "HEAD" if head_only else "GET" streamable = not head_only status_code = None urls = list(self.get_upstream_urls()) # Try each upstream. If one fails, log it and try another. On success, # return the response data. If all upstreams fail, fail the request. for i, url in enumerate(urls): with closing(requests.request(method, url, stream=streamable)) as r: status_code = r.status_code if status_code == 200: yield from self.process_upstream_response(r, head_only, mtime_out, i, range_offset) return self.log_message('"%s" %d - SKIP:%d', self.requestline, status_code, i) self.log_request(status_code) self.send_response_only(status_code) self.end_headers() def process_upstream_response(self, r, head_only, mtime_out, upstream, range_offset): if r: response_headers = {} if 'Last-Modified' in r.headers: try: mtime = text_to_epoch(r.headers['Last-Modified']) response_headers['Last-Modified'] = epoch_to_text(mtime) if mtime_out: mtime_out[0] = mtime except ValueError: self.log_error("Unable to parse Last-Modified header") if 'Content-Length' in r.headers: self.send_ok(int(r.headers['Content-Length']), response_headers, upstream=upstream, range_offset=range_offset) else: self.send_ok(None, response_headers, upstream=upstream, range_offset=range_offset) if not head_only: yield from r.iter_content(4096) @contextmanager def open_write_cache(self, path): if self.server.is_readonly: yield None return temp_path = path + ".download" try: with open(temp_path, 'wb') as f: # Prevent concurrent writers fcntl.lockf(f, fcntl.LOCK_EX | fcntl.LOCK_NB) yield f except OSError as e: self.log_error("Failed to create cache file %s: %s", temp_path, e) yield None def finish_cache(self, mtime): path = self.get_local_path() temp_path = path + ".download" if mtime: os.utime(temp_path, times=(mtime, mtime)) try: os.rename(temp_path, path) except OSError as e: self.log_error("Failed to rename %s", temp_path) try: os.unlink(temp_path) except OSError as e: self.log_error("Failed to remove %s", temp_path) def parse_range(self): value = self.headers.get('Range') if value is not None: # Only support "continue" range requests, resuming previous # download. Anything more complex is not needed at the moment. m = re.match(r'bytes=(?P\d+)-$', value) if not m: raise BadRequest("Unsupported range request: %s" % value) return int(m.group("from")) @staticmethod def skip_range_chunk(chunk, skip_bytes): if skip_bytes: chunksize = len(chunk) if chunksize > skip_bytes: chunk = chunk[skip_bytes:] skip_bytes = None else: chunk = b'' skip_bytes -= chunksize return chunk, (skip_bytes or None) @classmethod def skip_range(cls, it, skip_bytes): for chunk in it: chunk, skip_bytes = cls.skip_range_chunk(chunk, skip_bytes) if chunk: yield chunk def request_data_with_cache(self, head_only=False): range_offset = self.parse_range() if not self.is_cacheable(): # Not cacheable, directly obtain data and bypass cache remote_data = self.request_data(range_offset=range_offset) yield from self.skip_range(remote_data, range_offset) return path = self.get_local_path() try: # Try to open cached file and yield data from it stat_info = os.stat(path) if stat_info.st_size == 0: # Treat empty files as missing. raise FileNotFoundError response_headers = {'Last-Modified': epoch_to_text(stat_info.st_mtime)} self.send_ok(stat_info.st_size, response_headers, range_offset=range_offset) if not head_only: with open(path, 'rb') as f: if range_offset: f.seek(range_offset) yield from f except FileNotFoundError: # File does not exist, so try to pipe upstream # (optionally writing to cache file) mtime_pointer = [None] remote_data = self.request_data(head_only=head_only, mtime_out=mtime_pointer, range_offset=range_offset) if head_only: list(remote_data) # consume yield and StopIteration if not head_only and remote_data: cache_ok = False with self.open_write_cache(path) as cache_file: cache_ok = cache_file is not None if cache_ok: # Overwrite the temporary cache file from begin to end, # but do not write include the first "range_offset" # bytes in the response. skip = range_offset for chunk in remote_data: cache_file.write(chunk) chunk, skip = self.skip_range_chunk(chunk, skip) if chunk: yield chunk if cache_ok: # Write was successful, now fix mtime and rename self.finish_cache(mtime_pointer[0]) else: # Cache file unavailable, just pass all data yield from self.skip_range(remote_data, range_offset) def do_GET(self): try: data = self.request_data_with_cache() if data: for chunk in data: self.wfile.write(chunk) except (BrokenPipeError, ConnectionResetError): self.log_error("GET %s - (connection aborted)", self.path) except BadRequest as e: self.log_error("GET %s - Bad Request: %s", self.path, e) self.send_response(400) except Exception as e: self.log_error("GET %s failed: %s", self.path, e) import traceback; traceback.print_exc() self.send_response(502) def do_HEAD(self): try: list(self.request_data_with_cache(True)) except (BrokenPipeError, ConnectionResetError): self.log_error("HEAD %s - (connection aborted)", self.path) except BadRequest as e: self.log_error("HEAD %s - Bad Request: %s", self.path, e) self.send_response(400) except Exception as e: self.log_error("HEAD %s failed: %s", self.path, e) import traceback; traceback.print_exc() self.send_response(502) def get_upstream_urls(self): # If an old version is requested, retrieve the databases from the # archive mirror and do not fallback. if self.server.archive_url and self.is_date_sensitive_request(): yield self.server.archive_url + self.path return for prefix in self.server.mirrors: yield prefix + self.path def get_local_path(self): filename = os.path.basename(self.path) return os.path.join(self.server.cachedir, filename) def is_cacheable(self): """Whether the requested file should be cached.""" # Support .pkg.tar.xz, .pkg.tar.zst, etc. basename = os.path.splitext(self.path)[0] return basename.endswith(".pkg.tar") def is_date_sensitive_request(self): """Whether the resource is ephemeral.""" path = self.path if path.endswith(".sig"): path = path[:-4] suffixes = [".db", ".files", ".abs.tar.gz"] return any(path.endswith(suffix) for suffix in suffixes) class SomeServer(http.server.HTTPServer): def __init__(self, addr, handler, args): self.allow_reuse_address = True if ':' in addr[0]: self.address_family = socket.AF_INET6 super().__init__(addr, handler) self.cachedir = args.cachedir self.is_readonly = args.readonly self.mirrors = [m.rstrip('/') for m in args.mirrors] if not args.date: self.archive_url = None else: archive_mirror = "https://archive.archlinux.org/repos/" self.archive_url = archive_mirror + args.date self.mirrors.append(self.archive_url) def dump_config(self): yesno = lambda x: "yes" if x else "no" print("Listen address: %s:%s" % self.socket.getsockname()[:2]) print("Cache directory: %s" % self.cachedir) print("Read-only cache: %s" % yesno(self.is_readonly)) print("Using archive: %s" % yesno(self.archive_url)) print("Mirrors:") for mirror in self.mirrors: print(" %s" % mirror) def mirror_url(string): scheme = string.split(":", 1)[0] if scheme not in ("http", "https"): raise argparse.ArgumentTypeError("%s is not a valid URL" % string) return string.rstrip("/") + "/" def parse_date(string): m = re.match(r'^(\d{4})([/-]?)(\d{2})\2(\d{2})$', string) if not m: raise argparse.ArgumentTypeError("%s is not a valid date" % string) y, _, m, d = m.groups() return "%s/%s/%s" % (y, m, d) parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--readonly", action="store_true", help="Do not write downloaded results to the cache directory") parser.add_argument("--cachedir", default=os.getcwd(), help="Cache directory") parser.add_argument("--port", type=int, default=8001, help="Listen port") parser.add_argument("--date", type=parse_date, help="Provide a repository snapshot from 'yyyy/mm/dd'") parser.add_argument("--mirror", dest="mirrors", metavar='URL', nargs="+", type=mirror_url, default=["https://mirror.nl.leaseweb.net/archlinux"], help="Mirror list") if __name__ == '__main__': args = parser.parse_args() addr = ('', args.port) server = SomeServer(addr, RequestHandler, args) server.dump_config() server.serve_forever()