diff options
Diffstat (limited to 'upload-svr.go')
-rw-r--r-- | upload-svr.go | 450 |
1 files changed, 450 insertions, 0 deletions
diff --git a/upload-svr.go b/upload-svr.go new file mode 100644 index 0000000..efba0c9 --- /dev/null +++ b/upload-svr.go @@ -0,0 +1,450 @@ +// Accepts POST uploads at / +// Serves HEAD and GET requests for all other files. +// Serves HEAD and GET requests for other directories. +package main + +import ( + "bytes" + "crypto/sha1" + "crypto/sha256" + "crypto/tls" + "flag" + "fmt" + "html" + "io" + "log" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" +) + +const ( + blockSize = 1024 * 1024 + maxFilenameLength = 50 +) + +func isSlashRune(r rune) bool { return r == '/' || r == '\\' } + +func isForbiddenPath(path string) bool { + if !strings.Contains(path, "..") { + return false + } + for _, elm := range strings.FieldsFunc(path, isSlashRune) { + if elm == ".." { + return true + } + } + return false +} + +func uriToFilesystem(uri string) string { + return strings.Trim(filepath.Clean(uri), "/") +} + +func showIndex(w http.ResponseWriter, r *http.Request, msg string) { + if strings.Contains(r.UserAgent(), "curl") || + strings.Contains(r.UserAgent(), "Wget") { + w.Header().Set("Content-Type", "text/plain") + if msg != "" { + w.Write([]byte(msg)) + w.Write([]byte{'\n'}) + } else { + w.Write([]byte(curlUsage)) + } + return + } + + w.Write([]byte(uploadHtml)) + if msg != "" { + msg = strings.ReplaceAll(html.EscapeString(msg), "\n", "<br>\n") + errHtml := fmt.Sprintf("<p>%s</p>\n", msg) + w.Write([]byte(errHtml)) + } + dir := uriToFilesystem(r.URL.Path) + if dir == "" { + dir = "." + } + w.Write([]byte("<pre>\n")) + // TODO split presentation from traversal. + // TODO group directories first and sort alphabetically. + filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + log.Printf("Error walking %v: %v\n", path, err) + return nil + } + relpathFs, err := filepath.Rel(dir, path) + if err != nil { + log.Printf("Error creating relative path %v: %v\n", path, err) + return nil + } + relpath := strings.ReplaceAll(relpathFs, string(filepath.Separator), "/") + // If this is 'dir', then display it as '..' if we are a subdir. + if relpath == "." { + if dir == "." { + return nil + } + relpath = ".." + } + if info.IsDir() { + relpath += "/" + } + disppath := relpath + if len(relpath) > maxFilenameLength { + tail := "..>" + if info.IsDir() { + tail += "/" + } + disppath = disppath[:maxFilenameLength-len(tail)] + tail + } + pad := strings.Repeat(" ", maxFilenameLength-len(relpath)) + sz := "-" + if !info.IsDir() { + sz = strconv.FormatInt(info.Size(), 10) + } + // emulates nginx directory index layout + entryHtml := fmt.Sprintf(`<a href="%s">%s</a>%s %s %20s`+"\n", + html.EscapeString(relpath), + html.EscapeString(disppath), + pad, + info.ModTime().Format("_2-Jan-2006 15:04"), + sz) + w.Write([]byte(entryHtml)) + if path != dir && info.IsDir() { + //log.Printf("Skipping dir: %v\n", path) + return filepath.SkipDir + } + return nil + }) + w.Write([]byte("</pre>\n")) +} + +// Reads from the files once and compare the results. Returns whether all bytes +// are the same and the number of bytes from the second reader. If EOF is +// reached, then "true, 0" will be returned. +func readAndCompareOnce(a, b io.Reader, bufferA, bufferB []byte) (bool, int) { + nA, errA := a.Read(bufferA) + nB, errB := b.Read(bufferB) + if errA != nil && errA != io.EOF { + log.Println("Error reading first file:", errA) + return false, nB + } + if errB != nil && errB != io.EOF { + log.Println("Error reading second file:", errB) + return false, nB + } + if nA != nB || !bytes.Equal(bufferA[:nA], bufferB[:nB]) { + return false, nB + } + return true, nB +} + +// Advances both file handles by size (if possible) and returns true if that +// part is equal. +func isFileHeadEqual(a, b io.Reader, size int64) bool { + ar := io.LimitReader(a, size) + br := io.LimitReader(b, size) + bufferA := make([]byte, blockSize) + bufferB := make([]byte, blockSize) + for { + ok, n := readAndCompareOnce(ar, br, bufferA, bufferB) + if !ok { + return false + } + if n == 0 { + return true + } + } +} + +// Assuming that the initial 'offset' bytes from both files have been verified +// to be the same, continue reading and check whether they are the same. +// extraData is data that was consumed from 'newFile' and must be matched before +// new data can be read. +// Returns whether the file fully matches, the base offset and updated extra +// data. +func isDupe(baseFile, newFile io.Reader, offset int64, extraData []byte) (bool, int64, []byte) { + if len(extraData) > 0 { + if !isFileHeadEqual(baseFile, bytes.NewReader(extraData), int64(len(extraData))) { + return false, offset, extraData + } + offset += int64(len(extraData)) + extraData = extraData[:0] + } + bufferBase := make([]byte, cap(extraData)) + bufferNew := extraData[:cap(extraData)] + for { + ok, n := readAndCompareOnce(baseFile, newFile, bufferBase, bufferNew) + switch { + case !ok: + // not the same, remember read buffer. + return false, offset, extraData[:n] + case ok && n == 0: + // EOF - must be the same + return true, offset, extraData[:0] + case ok && n != 0: + // data was read and was found to be the same. Continue. + offset += int64(n) + } + } +} + +// Save the file using the given name as hint. If a file already exists and has +// the same contents, no new file will be created. Otherwise a new file will be +// created with the provided contents. Returns the actual filename and a status +// message. +func saveFile(userFilename string, r io.Reader) (string, string) { + saveName := path.Base(userFilename) + // basic sanity check. Incomplete for Windows (where names like "NUL" + // and "COM1" are special) - do not care for now. + if saveName == "." || saveName == ".." { + saveName = "file" + } + + // Handle large files in a more efficient way: keep track of "file with + // the same data", "number of bytes within said file that was the same", + // and "data that was read but different". + var baseFile *os.File + var baseSize int64 + newData := make([]byte, 0, blockSize) + defer baseFile.Close() + + // Read a bit once + n, err := r.Read(newData[:cap(newData)]) + if err != nil && err != io.EOF { + return userFilename, fmt.Sprintf("Error reading upload: %v", err.Error()) + } + if n == 0 { + return userFilename, "Uploaded file was empty - ignoring." + } + newData = newData[:n] + + for i := 0; ; i++ { + filename := saveName + if i > 0 { + filename = fmt.Sprintf("%s.%d", filename, i) + } + f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644) + defer f.Close() + if os.IsExist(err) { + if baseFile2, err := os.Open(filename); err == nil { + defer baseFile2.Close() + // If there is an existing base file, check if + // it starts with the same data. + if baseFile != nil { + if _, err := baseFile.Seek(0, os.SEEK_SET); err != nil { + // Seek failed, cannot compare. + continue + } + if !isFileHeadEqual(baseFile, baseFile2, baseSize) { + // different - not a dupe + continue + } + } + var dupe bool + var newSize int64 + dupe, newSize, newData = isDupe(baseFile2, r, baseSize, newData) + if baseSize != newSize { + baseFile = baseFile2 + baseFile2 = nil + baseSize = newSize + } + if dupe { + return filename, "Duplicate file" + } + } + // Duplicate filename. + continue + } + if err != nil { + return filename, fmt.Sprintf("Unable to save file: %v", err) + } + + // Ready to save the file. + if filename == userFilename { + log.Printf("Saving %v\n", filename) + } else { + log.Printf("Saving %v as %v\n", userFilename, filename) + } + var n int64 + if baseFile != nil { + if _, err := baseFile.Seek(0, os.SEEK_SET); err != nil { + // Seek failed, OOPS, cannot reconstruct file! + return filename, fmt.Sprintf("File data disappeared (seek): %v", err) + } + n, err = io.Copy(f, io.LimitReader(baseFile, baseSize)) + if err != nil { + return filename, fmt.Sprintf("File data disappeared (copy): %v", err) + } + if n != baseSize { + return filename, fmt.Sprintf("Wrote %d bytes instead of %d bytes", n, baseSize) + } + } + if len(newData) != 0 { + k, err := f.Write(newData) + if k != len(newData) { + return filename, fmt.Sprintf("Wrote %d+%d bytes instead of %d+%d: %v", n, k, n, len(newData), err) + } + n += int64(len(newData)) + } + m, err := io.Copy(f, r) + if err != nil { + return filename, fmt.Sprintf("Wrote %d bytes, %s", n+m, err) + } + return filename, fmt.Sprintf("OK. Wrote %d bytes", n+m) + } +} + +func handleUpload(w http.ResponseWriter, r *http.Request) { + mr, err := r.MultipartReader() + if err != nil { + log.Print(err) + showIndex(w, r, err.Error()) + return + } + var messages []string + for { + p, err := mr.NextPart() + if err == io.EOF { + break + } + if err != nil { + log.Printf("Upload error: %s\n", err) + messages = append(messages, err.Error()) + break + } + if p.FormName() != "file" { + continue + } + // Save file + saveName, msg := saveFile(p.FileName(), p) + displayName := saveName + if p.FileName() != saveName { + displayName = fmt.Sprintf("%s -> %s", p.FileName(), displayName) + } + msg = fmt.Sprintf("%s - %s", displayName, msg) + log.Println(msg) + messages = append(messages, msg) + } + if messages == nil { + showIndex(w, r, "No files selected") + } else { + showIndex(w, r, strings.Join(messages, "\n")) + } +} + +func serveFile(w http.ResponseWriter, r *http.Request) { + if isForbiddenPath(r.URL.Path) { + http.Error(w, "forbidden URL", http.StatusBadRequest) + return + } + f, err := os.Open(uriToFilesystem(r.URL.Path)) + if err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + log.Println(err) + return + } + defer f.Close() + http.ServeContent(w, r, r.URL.Path, time.Time{}, f) +} + +var ( + certFile string + keyFile string + address string +) + +func init() { + flag.StringVar(&certFile, "cert", "", "Certificate and optional key file") + flag.StringVar(&keyFile, "key", "", "Key file") + flag.StringVar(&address, "accept", ":1111", "Listen address") +} + +func main() { + flag.Parse() + if flag.NArg() != 0 { + fmt.Fprintln(os.Stderr, "Too many arguments") + os.Exit(2) + } + if certFile != "" && keyFile == "" { + keyFile = certFile + } + + http.HandleFunc("/favicon.ico", http.NotFound) + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + log.Printf("%s %s %s", r.RemoteAddr, r.Method, r.URL.Path) + if isForbiddenPath(r.URL.Path) { + http.Error(w, "forbidden URL", http.StatusBadRequest) + return + } + + isDir := r.URL.Path[len(r.URL.Path)-1] == '/' + if r.Method == "POST" { + // TODO relax this to other locations? + if r.URL.Path != "/" { + http.Error(w, "Uploads are only accepted at /", http.StatusMethodNotAllowed) + return + } + handleUpload(w, r) + } else if r.Method == "HEAD" || r.Method == "GET" { + if isDir { + showIndex(w, r, "") + return + } + serveFile(w, r) + } else { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + }) + if certFile != "" { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatal(err) + } + server := &http.Server{ + Addr: address, + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{cert}, + }, + } + leafCert := server.TLSConfig.Certificates[0].Certificate[0] + sha256sum := sha256.Sum256(leafCert) + log.Println("SHA-256 Fingerprint:", formatFingerprint(sha256sum[:])) + sha1sum := sha1.Sum(leafCert) + log.Println("SHA1 Fingerprint: ", formatFingerprint(sha1sum[:])) + + log.Println("Listening at", server.Addr, "using TLS") + log.Fatal(server.ListenAndServeTLS("", "")) + } else { + log.Println("Listening at", address) + log.Fatal(http.ListenAndServe(address, nil)) + } +} + +func formatFingerprint(certData []byte) string { + const hextable = "0123456789ABCDEF" + out := make([]byte, len(certData)*3-1) + for i, v := range certData { + if i > 0 { + out[i*3-1] = ':' + } + out[i*3] = hextable[v>>4] + out[i*3+1] = hextable[v&0xf] + } + return string(out) +} + +const uploadHtml = ` +<!doctype html> +<meta charset="UTF-8"> +<meta name="viewport" content="initial-scale=1"> +<form action="" method="POST" enctype="multipart/form-data"> +<input type="file" name="file" multiple> +<input type="submit" value="Upload"> +</form> +<hr> +` +const curlUsage = "Usage: curl -Ffile=@input.txt ...\n" |