summaryrefslogtreecommitdiff
path: root/upload-svr.go
diff options
context:
space:
mode:
Diffstat (limited to 'upload-svr.go')
-rw-r--r--upload-svr.go450
1 files changed, 450 insertions, 0 deletions
diff --git a/upload-svr.go b/upload-svr.go
new file mode 100644
index 0000000..efba0c9
--- /dev/null
+++ b/upload-svr.go
@@ -0,0 +1,450 @@
+// Accepts POST uploads at /
+// Serves HEAD and GET requests for all other files.
+// Serves HEAD and GET requests for other directories.
+package main
+
+import (
+ "bytes"
+ "crypto/sha1"
+ "crypto/sha256"
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "html"
+ "io"
+ "log"
+ "net/http"
+ "os"
+ "path"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ blockSize = 1024 * 1024
+ maxFilenameLength = 50
+)
+
+func isSlashRune(r rune) bool { return r == '/' || r == '\\' }
+
+func isForbiddenPath(path string) bool {
+ if !strings.Contains(path, "..") {
+ return false
+ }
+ for _, elm := range strings.FieldsFunc(path, isSlashRune) {
+ if elm == ".." {
+ return true
+ }
+ }
+ return false
+}
+
+func uriToFilesystem(uri string) string {
+ return strings.Trim(filepath.Clean(uri), "/")
+}
+
+func showIndex(w http.ResponseWriter, r *http.Request, msg string) {
+ if strings.Contains(r.UserAgent(), "curl") ||
+ strings.Contains(r.UserAgent(), "Wget") {
+ w.Header().Set("Content-Type", "text/plain")
+ if msg != "" {
+ w.Write([]byte(msg))
+ w.Write([]byte{'\n'})
+ } else {
+ w.Write([]byte(curlUsage))
+ }
+ return
+ }
+
+ w.Write([]byte(uploadHtml))
+ if msg != "" {
+ msg = strings.ReplaceAll(html.EscapeString(msg), "\n", "<br>\n")
+ errHtml := fmt.Sprintf("<p>%s</p>\n", msg)
+ w.Write([]byte(errHtml))
+ }
+ dir := uriToFilesystem(r.URL.Path)
+ if dir == "" {
+ dir = "."
+ }
+ w.Write([]byte("<pre>\n"))
+ // TODO split presentation from traversal.
+ // TODO group directories first and sort alphabetically.
+ filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
+ if err != nil {
+ log.Printf("Error walking %v: %v\n", path, err)
+ return nil
+ }
+ relpathFs, err := filepath.Rel(dir, path)
+ if err != nil {
+ log.Printf("Error creating relative path %v: %v\n", path, err)
+ return nil
+ }
+ relpath := strings.ReplaceAll(relpathFs, string(filepath.Separator), "/")
+ // If this is 'dir', then display it as '..' if we are a subdir.
+ if relpath == "." {
+ if dir == "." {
+ return nil
+ }
+ relpath = ".."
+ }
+ if info.IsDir() {
+ relpath += "/"
+ }
+ disppath := relpath
+ if len(relpath) > maxFilenameLength {
+ tail := "..>"
+ if info.IsDir() {
+ tail += "/"
+ }
+ disppath = disppath[:maxFilenameLength-len(tail)] + tail
+ }
+ pad := strings.Repeat(" ", maxFilenameLength-len(relpath))
+ sz := "-"
+ if !info.IsDir() {
+ sz = strconv.FormatInt(info.Size(), 10)
+ }
+ // emulates nginx directory index layout
+ entryHtml := fmt.Sprintf(`<a href="%s">%s</a>%s %s %20s`+"\n",
+ html.EscapeString(relpath),
+ html.EscapeString(disppath),
+ pad,
+ info.ModTime().Format("_2-Jan-2006 15:04"),
+ sz)
+ w.Write([]byte(entryHtml))
+ if path != dir && info.IsDir() {
+ //log.Printf("Skipping dir: %v\n", path)
+ return filepath.SkipDir
+ }
+ return nil
+ })
+ w.Write([]byte("</pre>\n"))
+}
+
+// Reads from the files once and compare the results. Returns whether all bytes
+// are the same and the number of bytes from the second reader. If EOF is
+// reached, then "true, 0" will be returned.
+func readAndCompareOnce(a, b io.Reader, bufferA, bufferB []byte) (bool, int) {
+ nA, errA := a.Read(bufferA)
+ nB, errB := b.Read(bufferB)
+ if errA != nil && errA != io.EOF {
+ log.Println("Error reading first file:", errA)
+ return false, nB
+ }
+ if errB != nil && errB != io.EOF {
+ log.Println("Error reading second file:", errB)
+ return false, nB
+ }
+ if nA != nB || !bytes.Equal(bufferA[:nA], bufferB[:nB]) {
+ return false, nB
+ }
+ return true, nB
+}
+
+// Advances both file handles by size (if possible) and returns true if that
+// part is equal.
+func isFileHeadEqual(a, b io.Reader, size int64) bool {
+ ar := io.LimitReader(a, size)
+ br := io.LimitReader(b, size)
+ bufferA := make([]byte, blockSize)
+ bufferB := make([]byte, blockSize)
+ for {
+ ok, n := readAndCompareOnce(ar, br, bufferA, bufferB)
+ if !ok {
+ return false
+ }
+ if n == 0 {
+ return true
+ }
+ }
+}
+
+// Assuming that the initial 'offset' bytes from both files have been verified
+// to be the same, continue reading and check whether they are the same.
+// extraData is data that was consumed from 'newFile' and must be matched before
+// new data can be read.
+// Returns whether the file fully matches, the base offset and updated extra
+// data.
+func isDupe(baseFile, newFile io.Reader, offset int64, extraData []byte) (bool, int64, []byte) {
+ if len(extraData) > 0 {
+ if !isFileHeadEqual(baseFile, bytes.NewReader(extraData), int64(len(extraData))) {
+ return false, offset, extraData
+ }
+ offset += int64(len(extraData))
+ extraData = extraData[:0]
+ }
+ bufferBase := make([]byte, cap(extraData))
+ bufferNew := extraData[:cap(extraData)]
+ for {
+ ok, n := readAndCompareOnce(baseFile, newFile, bufferBase, bufferNew)
+ switch {
+ case !ok:
+ // not the same, remember read buffer.
+ return false, offset, extraData[:n]
+ case ok && n == 0:
+ // EOF - must be the same
+ return true, offset, extraData[:0]
+ case ok && n != 0:
+ // data was read and was found to be the same. Continue.
+ offset += int64(n)
+ }
+ }
+}
+
+// Save the file using the given name as hint. If a file already exists and has
+// the same contents, no new file will be created. Otherwise a new file will be
+// created with the provided contents. Returns the actual filename and a status
+// message.
+func saveFile(userFilename string, r io.Reader) (string, string) {
+ saveName := path.Base(userFilename)
+ // basic sanity check. Incomplete for Windows (where names like "NUL"
+ // and "COM1" are special) - do not care for now.
+ if saveName == "." || saveName == ".." {
+ saveName = "file"
+ }
+
+ // Handle large files in a more efficient way: keep track of "file with
+ // the same data", "number of bytes within said file that was the same",
+ // and "data that was read but different".
+ var baseFile *os.File
+ var baseSize int64
+ newData := make([]byte, 0, blockSize)
+ defer baseFile.Close()
+
+ // Read a bit once
+ n, err := r.Read(newData[:cap(newData)])
+ if err != nil && err != io.EOF {
+ return userFilename, fmt.Sprintf("Error reading upload: %v", err.Error())
+ }
+ if n == 0 {
+ return userFilename, "Uploaded file was empty - ignoring."
+ }
+ newData = newData[:n]
+
+ for i := 0; ; i++ {
+ filename := saveName
+ if i > 0 {
+ filename = fmt.Sprintf("%s.%d", filename, i)
+ }
+ f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644)
+ defer f.Close()
+ if os.IsExist(err) {
+ if baseFile2, err := os.Open(filename); err == nil {
+ defer baseFile2.Close()
+ // If there is an existing base file, check if
+ // it starts with the same data.
+ if baseFile != nil {
+ if _, err := baseFile.Seek(0, os.SEEK_SET); err != nil {
+ // Seek failed, cannot compare.
+ continue
+ }
+ if !isFileHeadEqual(baseFile, baseFile2, baseSize) {
+ // different - not a dupe
+ continue
+ }
+ }
+ var dupe bool
+ var newSize int64
+ dupe, newSize, newData = isDupe(baseFile2, r, baseSize, newData)
+ if baseSize != newSize {
+ baseFile = baseFile2
+ baseFile2 = nil
+ baseSize = newSize
+ }
+ if dupe {
+ return filename, "Duplicate file"
+ }
+ }
+ // Duplicate filename.
+ continue
+ }
+ if err != nil {
+ return filename, fmt.Sprintf("Unable to save file: %v", err)
+ }
+
+ // Ready to save the file.
+ if filename == userFilename {
+ log.Printf("Saving %v\n", filename)
+ } else {
+ log.Printf("Saving %v as %v\n", userFilename, filename)
+ }
+ var n int64
+ if baseFile != nil {
+ if _, err := baseFile.Seek(0, os.SEEK_SET); err != nil {
+ // Seek failed, OOPS, cannot reconstruct file!
+ return filename, fmt.Sprintf("File data disappeared (seek): %v", err)
+ }
+ n, err = io.Copy(f, io.LimitReader(baseFile, baseSize))
+ if err != nil {
+ return filename, fmt.Sprintf("File data disappeared (copy): %v", err)
+ }
+ if n != baseSize {
+ return filename, fmt.Sprintf("Wrote %d bytes instead of %d bytes", n, baseSize)
+ }
+ }
+ if len(newData) != 0 {
+ k, err := f.Write(newData)
+ if k != len(newData) {
+ return filename, fmt.Sprintf("Wrote %d+%d bytes instead of %d+%d: %v", n, k, n, len(newData), err)
+ }
+ n += int64(len(newData))
+ }
+ m, err := io.Copy(f, r)
+ if err != nil {
+ return filename, fmt.Sprintf("Wrote %d bytes, %s", n+m, err)
+ }
+ return filename, fmt.Sprintf("OK. Wrote %d bytes", n+m)
+ }
+}
+
+func handleUpload(w http.ResponseWriter, r *http.Request) {
+ mr, err := r.MultipartReader()
+ if err != nil {
+ log.Print(err)
+ showIndex(w, r, err.Error())
+ return
+ }
+ var messages []string
+ for {
+ p, err := mr.NextPart()
+ if err == io.EOF {
+ break
+ }
+ if err != nil {
+ log.Printf("Upload error: %s\n", err)
+ messages = append(messages, err.Error())
+ break
+ }
+ if p.FormName() != "file" {
+ continue
+ }
+ // Save file
+ saveName, msg := saveFile(p.FileName(), p)
+ displayName := saveName
+ if p.FileName() != saveName {
+ displayName = fmt.Sprintf("%s -> %s", p.FileName(), displayName)
+ }
+ msg = fmt.Sprintf("%s - %s", displayName, msg)
+ log.Println(msg)
+ messages = append(messages, msg)
+ }
+ if messages == nil {
+ showIndex(w, r, "No files selected")
+ } else {
+ showIndex(w, r, strings.Join(messages, "\n"))
+ }
+}
+
+func serveFile(w http.ResponseWriter, r *http.Request) {
+ if isForbiddenPath(r.URL.Path) {
+ http.Error(w, "forbidden URL", http.StatusBadRequest)
+ return
+ }
+ f, err := os.Open(uriToFilesystem(r.URL.Path))
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusNotFound)
+ log.Println(err)
+ return
+ }
+ defer f.Close()
+ http.ServeContent(w, r, r.URL.Path, time.Time{}, f)
+}
+
+var (
+ certFile string
+ keyFile string
+ address string
+)
+
+func init() {
+ flag.StringVar(&certFile, "cert", "", "Certificate and optional key file")
+ flag.StringVar(&keyFile, "key", "", "Key file")
+ flag.StringVar(&address, "accept", ":1111", "Listen address")
+}
+
+func main() {
+ flag.Parse()
+ if flag.NArg() != 0 {
+ fmt.Fprintln(os.Stderr, "Too many arguments")
+ os.Exit(2)
+ }
+ if certFile != "" && keyFile == "" {
+ keyFile = certFile
+ }
+
+ http.HandleFunc("/favicon.ico", http.NotFound)
+ http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
+ log.Printf("%s %s %s", r.RemoteAddr, r.Method, r.URL.Path)
+ if isForbiddenPath(r.URL.Path) {
+ http.Error(w, "forbidden URL", http.StatusBadRequest)
+ return
+ }
+
+ isDir := r.URL.Path[len(r.URL.Path)-1] == '/'
+ if r.Method == "POST" {
+ // TODO relax this to other locations?
+ if r.URL.Path != "/" {
+ http.Error(w, "Uploads are only accepted at /", http.StatusMethodNotAllowed)
+ return
+ }
+ handleUpload(w, r)
+ } else if r.Method == "HEAD" || r.Method == "GET" {
+ if isDir {
+ showIndex(w, r, "")
+ return
+ }
+ serveFile(w, r)
+ } else {
+ http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+ }
+ })
+ if certFile != "" {
+ cert, err := tls.LoadX509KeyPair(certFile, keyFile)
+ if err != nil {
+ log.Fatal(err)
+ }
+ server := &http.Server{
+ Addr: address,
+ TLSConfig: &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ },
+ }
+ leafCert := server.TLSConfig.Certificates[0].Certificate[0]
+ sha256sum := sha256.Sum256(leafCert)
+ log.Println("SHA-256 Fingerprint:", formatFingerprint(sha256sum[:]))
+ sha1sum := sha1.Sum(leafCert)
+ log.Println("SHA1 Fingerprint: ", formatFingerprint(sha1sum[:]))
+
+ log.Println("Listening at", server.Addr, "using TLS")
+ log.Fatal(server.ListenAndServeTLS("", ""))
+ } else {
+ log.Println("Listening at", address)
+ log.Fatal(http.ListenAndServe(address, nil))
+ }
+}
+
+func formatFingerprint(certData []byte) string {
+ const hextable = "0123456789ABCDEF"
+ out := make([]byte, len(certData)*3-1)
+ for i, v := range certData {
+ if i > 0 {
+ out[i*3-1] = ':'
+ }
+ out[i*3] = hextable[v>>4]
+ out[i*3+1] = hextable[v&0xf]
+ }
+ return string(out)
+}
+
+const uploadHtml = `
+<!doctype html>
+<meta charset="UTF-8">
+<meta name="viewport" content="initial-scale=1">
+<form action="" method="POST" enctype="multipart/form-data">
+<input type="file" name="file" multiple>
+<input type="submit" value="Upload">
+</form>
+<hr>
+`
+const curlUsage = "Usage: curl -Ffile=@input.txt ...\n"