feat: support multi keys
This commit is contained in:
176
simple-ssh-server.go
Normal file
176
simple-ssh-server.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const WELCOME = ", . . . \n" +
|
||||
"| . | | | \n" +
|
||||
"| ) ) ,-. | ,-. ,-. ;-.-. ,-. |- ,-. \n" +
|
||||
"|/|/ |-' | | | | | | | |-' | | | \n" +
|
||||
"' ' `-' ' `-' `-' ' ' ' `-' `-' `-' \n" +
|
||||
" \n" +
|
||||
" . . . . ,-. \n" +
|
||||
" `. | | | | ' ( ` \n" +
|
||||
"--- > |--| ,-: |- |- ,-. ;-. ,-. `-. ,-. ;-. . , ,-. ;-. \n" +
|
||||
" ,' | | | | | | |-' | `-. . ) |-' | |/ |-' | \n" +
|
||||
" ' ' `-` `-' `-' `-' ' `-' `-' `-' ' ' `-' ' \n" +
|
||||
" \n"
|
||||
|
||||
func main() {
|
||||
args := os.Args
|
||||
log.Println("Arguments: ", args)
|
||||
port := ":2222"
|
||||
if len(args) > 1 {
|
||||
port = ":" + args[1]
|
||||
}
|
||||
log.Println("Use port: ", port)
|
||||
|
||||
allowedSshPublicKeys, allowedSshPublicKeysErr := parseAllowedSshPubkeysAsStrings()
|
||||
if allowedSshPublicKeysErr != nil {
|
||||
log.Fatal("Parse sk ecdsa public key(s) failed: ", allowedSshPublicKeysErr)
|
||||
return
|
||||
}
|
||||
log.Println("Found sk ecdsa public keys: ", len(allowedSshPublicKeys))
|
||||
for i, k := range allowedSshPublicKeys {
|
||||
log.Println(i, ">>", k)
|
||||
}
|
||||
|
||||
hostKeyBytes, hostKeyBytesErr := readHostKey()
|
||||
if hostKeyBytesErr != nil {
|
||||
log.Fatal("Load host key failed: ", hostKeyBytesErr)
|
||||
return
|
||||
}
|
||||
|
||||
ssh.Handle(func(s ssh.Session) {
|
||||
io.WriteString(s, WELCOME)
|
||||
|
||||
cmd := exec.Command("/bin/bash")
|
||||
ptyReq, winCh, isPty := s.Pty()
|
||||
if isPty {
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term))
|
||||
cmd.Env = append(cmd.Env, fmt.Sprintf("HOME=/root"))
|
||||
f, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
setWinsize(f, win.Width, win.Height)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
io.Copy(f, s) // stdin
|
||||
}()
|
||||
io.Copy(s, f) // stdout
|
||||
cmd.Wait()
|
||||
s.Exit(0)
|
||||
} else {
|
||||
io.WriteString(s, "No PTY requested.\n")
|
||||
s.Exit(1)
|
||||
}
|
||||
})
|
||||
|
||||
hostKeyOption := ssh.HostKeyPEM(hostKeyBytes)
|
||||
|
||||
publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
marshalPubKey := marshalKey(key)
|
||||
log.Println("Auth public key: ", marshalPubKey, ", from: ", ctx.RemoteAddr())
|
||||
|
||||
for idxAllowedSshPublicKey, allowedSshPublicKey := range allowedSshPublicKeys {
|
||||
if marshalPubKey == allowedSshPublicKey {
|
||||
log.Println("Key allowed: ", idxAllowedSshPublicKey)
|
||||
return true
|
||||
}
|
||||
}
|
||||
log.Println("Key NOT allowed")
|
||||
return false
|
||||
})
|
||||
log.Println("Listening ", port, "...")
|
||||
log.Fatal(ssh.ListenAndServe(port, nil, hostKeyOption, publicKeyOption))
|
||||
}
|
||||
|
||||
func setWinsize(f *os.File, w, h int) {
|
||||
syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ),
|
||||
uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(h), uint16(w), 0, 0})))
|
||||
}
|
||||
|
||||
func parseAllowedSshPubkeysAsStrings() ([]string, error) {
|
||||
pubkeys, err := parseAllowedSshPubkeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var pubkeysAsStrings []string
|
||||
for _, pubkey := range pubkeys {
|
||||
pubkeysAsStrings = append(pubkeysAsStrings, marshalKey(pubkey))
|
||||
}
|
||||
return pubkeysAsStrings, nil
|
||||
}
|
||||
|
||||
func parseAllowedSshPubkeys() ([]gossh.PublicKey, error) {
|
||||
pubkeyBytes, pubkeyErr := ioutil.ReadFile("allowed_keys")
|
||||
if pubkeyErr != nil {
|
||||
return nil, pubkeyErr
|
||||
}
|
||||
var pubkeys []gossh.PublicKey
|
||||
pubkeyLines := string(pubkeyBytes)
|
||||
pubkeySplitedLines := strings.Split(pubkeyLines, "\n")
|
||||
|
||||
for _, pubkeyLine := range pubkeySplitedLines {
|
||||
pubkey := strings.TrimSpace(pubkeyLine)
|
||||
if len(pubkey) > 0 && !strings.HasPrefix(pubkey, "#") {
|
||||
pubkey = strings.Split(pubkey, " ")[1]
|
||||
pubkeyBytes, pubkeyBytesErr := base64.StdEncoding.DecodeString(pubkey)
|
||||
if pubkeyBytesErr != nil {
|
||||
return nil, pubkeyBytesErr
|
||||
}
|
||||
publicKey, publicKeyErr := gossh.ParsePublicKey(pubkeyBytes)
|
||||
if publicKeyErr != nil {
|
||||
return nil, publicKeyErr
|
||||
}
|
||||
pubkeys = append(pubkeys, publicKey)
|
||||
}
|
||||
}
|
||||
return pubkeys, nil
|
||||
}
|
||||
|
||||
func readHostKey() ([]byte, error) {
|
||||
hostKeyEcdsaFile := "/etc/ssh/ssh_host_ecdsa_key"
|
||||
hostKeyEcdsaFileBytes, hostKeyEcdsaFileBytesErr := ioutil.ReadFile(hostKeyEcdsaFile)
|
||||
if hostKeyEcdsaFileBytesErr == nil {
|
||||
log.Println("Found host key: ", hostKeyEcdsaFile)
|
||||
return hostKeyEcdsaFileBytes, nil
|
||||
}
|
||||
|
||||
hostKeyFile := "/etc/ssh/ssh_host_rsa_key"
|
||||
hostKeyFileBytes, hostKeyFileBytesErr := ioutil.ReadFile(hostKeyFile)
|
||||
if hostKeyFileBytesErr == nil {
|
||||
log.Println("Found host key: ", hostKeyFile)
|
||||
return hostKeyFileBytes, nil
|
||||
}
|
||||
|
||||
tempHostKeyFileBytes, tempHostKeyFileBytesErr := ioutil.ReadFile("/Users/hatterjiang/.ssh/id_rsa")
|
||||
if tempHostKeyFileBytesErr == nil {
|
||||
log.Println("!!WARN!! Found host key: ", "~/.ssh/id_rsa")
|
||||
return tempHostKeyFileBytes, nil
|
||||
}
|
||||
return nil, errors.New("Canot read any host key from file")
|
||||
}
|
||||
|
||||
func marshalKey(pubkey ssh.PublicKey) string {
|
||||
return pubkey.Type() + " " + base64.StdEncoding.EncodeToString(pubkey.Marshal())
|
||||
}
|
||||
Reference in New Issue
Block a user