#!/usr/bin/env python3 # Arch Linux packages proxy # # Proxies requests, caching files with ".pkg.tar.xz" suffix. # If the cached file exists, serve that file. # Otherwise, try to download file (as normally) and optionally cache it with # ".download" extension. If the file exists and the file is not being # downloaded, resume the download (Range requests). import argparse import http.server import os import socket from datetime import datetime import requests from contextlib import closing, contextmanager import fcntl DATE_FORMAT = '%a, %d %b %Y %H:%M:%S GMT' def text_to_epoch(text): return datetime.strptime(text, DATE_FORMAT).timestamp() def epoch_to_text(epoch): return datetime.fromtimestamp(epoch).strftime(DATE_FORMAT) class RequestHandler(http.server.BaseHTTPRequestHandler): def send_ok(self, size, headers={}, cached=False): self.log_message('"%s" %d %s %s', self.requestline, 200, size, "HIT" if cached else "MISS") self.send_response_only(200) self.send_header('Content-Length', size) for k, v in headers.items(): self.send_header(k, v) self.end_headers() def request_data(self, head_only=False, mtime_out=None): method = "HEAD" if head_only else "GET" url = self.get_upstream_url() with closing(requests.request(method, url, stream=not head_only)) as r: if r.status_code != 200: self.log_request(r.status_code) self.send_response_only(r.status_code) self.end_headers() return response_headers = {} if 'Last-Modified' in r.headers: try: mtime = text_to_epoch(r.headers['Last-Modified']) response_headers['Last-Modified'] = epoch_to_text(mtime) if mtime_out: mtime_out[0] = mtime except ValueError: self.log_error("Unable to parse Last-Modified header") self.send_ok(int(r.headers['Content-Length']), response_headers) if not head_only: yield from r.iter_content(4096) @contextmanager def open_write_cache(self, path): if self.server.is_readonly: yield None return temp_path = path + ".download" try: with open(temp_path, 'wb') as f: # Prevent concurrent writers fcntl.lockf(f, fcntl.LOCK_EX | fcntl.LOCK_NB) yield f except OSError as e: self.log_error("Failed to create cache file %s: %s", temp_path, e) yield None def finish_cache(self, mtime): path = self.get_local_path() temp_path = path + ".download" if mtime: os.utime(temp_path, times=(mtime, mtime)) try: os.rename(temp_path, path) except OSError as e: self.log_error("Failed to rename %s", temp_path) try: os.unlink(temp_path) except OSError as e: self.log_error("Failed to remove %s", temp_path) def request_data_with_cache(self, head_only=False): if not self.is_cacheable(): # Not cacheable, directly obtain data and bypass cache yield from self.request_data() return path = self.get_local_path() try: # Try to open cached file and yield data from it self.send_ok(os.path.getsize(path), {'Last-Modified': epoch_to_text(os.stat(path).st_mtime)}, cached=True) if not head_only: with open(path, 'rb') as f: yield from f except FileNotFoundError: # File does not exist, so try to pipe upstream # (optionally writing to cache file) mtime_pointer = [None] remote_data = self.request_data(head_only=head_only, mtime_out=mtime_pointer) if head_only: list(remote_data) # consume yield and StopIteration if not head_only and remote_data: cache_ok = False with self.open_write_cache(path) as cache_file: cache_ok = cache_file is not None if cache_ok: for chunk in remote_data: cache_file.write(chunk) yield chunk if cache_ok: # Write was successful, now fix mtime and rename self.finish_cache(mtime_pointer[0]) else: # Cache file unavailable, just pass all data yield from remote_data def do_GET(self): try: data = self.request_data_with_cache() if data: for chunk in data: self.wfile.write(chunk) except Exception as e: self.log_error("GET %s failed: %s", self.path, e) import traceback; traceback.print_exc() self.send_response(502) def do_HEAD(self): try: list(self.request_data_with_cache(True)) except Exception as e: self.log_error("HEAD %s failed: %s", self.path, e) import traceback; traceback.print_exc() self.send_response(502) def get_upstream_url(self): prefix = "http://mirror.nl.leaseweb.net/archlinux/" return prefix + self.path def get_local_path(self): filename = os.path.basename(self.path) return os.path.join(self.server.cachedir, filename) def is_cacheable(self): """Whether the requested file should be cached.""" return self.path.endswith(".pkg.tar.xz") class SomeServer(http.server.HTTPServer): def __init__(self, addr, handler, args): self.allow_reuse_address = True if ':' in addr[0]: self.address_family = socket.AF_INET6 super().__init__(addr, handler) self.cachedir = args.cachedir self.is_readonly = args.readonly parser = argparse.ArgumentParser() parser.add_argument("--readonly", action="store_true") parser.add_argument("--cachedir", default=os.getcwd()) parser.add_argument("--port", type=int, default=8001) if __name__ == '__main__': args = parser.parse_args() addr = ('', args.port) server = SomeServer(addr, RequestHandler, args) server.serve_forever()