diff --git a/main.go b/main.go index 5497ae6..0253270 100644 --- a/main.go +++ b/main.go @@ -1,8 +1,11 @@ package main import ( + "encoding/base64" + "errors" "fmt" "io" + "io/ioutil" "log" "os" "os/exec" @@ -15,15 +18,37 @@ import ( gossh "golang.org/x/crypto/ssh" ) +const SK_ECDSA_PUBKEY = "sk-ecdsa-sha2-nistp256@openssh.com AAAAInNrLWVjZHNhLXNoYTItbmlzdHAyNTZAb3BlbnNzaC5jb20AAAAIbmlzdHAyNTYAAABBBIEy/KQzi+q7uqufEtqHXusQbpT9GVM2j1jNhU83VI8T8VOy4nWX9STNU+qpcwp6l1wqhYZSmMRkXF+3CwCAssAAAAAEc3NoOg== fido-u2f" +const WELCOME = ", . . . \n" + + "| . | | | \n" + + "| ) ) ,-. | ,-. ,-. ;-.-. ,-. |- ,-. \n" + + "|/|/ |-' | | | | | | | |-' | | | \n" + + "' ' `-' ' `-' `-' ' ' ' `-' `-' `-' \n" + + " \n" + + " . . . . ,-. \n" + + " `. | | | | ' ( ` \n" + + "--- > |--| ,-: |- |- ,-. ;-. ,-. `-. ,-. ;-. . , ,-. ;-. \n" + + " ,' | | | | | | |-' | `-. . ) |-' | |/ |-' | \n" + + " ' ' `-` `-' `-' `-' ' `-' `-' `-' ' ' `-' ' \n" + + " " + func main() { + sshPublicKey, sshPublicKeyErr := parseSshPubkey(SK_ECDSA_PUBKEY) + if sshPublicKeyErr != nil { + log.Fatal("Parse sk ecdsa public key failed: ", sshPublicKeyErr) + return + } + marshalSshPublicKey := marshalKey(sshPublicKey) + log.Println("Found sk ecdsa public key: ", marshalSshPublicKey) + hostKeyBytes, hostKeyBytesErr := readHostKey() + if hostKeyBytesErr != nil { + log.Fatal("Load host key failed: ", hostKeyBytesErr) + return + } + ssh.Handle(func(s ssh.Session) { - authorizedKey := gossh.MarshalAuthorizedKey(s.PublicKey()) - io.WriteString(s, fmt.Sprintf("\r\n%s\r\n", strings.Repeat("-", 88))) - io.WriteString(s, fmt.Sprintf("public key used by %s:\n", s.User())) - io.WriteString(s, fmt.Sprintf("%v", s.PublicKey())) - io.WriteString(s, fmt.Sprintf("\r\n%s\r\n", strings.Repeat("-", 88))) - s.Write(authorizedKey) - io.WriteString(s, fmt.Sprintf("\r\n%s\r\n", strings.Repeat("-", 88))) + // authorizedKey := gossh.MarshalAuthorizedKey(s.PublicKey()) + io.WriteString(s, WELCOME) cmd := exec.Command("/bin/bash") ptyReq, winCh, isPty := s.Pty() @@ -49,19 +74,52 @@ func main() { } }) + hostKeyOption := ssh.HostKeyPEM(hostKeyBytes) + publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool { - log.Println("type: ", key.Type()) - if key.Type() == "sk-ecdsa-sha2-nistp256@openssh.com" { - return true - } - return false - //return true // allow all keys, or use ssh.KeysEqual() to compare against known keys + marshalPubKey := marshalKey(key) + log.Println("Auth public key: ", marshalPubKey, ", from: ", ctx.RemoteAddr()) + isAllowed := marshalPubKey == marshalSshPublicKey + log.Println("Key allowed: ", isAllowed) + return isAllowed }) - log.Println("Listening :222...") - log.Fatal(ssh.ListenAndServe(":2222", nil, publicKeyOption)) + log.Println("Listening :2222...") + log.Fatal(ssh.ListenAndServe(":2222", nil, hostKeyOption, publicKeyOption)) } func setWinsize(f *os.File, w, h int) { syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0}))) } + +func parseSshPubkey(pubkey string) (gossh.PublicKey, error) { + if strings.Contains(pubkey, " ") { + pubkey = strings.Split(pubkey, " ")[1] + } + pubkeyBytes, pubkeyBytesErr := base64.StdEncoding.DecodeString(pubkey) + if pubkeyBytesErr != nil { + return nil, pubkeyBytesErr + } + publicKey, publicKeyErr := gossh.ParsePublicKey(pubkeyBytes) + if publicKeyErr != nil { + return nil, publicKeyErr + } + return publicKey, nil +} + +func readHostKey() ([]byte, error) { + hostKeyFile := "/etc/ssh/ssh_host_rsa_key" + hostKeyFileBytes, hostKeyFileBytesErr := ioutil.ReadFile(hostKeyFile) + if hostKeyFileBytesErr == nil { + return hostKeyFileBytes, nil + } + tempHostKeyFileBytes, tempHostKeyFileBytesErr := ioutil.ReadFile("/Users/hatterjiang/.ssh/id_rsa") + if tempHostKeyFileBytesErr == nil { + return tempHostKeyFileBytes, nil + } + return nil, errors.New("Canot read any host key from file") +} + +func marshalKey(pubkey ssh.PublicKey) string { + return pubkey.Type() + " " + base64.StdEncoding.EncodeToString(pubkey.Marshal()) +}