package main import ( "crypto/sha256" "crypto/subtle" "encoding/binary" "encoding/hex" "flag" "fmt" ) var verbose bool func main() { flag.BoolVar(&verbose, "verbose", false, "verbose") flag.Parse() // 256 bits secretKey // in real life it should be generated using crypto.Rand secretKey := []byte("secretsecretsecretsecretsecretse") legitimateData := []byte("user_id=1&role=user") legitimateSignature := sign(secretKey, legitimateData) // sha256 := append(make([]byte, 0), secretKey...) // // sha256 := make([]byte, 0) // sha256 = append(sha256, legitimateData...) // oadding := generatePadding(uint64(len(secretKey)), uint64(len(legitimateData))) // sha256 = append(sha256, oadding...) // fmt.Println(hex.Dump(sha256)) fmt.Printf("SecretKey: %s\n", hex.EncodeToString(secretKey)) fmt.Printf("Legitimate Data: %s\n", string(legitimateData)) fmt.Printf("Legitimate Signature SHA256(SecretKey || LegitimateData): %s\n", hex.EncodeToString(legitimateSignature)) fmt.Printf("Verify LegitimateSignature == SHA256(SecretKey || LegitimateData): %v\n", verifySignature(secretKey, legitimateSignature, legitimateData)) fmt.Println("\n---------------------------------------------------------------------------------------------------\n") maliciousData := []byte("&something=true&role=admin") maliciousMessage := generateMaliciousMessage(uint64(len(secretKey)), legitimateData, maliciousData) maliciousSignature := forgeSignature(legitimateSignature, maliciousData, uint64(len(secretKey)+len(legitimateData))) fmt.Printf("Malicious Data: %s\n", string(maliciousData)) if verbose { fmt.Println("Malicious Message (LegitimateData || padding || MaliciousData):") fmt.Println(hex.Dump(maliciousMessage)) } fmt.Printf("Malicious Signature: %s\n", hex.EncodeToString(maliciousSignature)) fmt.Printf("Verify MaliciousSignature == SHA256(SecretKey, MaliciousMessage): %v\n", verifySignature(secretKey, maliciousSignature, maliciousMessage)) } // forgeSignature performs a length extension attack by loading a SHA256 hash from the legitimate signature // and appending the malicious data. func forgeSignature(legitimateSignature []byte, maliciousData []byte, secretKeyAndDataLength uint64) (forgedSignature []byte) { digest := loadSha256(legitimateSignature, secretKeyAndDataLength) digest.Write(maliciousData) hash := digest.Sum(nil) forgedSignature = hash[:] return } // generateMaliciousMessage generates the malicious message used to forge a signature without knowing the // secretKey. The message has the following format: (legitimateData || padding || maliciousData) func generateMaliciousMessage(secretKeyLength uint64, legitimateData []byte, maliciousData []byte) (message []byte) { padding := generatePadding(secretKeyLength + uint64(len(legitimateData))) message = make([]byte, 0, len(legitimateData)+len(padding)+len(maliciousData)) message = append(message, legitimateData...) message = append(message, padding...) message = append(message, maliciousData...) return } // generatePadding generates the required padding to fill SHA256 blocks of 512 bits (64 bytes) // with (secretKey || data || padding) // The padding format is defined in RFC6234: https://www.rfc-editor.org/rfc/rfc6234#page-8 // inspired by `sha256.go` func generatePadding(secretKeyAndDataLength uint64) []byte { var tmp [64 + 8]byte // padding + length buffer var t uint64 // Padding. Add a 1 bit and 0 bits until 56 bytes mod 64. tmp[0] = 0x80 if secretKeyAndDataLength%64 < 56 { t = 56 - secretKeyAndDataLength%64 } else { t = 64 + 56 - secretKeyAndDataLength%64 } // Length in bits. secretKeyAndDataLength <<= 3 padlen := tmp[:t+8] binary.BigEndian.PutUint64(padlen[t+0:], secretKeyAndDataLength) return padlen } // verifySignature verifies that Signature == SHA256(secretKey || data) func verifySignature(secretKey []byte, signatureToVerify []byte, data []byte) (isValid bool) { isValid = false signature := sign(secretKey, data) if subtle.ConstantTimeCompare(signature, signatureToVerify) == 1 { isValid = true } return } // sign generates a SHA256 MAC such as SHA256(secretKey || data) func sign(secretKey []byte, data []byte) (signature []byte) { hasher := sha256.New() hasher.Write(secretKey) hasher.Write(data) hash := hasher.Sum(nil) signature = hash[:] return } // loadSha256 is a slightly modified version of digest.UnmarshalBinary in order to load the state from a // normal SHA256 hash instead of the "proprietary version" generated by digest.MarshalBinary func loadSha256(hashBytes []byte, secretKeyAndDataLength uint64) (hash *digest) { if len(hashBytes) != sha256.Size { panic("loadSha256: not a valid SHA256 hash") } hash = new(digest) hash.Reset() hashBytes, hash.h[0] = consumeUint32(hashBytes) hashBytes, hash.h[1] = consumeUint32(hashBytes) hashBytes, hash.h[2] = consumeUint32(hashBytes) hashBytes, hash.h[3] = consumeUint32(hashBytes) hashBytes, hash.h[4] = consumeUint32(hashBytes) hashBytes, hash.h[5] = consumeUint32(hashBytes) hashBytes, hash.h[6] = consumeUint32(hashBytes) _, hash.h[7] = consumeUint32(hashBytes) // hash.len is the nearest upper multiple of 64 of the hashed data (secretKeyAndDataLength) // hash.len = secretKeyAndDataLength + 64 - (secretKeyAndDataLength % 64) // hash.nx = int(hash.len % chunk) // hash.len is the length of consumed bytes, including the paddings hash.len = secretKeyAndDataLength + uint64(len(generatePadding(secretKeyAndDataLength))) return } // func signBinary(secretKey []byte, data []byte) (signature []byte) { // hasher := new(digest) // hasher.Reset() // hasher.Write(secretKey) // hasher.Write(data) // hash := hasher.checkSum() // signature = hash[:] // if verbose { // binary, _ := hasher.MarshalBinary() // fmt.Println("SHA256 Binary:") // fmt.Println(hex.Dump(binary)) // } // return // } // func loadSha256Binary(hashBytes []byte, secretKeyAndDataLength uint64) (hash *digest) { // digestBinary := make([]byte, 0, marshaledSize) // digestBinary = append(digestBinary, []byte(magic256)...) // digestBinary = append(digestBinary, hashBytes...) // digestBinary = append(digestBinary, make([]byte, chunk)...) // digestBinary = binary.BigEndian.AppendUint64(digestBinary, secretKeyAndDataLength+64-(secretKeyAndDataLength%64)) // hash = new(digest) // hash.Reset() // err := hash.UnmarshalBinary(digestBinary) // if err != nil { // panic(err) // } // if verbose { // fmt.Println("SHA256 state:") // fmt.Println(hex.Dump(digestBinary)) // } // return // } // dumpBinary prints 00000000 00000000 00000000 00000001 // func dumpBinary(data []byte) { // for i, n := range data { // fmt.Printf("%08b ", n) // if (i+1)%4 == 0 && i != 0 { // fmt.Println("") // } // } // fmt.Println("") // }