package main import ( "encoding/base64" "errors" "fmt" "io" "io/ioutil" "log" "os" "os/exec" "strings" "syscall" "unsafe" "github.com/creack/pty" "github.com/gliderlabs/ssh" gossh "golang.org/x/crypto/ssh" ) const WELCOME = ", . . . \n" + "| . | | | \n" + "| ) ) ,-. | ,-. ,-. ;-.-. ,-. |- ,-. \n" + "|/|/ |-' | | | | | | | |-' | | | \n" + "' ' `-' ' `-' `-' ' ' ' `-' `-' `-' \n" + " \n" + " . . . . ,-. \n" + " `. | | | | ' ( ` \n" + "--- > |--| ,-: |- |- ,-. ;-. ,-. `-. ,-. ;-. . , ,-. ;-. \n" + " ,' | | | | | | |-' | `-. . ) |-' | |/ |-' | \n" + " ' ' `-` `-' `-' `-' ' `-' `-' `-' ' ' `-' ' \n" + " \n" func main() { args := os.Args log.Println("Arguments: ", args) port := ":2222" if len(args) > 1 { port = ":" + args[1] } log.Println("Use port: ", port) allowedSshPublicKeys, allowedSshPublicKeysErr := parseAllowedSshPubkeysAsStrings() if allowedSshPublicKeysErr != nil { log.Fatal("Parse sk ecdsa public key(s) failed: ", allowedSshPublicKeysErr) return } log.Println("Found sk ecdsa public keys: ", len(allowedSshPublicKeys)) for i, k := range allowedSshPublicKeys { log.Println(i, ">>", k) } hostKeyBytes, hostKeyBytesErr := readHostKey() if hostKeyBytesErr != nil { log.Fatal("Load host key failed: ", hostKeyBytesErr) return } ssh.Handle(func(s ssh.Session) { io.WriteString(s, WELCOME) cmd := exec.Command("/bin/bash") ptyReq, winCh, isPty := s.Pty() if isPty { cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=/root")) f, err := pty.Start(cmd) if err != nil { panic(err) } go func() { for win := range winCh { setWinsize(f, win.Width, win.Height) } }() go func() { io.Copy(f, s) // stdin }() io.Copy(s, f) // stdout cmd.Wait() s.Exit(0) } else { io.WriteString(s, "No PTY requested.\n") s.Exit(1) } }) hostKeyOption := ssh.HostKeyPEM(hostKeyBytes) publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { marshalPubKey := marshalKey(key) log.Println("Auth public key: ", marshalPubKey, ", from: ", ctx.RemoteAddr()) for idxAllowedSshPublicKey, allowedSshPublicKey := range allowedSshPublicKeys { if marshalPubKey == allowedSshPublicKey { log.Println("Key allowed: ", idxAllowedSshPublicKey) return true } } log.Println("Key NOT allowed") return false }) log.Println("Listening ", port, "...") log.Fatal(ssh.ListenAndServe(port, nil, hostKeyOption, publicKeyOption)) } func setWinsize(f *os.File, w, h int) { syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) } func parseAllowedSshPubkeysAsStrings() ([]string, error) { pubkeys, err := parseAllowedSshPubkeys() if err != nil { return nil, err } var pubkeysAsStrings []string for _, pubkey := range pubkeys { pubkeysAsStrings = append(pubkeysAsStrings, marshalKey(pubkey)) } return pubkeysAsStrings, nil } func parseAllowedSshPubkeys() ([]gossh.PublicKey, error) { pubkeyBytes, pubkeyErr := ioutil.ReadFile("allowed_keys") if pubkeyErr != nil { return nil, pubkeyErr } var pubkeys []gossh.PublicKey pubkeyLines := string(pubkeyBytes) pubkeySplitedLines := strings.Split(pubkeyLines, "\n") for _, pubkeyLine := range pubkeySplitedLines { pubkey := strings.TrimSpace(pubkeyLine) // Comments starts with `#` if len(pubkey) > 0 && !strings.HasPrefix(pubkey, "#") { pubkey = strings.Split(pubkey, " ")[1] pubkeyBytes, pubkeyBytesErr := base64.StdEncoding.DecodeString(pubkey) if pubkeyBytesErr != nil { return nil, pubkeyBytesErr } publicKey, publicKeyErr := gossh.ParsePublicKey(pubkeyBytes) if publicKeyErr != nil { return nil, publicKeyErr } pubkeys = append(pubkeys, publicKey) } } return pubkeys, nil } func readHostKey() ([]byte, error) { hostKeyEd25519File := "/etc/ssh/ssh_host_ed25519_key" hostKeyEd25519FileBytes, hostKeyEd25519FileBytesErr := ioutil.ReadFile(hostKeyEd25519File) if hostKeyEd25519FileBytesErr == nil { log.Println("Found host key: ", hostKeyEd25519File) return hostKeyEd25519FileBytes, nil } hostKeyEcdsaFile := "/etc/ssh/ssh_host_ecdsa_key" hostKeyEcdsaFileBytes, hostKeyEcdsaFileBytesErr := ioutil.ReadFile(hostKeyEcdsaFile) if hostKeyEcdsaFileBytesErr == nil { log.Println("Found host key: ", hostKeyEcdsaFile) return hostKeyEcdsaFileBytes, nil } hostKeyFile := "/etc/ssh/ssh_host_rsa_key" hostKeyFileBytes, hostKeyFileBytesErr := ioutil.ReadFile(hostKeyFile) if hostKeyFileBytesErr == nil { log.Println("Found host key: ", hostKeyFile) return hostKeyFileBytes, nil } tempHostKeyFileBytes, tempHostKeyFileBytesErr := ioutil.ReadFile("/Users/hatterjiang/.ssh/id_rsa") if tempHostKeyFileBytesErr == nil { log.Println("!!WARN!! Found host key: ", "~/.ssh/id_rsa") return tempHostKeyFileBytes, nil } return nil, errors.New("Canot read any host key from file") } func marshalKey(pubkey ssh.PublicKey) string { return pubkey.Type() + " " + base64.StdEncoding.EncodeToString(pubkey.Marshal()) }