diff --git a/.gitignore b/.gitignore index 3af2f7dd..c087c881 100644 --- a/.gitignore +++ b/.gitignore @@ -1,75 +1,19 @@ -# project specific +ACVP-Server-1.1.0.39/ +SLH-DSA-.*-1.rsp +.DS_STORE +.vscode/launch.json -xtest -__pycache__ - -_build -_prof -*.vvp -*.log -*.drc -*.bgn -*.bit -*.mmi -*.jou -xvlog.* -build -firmware.* -config.h -obj_dir/* -prof -pqse -syn_out - -# Prerequisites -*.d - -# Object files +# executables *.o -*.ko -*.obj -*.elf - -# Linker output -*.ilk -*.map -*.exp - -# Precompiled Headers -*.gch -*.pch - -# Libraries -*.lib -*.a -*.la -*.lo - -# Shared objects (inc. Windows DLLs) -*.dll -*.so -*.so.* -*.dylib - -# Executables -*.exe -*.out -*.app -*.i*86 -*.x86_64 -*.hex - -# Debug files -*.dSYM/ -*.su -*.idb -*.pdb - -# Kernel Module Compile Results -*.mod* -*.cmd -.tmp_versions/ -modules.order -Module.symvers -Mkfile.old -dkms.conf +./slh_sha2 +./slh_dsa +./ACVP_test_functions +./ACVP_sig_test +./kat_test +./sha3_f1600 +./ACVP_keygen_test +./sha2_256 +./sha2_512 +./slh_shake +./ACVP_ver_test +./sha3_api \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..9249addb --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,48 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + + + "version": "0.2.0", + "configurations": [ + { + "name": "Debug kat_test", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/slh/kat_test", + "args": [], + "stopAtEntry": false, + "cwd": "${workspaceFolder}", + "environment": [], + "externalConsole": false, + "MIMode": "lldb" // or "gdb" for Linux/WSL + }, + { + "name": "Debug ACVP_keygen_test", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/slh/ACVP_keygen_test", + "args": [], + "stopAtEntry": false, + "cwd": "${workspaceFolder}", + "environment": [], + "externalConsole": false, + "MIMode": "lldb" + }, + { + "name": "Debug ACVP_sig_test", + "type": "cppdbg", + "request": "launch", + "program": "${workspaceFolder}/slh/ACVP_sig_test", + "args": [], + "stopAtEntry": false, + "cwd": "${workspaceFolder}", + "environment": [], + "externalConsole": false, + "MIMode": "lldb" + } + ] +} + + diff --git a/slh/.gitignore b/slh/.gitignore new file mode 100644 index 00000000..fa4018f2 --- /dev/null +++ b/slh/.gitignore @@ -0,0 +1,12 @@ +./slh_sha2 +./slh_dsa +./ACVP_test_functions +./ACVP_sig_test +./kat_test +./sha3_f1600 +./ACVP_keygen_test +./sha2_256 +./sha2_512 +./slh_shake +./ACVP_ver_test +./sha3_api diff --git a/slh/ACVP_keygen_test b/slh/ACVP_keygen_test new file mode 100755 index 00000000..387e0994 Binary files /dev/null and b/slh/ACVP_keygen_test differ diff --git a/slh/ACVP_keygen_test.c b/slh/ACVP_keygen_test.c new file mode 100644 index 00000000..ec29b601 --- /dev/null +++ b/slh/ACVP_keygen_test.c @@ -0,0 +1,120 @@ +#include +#include +#include + +#include "ACVP_tests.h" + +// hardcoded example usage ******************************************************************* + +/* +from JSON -> + +"parameterSet": "SLH-DSA-SHA2-128s", +"skSeed": "AC379F047FAAB2004F3AE32350AC9A3D", +"skPrf": "829FFF0AA59E956A87F3971C4D58E710", +"pkSeed": "0566D240CC519834322EAFBCC73C79F5", +"sk": "AC379F047FAAB2004F3AE32350AC9A3D829FFF0AA59E956A87F3971C4D58E7100566D240CC519834322EAFBCC73C79F5A4B84F02E8BF0CBD54017B2D3C494B57", +"pk": "0566D240CC519834322EAFBCC73C79F5A4B84F02E8BF0CBD54017B2D3C494B57" +*/ + +/* +static int hardcoded_keygen_rbg(uint8_t *x, size_t xlen) +{ + size_t n = slh_dsa_sha2_128s.n; + + char skSeedString[] = "AC379F047FAAB2004F3AE32350AC9A3D"; + char skPrfString[] = "829FFF0AA59E956A87F3971C4D58E710"; + char pkSeedString[] = "0566D240CC519834322EAFBCC73C79F5"; + + hexStringToByteArray(skSeedString,x); + hexStringToByteArray(skPrfString,x+n); + hexStringToByteArray(pkSeedString,x+2*n); + + return 0; +} + +int main() +{ + uint8_t pk[MAX_PK_BYTES] = {0}; + uint8_t sk[MAX_SK_BYTES] = {0}; + uint8_t pk_expected[MAX_PK_BYTES] = {0}; + uint8_t sk_expected[MAX_SK_BYTES] = {0}; + + char skString[] = "AC379F047FAAB2004F3AE32350AC9A3D829FFF0AA59E956A87F3971C4D58E7100566D240CC519834322EAFBCC73C79F5A4B84F02E8BF0CBD54017B2D3C494B57"; + char pkString[] = "0566D240CC519834322EAFBCC73C79F5A4B84F02E8BF0CBD54017B2D3C494B57"; + hexStringToByteArray(skString, sk_expected); + hexStringToByteArray(pkString, pk_expected); + + slh_keygen(pk,sk,&hardcoded_keygen_rbg,&slh_dsa_sha2_128s); + + if(memcmp(sk,sk_expected,sizeof(sk))) + { + printf("SK does not match expected value! \r\n"); + return -1; + } + + else if(memcmp(pk,pk_expected,sizeof(pk))) + { + printf("PK does not match expected value! \r\n"); + return -1; + } + + printf("All tests passed! \r\n"); + return 0; +} +*/ + +/* +arg 1 := tgId +arg 2 := tcId +arg 3 := prmSet +arg 4 := skSeed +arg 5 := skPrf +arg 6 := pkSeed +arg 7 := sk +arg 8 := pk +*/ + +int main(int argc, char *argv[]) +{ + /* one of the args is arg 0 */ + if(argc != 9) + { + printf("%d \r\n", argc); + printf("Usage: ./ACVP_keygen_test \r\n"); + return -1; + } + + uint8_t sk[MAX_SK_BYTES] = {0}; + uint8_t pk[MAX_PK_BYTES] = {0}; + uint8_t skExpected[MAX_SK_BYTES] = {0}; + uint8_t pkExpected[MAX_PK_BYTES] = {0}; + + /* process arg vars */ + uint8_t tgId = (uint8_t)atoi(argv[1]); + uint8_t tcId = (uint8_t)atoi(argv[2]); + selectPrmSet(argv[3],&prmSet_g); + seed_g.skSeedString = argv[4]; + seed_g.skPrfString = argv[5]; + seed_g.pkSeedString = argv[6]; + hexStringToByteArray(argv[7], skExpected); + hexStringToByteArray(argv[8], pkExpected); + + slh_keygen(pk,sk,fixedKeygenRbg,prmSet_g); + + if(memcmp(sk,skExpected,sizeof(sk))) + { + printf("SK does not match expected value in keygen for (tgId,tcId): (%d,%d)! \r\n", tgId, tcId); + return -1; + } + + else if(memcmp(pk,pkExpected,sizeof(pk))) + { + printf("PK does not match expected value in keygen for (tgId,tcId): (%d,%d)! \r\n", tgId, tcId); + return -1; + } + + printf("Correct output for (tgId,tcId): (%d,%d)! \r\n", tgId, tcId); + + return 0; +} \ No newline at end of file diff --git a/slh/ACVP_sig_test b/slh/ACVP_sig_test new file mode 100755 index 00000000..1198d77a Binary files /dev/null and b/slh/ACVP_sig_test differ diff --git a/slh/ACVP_sig_test.c b/slh/ACVP_sig_test.c new file mode 100644 index 00000000..1802942e --- /dev/null +++ b/slh/ACVP_sig_test.c @@ -0,0 +1,90 @@ + +#include +#include +#include + +#include "ACVP_tests.h" +#include "slh_dsa.h" + +/* +arg 1 := tgId +arg 2 := tcId +arg 3 := prmSet +arg 4 := deterministic +arg 5 := interface +arg 6 := m_sz +arg 7 := m +arg 8 := ctxlen +arg 9 := ctxStr +arg 10 := sk +arg 11 := sig +arg 12 := addRnd +*/ + + +int main(int argc, char *argv[]) +{ + /* one of the args is arg 0 */ + if(argc != 13) + { + printf("Arg count is %d! \r\n", argc); + printf("Usage: ./ACVP_sig_test \r\n"); + return -1; + } + + /* process arg vars */ + uint8_t tgId = (uint8_t)atoi(argv[1]); + uint8_t tcId = (uint8_t)atoi(argv[2]); + selectPrmSet(argv[3],&prmSet_g); + deterministic_g = (!strcmp(argv[4],"True")) ? SLH_DETERMINISTIC : SLH_NON_DETERMINISTIC; + interface_e interface = (!strcmp(argv[5],"internal")) ? SLH_INTERNAL : SLH_EXTERNAL; + size_t m_sz = atoi(argv[6])/2; + uint8_t *m = (uint8_t *)malloc(m_sz * sizeof(uint8_t)); + if(m_sz != 0) + { + hexStringToByteArray(argv[7], m); + } + size_t ctxLen = atoi(argv[8])/2; + uint8_t *ctxStr = (uint8_t *)malloc(ctxLen * sizeof(uint8_t)); + + if(interface == SLH_EXTERNAL && ctxLen != 0) + { + hexStringToByteArray(argv[9], ctxStr); + } + uint8_t sk[MAX_SK_BYTES] = {0}; + hexStringToByteArray(argv[10], sk); + size_t sig_sz = strlen(argv[11]); + uint8_t *sig = (uint8_t *)malloc(sig_sz * sizeof(uint8_t)); + memset(sig,0,sig_sz); + uint8_t *sigExpected = (uint8_t *)malloc(sig_sz * sizeof(uint8_t)); + hexStringToByteArray(argv[11], sigExpected); + if(deterministic_g == SLH_NON_DETERMINISTIC){hexStringToByteArray(argv[12], addRnd_g);} + + if(interface == SLH_EXTERNAL) + { + slh_sign(sig,m,m_sz, ctxStr,ctxLen, sk,fixedSigRbg,prmSet_g); + } + else { + slh_sign_internal(sig,m,m_sz,sk,0,0,prmSet_g,addRnd_g); + } + + if(memcmp(sig,sigExpected,sig_sz)) + { + printf("Signature does not match expected value for (tgId,tcId): (%d,%d)! \r\n", tgId, tcId); + + free(m); + free(ctxStr); + free(sig); + free(sigExpected); + + return -1; + } + printf("Signature matches expected value for (tgId,tcId): (%d,%d)! \r\n", tgId, tcId); + + free(m); + free(ctxStr); + free(sig); + free(sigExpected); + + return 0; +} \ No newline at end of file diff --git a/slh/ACVP_test_functions.c b/slh/ACVP_test_functions.c new file mode 100644 index 00000000..2100ab7d --- /dev/null +++ b/slh/ACVP_test_functions.c @@ -0,0 +1,119 @@ +#include +#include +#include + +#include "ACVP_tests.h" + +// common functions ************************************************************************** +uint8_t hexCharToDecimal(char c) +{ + if (c >= '0' && c <= '9') { + return (uint8_t) (c - '0'); + } else if (c >= 'a' && c <= 'f') { + return (uint8_t) (c - 'a' + 10); + } else if (c >= 'A' && c <= 'F') { + return (uint8_t) (c - 'A' + 10); + } else { + fprintf(stderr, "Invalid hex character: %c\n", c); + return 0; + } +} + +void hexStringToByteArray(const char *hexString, uint8_t *byteArray) +{ + size_t len = strlen(hexString); + + if (len % 2 != 0) { + fprintf(stderr, "Hex string must have an even number of characters\n"); + exit(EXIT_FAILURE); + } + + for (size_t i = 0, j = 0; i < len; i += 2, j++) { + byteArray[j] = (uint8_t) ((hexCharToDecimal(hexString[i]) << 4) | hexCharToDecimal(hexString[i + 1])); + } +} + +char decimalToHexChar(uint8_t d) +{ + if (d <= 9) { + return '0' + d; + } else if (d <= 15) { + return 'A' + (d - 10); + } else { + fprintf(stderr, "Invalid value: %u\n", d); + return '?'; + } +} + +void byteArrayToHexString(const uint8_t *byteArray, size_t byteArrayLen, char *hexString) +{ + for (size_t i = 0; i < byteArrayLen; i++) { + uint8_t byte = byteArray[i]; + hexString[2 * i] = decimalToHexChar((byte >> 4) & 0x0F); + hexString[2 * i + 1] = decimalToHexChar(byte & 0x0F); + } + hexString[2 * byteArrayLen] = '\0'; +} + +void selectPrmSet(char *prmSetString, const slh_param_t **prmSet) +{ + // Compare prmSetString with the names of the parameter sets + if (strcmp(prmSetString, "SLH-DSA-SHA2-128s") == 0) { + *prmSet = &slh_dsa_sha2_128s; + } + else if (strcmp(prmSetString, "SLH-DSA-SHAKE-128s") == 0) { + *prmSet = &slh_dsa_shake_128s; + } + else if (strcmp(prmSetString, "SLH-DSA-SHA2-128f") == 0) { + *prmSet = &slh_dsa_sha2_128f; + } + else if (strcmp(prmSetString, "SLH-DSA-SHAKE-128f") == 0) { + *prmSet = &slh_dsa_shake_128f; + } + else if (strcmp(prmSetString, "SLH-DSA-SHA2-192s") == 0) { + *prmSet = &slh_dsa_sha2_192s; + } + else if (strcmp(prmSetString, "SLH-DSA-SHAKE-192s") == 0) { + *prmSet = &slh_dsa_shake_192s; + } + else if (strcmp(prmSetString, "SLH-DSA-SHA2-192f") == 0) { + *prmSet = &slh_dsa_sha2_192f; + } + else if (strcmp(prmSetString, "SLH-DSA-SHAKE-192f") == 0) { + *prmSet = &slh_dsa_shake_192f; + } + else if (strcmp(prmSetString, "SLH-DSA-SHA2-256s") == 0) { + *prmSet = &slh_dsa_sha2_256s; + } + else if (strcmp(prmSetString, "SLH-DSA-SHAKE-256s") == 0) { + *prmSet = &slh_dsa_shake_256s; + } + else if (strcmp(prmSetString, "SLH-DSA-SHA2-256f") == 0) { + *prmSet = &slh_dsa_sha2_256f; + } + else if (strcmp(prmSetString, "SLH-DSA-SHAKE-256f") == 0) { + *prmSet = &slh_dsa_shake_256f; + } + else { + // Handle unknown prmSetString (if needed) + *prmSet = NULL; // or provide a default set + } +} + +int fixedKeygenRbg(uint8_t *x, size_t xlen) +{ + size_t n = prmSet_g->n; + + hexStringToByteArray(seed_g.skSeedString,x); + hexStringToByteArray(seed_g.skPrfString,x+n); + hexStringToByteArray(seed_g.pkSeedString,x+2*n); + + return 0; +} + +int fixedSigRbg(uint8_t *x, size_t xlen) +{ + memcpy(x,addRnd_g,xlen); + + return 0; +} \ No newline at end of file diff --git a/slh/ACVP_tests.h b/slh/ACVP_tests.h new file mode 100644 index 00000000..492cda7f --- /dev/null +++ b/slh/ACVP_tests.h @@ -0,0 +1,28 @@ +#include +#include +#include + +#include "slh_param.h" +#include "slh_dsa.h" + +// defines *********************************************************************************** +#define ACVP_TEST + +#define MAX_N 32 +#define MAX_PK_BYTES 2 * MAX_N +#define MAX_SK_BYTES 4 * MAX_N +#define MAX_SIG_BYTES 1000 *4 8.4 + +// types ************************************************************************************** +typedef enum interface_e +{ + SLH_INTERNAL, + SLH_EXTERNAL +} interface_e; + +// functions *********************************************************************************** +void hexStringToByteArray(const char *hexString, uint8_t *byteArray); +void byteArrayToHexString(const uint8_t *byteArray, size_t byteArrayLen, char *hexString); +void selectPrmSet(char *prmSetString, const slh_param_t **prmSet); +int fixedKeygenRbg(uint8_t *x, size_t xlen); +int fixedSigRbg(uint8_t *x, size_t xlen); diff --git a/slh/ACVP_tests.py b/slh/ACVP_tests.py new file mode 100644 index 00000000..06a6c6e9 --- /dev/null +++ b/slh/ACVP_tests.py @@ -0,0 +1,154 @@ +import subprocess +import json + +''' +Keygen tests format + +arg 1 := tgId +arg 2 := tcId +arg 3 := prmSet +arg 4 := skSeed +arg 5 := skPrf +arg 6 := pkSeed +arg 7 := sk +arg 8 := pk +''' + +# Creates command to call C program to run a keygen testcase +def makeKeygenTestCmd(tgId,tcId,prmSet,skSeed,skPrf,pkSeed,sk,pk): + command = "./ACVP_keygen_test " + command += (str(tgId) + " ") + command += (str(tcId) + " ") + command += (str(prmSet) + " ") + command += (str(skSeed) + " ") + command += (str(skPrf) + " ") + command += (str(pkSeed) + " ") + command += (str(sk) + " ") + command += str(pk) + return command + +# Creates command to call C program to run a keygen testcase +def makeSigTestCmd(tgId,tcId,prmSet,deterministic,interface,m_sz,m,ctxLen,ctxStr,sk,sig,addRnd): + command = "./ACVP_sig_test " + command += (str(tgId) + " ") + command += (str(tcId) + " ") + command += (str(prmSet) + " ") + command += (str(deterministic) + " ") + command += (str(interface) + " ") + command += (str(m_sz) + " ") + command += (str(m) + " ") + command += (str(ctxLen) + " ") + command += (str(ctxStr) + " ") + command += (str(sk) + " ") + command += (str(sig) + " ") + command += str(addRnd) + return command + +def makeVerTestCmd(tgId,tcId,prmSet,interface,m_sz,m,ctxLen,ctxStr,pk,sig,testPassed): + command = "./ACVP_ver_test " + command += (str(tgId) + " ") + command += (str(tcId) + " ") + command += (str(prmSet) + " ") + command += (str(interface) + " ") + command += (str(m_sz) + " ") + command += (str(m) + " ") + command += (str(ctxLen) + " ") + command += (str(ctxStr) + " ") + command += (str(pk) + " ") + command += (str(sig) + " ") + command += (str(testPassed)) + return command + +# Function performs all keygen tests for a given parameter set +def keygenTestPrmSet(prmSet): + with open("../ACVP-Server-1.1.0.39/gen-val/json-files/SLH-DSA-keyGen-FIPS205/internalProjection.json", 'r') as fp: + slh_dsa_kg_acvp = json.load(fp) + + for variant in slh_dsa_kg_acvp["testGroups"]: + if variant["parameterSet"] == prmSet: + for testCase in variant["tests"]: + tgId = str(variant["tgId"]) + tcId = str(testCase["tcId"]) + prmSet = prmSet + skSeed = str(testCase["skSeed"]) + skPrf = str(testCase["skPrf"]) + pkSeed = str(testCase["pkSeed"]) + sk = str(testCase["sk"]) + pk = str(testCase["pk"]) + + subprocess.run(makeKeygenTestCmd(tgId,tcId,prmSet,skSeed,skPrf,pkSeed,sk,pk), shell=True) + +# Function performs all signing tests for a given parameter set +def sigTestPrmSet(prmSet): + with open("../ACVP-Server-1.1.0.39/gen-val/json-files/SLH-DSA-sigGen-FIPS205/internalProjection.json", 'r') as fp: + slh_dsa_sig_acvp = json.load(fp) + + for variant in slh_dsa_sig_acvp["testGroups"]: + if variant["parameterSet"] == prmSet and variant["preHash"] != "preHash": + for testCase in variant["tests"]: + tgId = str(variant["tgId"]) + tcId = str(testCase["tcId"]) + prmSet = prmSet + deterministic = str(variant["deterministic"]) + interface = str(variant["signatureInterface"]) + m = str(testCase["message"]) + m_sz = len(m) + if m_sz == 0: + m = "-" + ctxStr = str(testCase["context"]) if interface == "external" else "-" + ctxLen = len(ctxStr) + if ctxLen == 0: + ctxStr = "-" + sk = str(testCase["sk"]) + sig = str(testCase["signature"]) + addRnd = str(testCase["additionalRandomness"]) if deterministic == "False" else "-" + + subprocess.run(makeSigTestCmd(tgId,tcId,prmSet,deterministic,interface,m_sz,m,ctxLen,ctxStr,sk,sig,addRnd), shell=True) + +# Function performs all verification tests for a given parameter set +def verTestPrmSet(prmSet): + with open("../ACVP-Server-1.1.0.39/gen-val/json-files/SLH-DSA-sigVer-FIPS205/internalProjection.json", 'r') as fp: + slh_dsa_ver_acvp = json.load(fp) + + for variant in slh_dsa_ver_acvp["testGroups"]: + if variant["parameterSet"] == prmSet and variant["preHash"] != "preHash": + for testCase in variant["tests"]: + tgId = str(variant["tgId"]) + tcId = str(testCase["tcId"]) + prmSet = prmSet + interface = str(variant["signatureInterface"]) + m = str(testCase["message"]) + m_sz = len(m) + if m_sz == 0: + m = "-" + ctxStr = str(testCase["context"]) if interface == "external" else "-" + ctxLen = len(ctxStr) + if ctxLen == 0: + ctxStr = "-" + pk = str(testCase["pk"]) + sig = str(testCase["signature"]) + testPassed = testCase["testPassed"] + + subprocess.run(makeVerTestCmd(tgId,tcId,prmSet,interface,m_sz,m,ctxLen,ctxStr,pk,sig,testPassed), shell=True) + + +# Variable Declaration +prmSets = ["SLH-DSA-SHA2-128s", "SLH-DSA-SHAKE-128s", + "SLH-DSA-SHA2-128f", "SLH-DSA-SHAKE-128f", + "SLH-DSA-SHA2-192s", "SLH-DSA-SHAKE-192s", + "SLH-DSA-SHA2-192f", "SLH-DSA-SHAKE-192f", + "SLH-DSA-SHA2-256s", "SLH-DSA-SHAKE-256s", + "SLH-DSA-SHA2-256f", "SLH-DSA-SHAKE-256f"] + +# Begin main program +# Keygen tests +for prmSet in prmSets: + keygenTestPrmSet(prmSet) + +# Signing tests +for prmSet in prmSets: + sigTestPrmSet(prmSet) + +# Verification tests +for prmSet in prmSets: + verTestPrmSet(prmSet) \ No newline at end of file diff --git a/slh/ACVP_ver_test b/slh/ACVP_ver_test new file mode 100755 index 00000000..7b92bf79 Binary files /dev/null and b/slh/ACVP_ver_test differ diff --git a/slh/ACVP_ver_test.c b/slh/ACVP_ver_test.c new file mode 100644 index 00000000..32d279dd --- /dev/null +++ b/slh/ACVP_ver_test.c @@ -0,0 +1,84 @@ + +#include +#include +#include + +#include "ACVP_tests.h" +#include "slh_dsa.h" + +/* +arg 1 := tgId +arg 2 := tcId +arg 3 := prmSet +arg 4 := interface +arg 5 := m_sz +arg 6 := m +arg 7 := ctxlen +arg 8 := ctxStr +arg 9 := pk +arg 10 := sig +arg 11 := validSig +*/ + + +int main(int argc, char *argv[]) +{ + /* one of the args is arg 0 */ + if(argc != 12) + { + printf("Arg count is %d! \r\n", argc); + printf("Usage: ./ACVP_ver_test \r\n"); + return -1; + } + + /* process arg vars */ + uint8_t tgId = (uint8_t)atoi(argv[1]); + uint8_t tcId = (uint8_t)atoi(argv[2]); + selectPrmSet(argv[3],&prmSet_g); + interface_e interface = (!strcmp(argv[5],"internal")) ? SLH_INTERNAL : SLH_EXTERNAL; + size_t m_sz = atoi(argv[5])/2; + uint8_t *m = (uint8_t *)malloc(m_sz * sizeof(uint8_t)); + if(m_sz != 0) + { + hexStringToByteArray(argv[6], m); + } + size_t ctxLen = atoi(argv[7])/2; + uint8_t *ctxStr = (uint8_t *)malloc(ctxLen * sizeof(uint8_t)); + if(interface == SLH_EXTERNAL && ctxLen != 0) + { + hexStringToByteArray(argv[8], ctxStr); + } + uint8_t pk[MAX_PK_BYTES] = {0}; + hexStringToByteArray(argv[9], pk); + size_t sig_sz = strlen(argv[10]); + uint8_t *sig = (uint8_t *)malloc(sig_sz * sizeof(uint8_t)); + memset(sig,0,sig_sz); + bool validSig = (!strcmp(argv[11],"true")); + bool result; + + if(interface == SLH_EXTERNAL) + { + result = slh_verify(m,m_sz,sig,ctxStr,ctxLen,pk,prmSet_g); + } + else { + result = slh_verify_internal(ctxStr,0,m,m_sz,sig,pk,prmSet_g); + } + + if(result != validSig) + { + printf("Verification result does not match expectation for (tgId,tcId): (%d,%d)! \r\n", tgId, tcId); + + free(m); + free(ctxStr); + free(sig); + + return -1; + } + printf("Verification result matches expectation for (tgId,tcId): (%d,%d)! \r\n", tgId, tcId); + + free(m); + free(ctxStr); + free(sig); + + return 0; +} \ No newline at end of file diff --git a/slh/Makefile b/slh/Makefile index 03da1e29..4f15a21f 100644 --- a/slh/Makefile +++ b/slh/Makefile @@ -5,9 +5,25 @@ # No hardware acceleration if used like this -- mainly useful for # debug purposes. -XBIN ?= kat_test +TARGET ?= kat_test CSRC = $(wildcard *.c) +DEBUG ?= 0 + +EXCLUDE_LIST = kat_test.c ACVP_keygen_test.c ACVP_sig_test.c ACVP_ver_test.c + +EXCLUDE_LIST := $(filter-out $(TARGET).c,$(EXCLUDE_LIST)) +CSRC := $(filter-out $(EXCLUDE_LIST),$(CSRC)) + +# ifeq ($(TARGET),kat_test) +# CSRC += ../drv/kat_drbg.c +# endif CSRC += ../drv/kat_drbg.c + +ifeq ($(DEBUG),1) + CFLAGS += -g -O0 -DDEBUG + BUILD_MODE := debug +endif + OBJS = $(CSRC:.c=.o) KATNUM ?= 1 @@ -17,17 +33,16 @@ CFLAGS += -Wall -march=native -Ofast -DNDEBUG CFLAGS += -I. -I../drv -DKATNUM=$(KATNUM) LDLIBS += -$(XBIN): $(OBJS) - $(CC) $(LDFLAGS) $(CFLAGS) -o $(XBIN) $(OBJS) $(LDLIBS) +$(TARGET): $(OBJS) + $(CC) $(LDFLAGS) $(CFLAGS) -o $(TARGET) $(OBJS) $(LDLIBS) %.o: %.[cS] $(CC) $(CFLAGS) -c $^ -o $@ -test: $(XBIN) - ./$(XBIN) +test: $(TARGET) + ./$(TARGET) sha256sum *.rsp > t.log cat t.log ../kat/kat$(KATNUM)-sha256.txt | sort | uniq -c -w 64 @echo "NOTE >2<- indicates test vector comparison match, 1 is a failure." clean: - $(RM) -rf $(XBIN) $(OBJS) *.rsp *.req *.log - + $(RM) -rf $(TARGET) $(OBJS) *.rsp *.req *.log \ No newline at end of file diff --git a/slh/kat_test.c b/slh/kat_test.c index 69af7975..6baf04e6 100644 --- a/slh/kat_test.c +++ b/slh/kat_test.c @@ -28,14 +28,14 @@ int iut_randombytes(uint8_t *x, size_t xlen) } static void kat_hex(FILE *fh, const char *label, - const uint8_t *x, size_t xlen) + const uint8_t *x, size_t xlen) { - size_t i; - fprintf(fh, "%s = ", label); - for (i = 0; i < xlen; i++) { - fprintf(fh, "%02X", x[i]); - } - fprintf(fh, "\n"); +size_t i; +fprintf(fh, "%s = ", label); +for (i = 0; i < xlen; i++) { +fprintf(fh, "%02X", x[i]); +} +fprintf(fh, "\n"); } int kat_test(const slh_param_t *iut, int katnum) @@ -99,7 +99,10 @@ int kat_test(const slh_param_t *iut, int katnum) kat_hex(fh, "pk", pk, pk_sz); kat_hex(fh, "sk", sk, sk_sz); - sm_sz = slh_sign(sm, msg, msg_sz, sk, &iut_randombytes, iut); + uint8_t ctx_str[4]; + size_t ctx_str_len = 4; + + sm_sz = slh_sign(sm, msg, msg_sz, ctx_str,ctx_str_len, sk, &iut_randombytes, iut); memcpy(sm + sm_sz, msg, msg_sz); sm_sz += msg_sz; @@ -109,7 +112,7 @@ int kat_test(const slh_param_t *iut, int katnum) fprintf(fh, "\n"); assert(sm_sz == sig_sz + msg_sz); - if (!slh_verify(sm + sig_sz, msg_sz, sm, pk, iut)) { + if (!slh_verify(sm + sig_sz, msg_sz, sm,ctx_str, ctx_str_len, pk, iut)) { fail++; fprintf(stderr, "[FAIL] slh_verify() fails.\n"); } @@ -121,7 +124,7 @@ int kat_test(const slh_param_t *iut, int katnum) (((uint32_t) seed[7]) << 24); xbit %= (8 * sm_sz); sm[xbit >> 3] ^= 1 << (xbit & 7); - if (slh_verify(sm + sig_sz, msg_sz, sm, pk, iut)) { + if (slh_verify(sm + sig_sz, msg_sz, sm, pk, iut,ctx_str,ctx_str_len)) { fail++; fprintf(stderr, "[FAIL] slh_verify() forgery bit= %u.\n", xbit); } @@ -159,9 +162,10 @@ int main(int argc, char **argv) iut_n < 12) { fail += kat_test(test_iut[iut_n], 1); } else { - for (iut_n = 0; test_iut[iut_n] != NULL; iut_n++) { - fail += kat_test(test_iut[iut_n], KATNUM); - } + // for (iut_n = 0; test_iut[iut_n] != NULL; iut_n++) { + // fail += kat_test(test_iut[iut_n], KATNUM); + // } + kat_test(&slh_dsa_shake_128s, KATNUM); } printf("[INFO] test_slh_dsa() fail= %d\n", fail); diff --git a/slh/slh_ctx.h b/slh/slh_ctx.h index de9b29f4..b0e49053 100644 --- a/slh/slh_ctx.h +++ b/slh/slh_ctx.h @@ -10,12 +10,13 @@ #include "sha2_api.h" // some structural sizes -#define SLH_MAX_N 32 -#define SLH_MAX_LEN (2 * SLH_MAX_N + 3) -#define SLH_MAX_K 35 -#define SLH_MAX_M 49 -#define SLH_MAX_HP 9 -#define SLH_MAX_A 14 +#define SLH_MAX_N 32 +#define SLH_MAX_LEN (2 * SLH_MAX_N + 3) +#define SLH_MAX_K 35 +#define SLH_MAX_M 49 +#define SLH_MAX_HP 9 +#define SLH_MAX_A 14 +#define SLH_MAX_CTX_STR_LEN 255 // context struct slh_ctx_s { diff --git a/slh/slh_dsa.c b/slh/slh_dsa.c index 2997a22f..c0cec53c 100644 --- a/slh/slh_dsa.c +++ b/slh/slh_dsa.c @@ -7,6 +7,7 @@ #include "slh_ctx.h" #include "slh_adrs.h" #include +#include // === Internal @@ -39,7 +40,7 @@ size_t slh_sig_sz(const slh_param_t *prm) } // === Compute the base 2**b representation of X. -// Algorithm 3: base_2b(X, b, out_len) +// Algorithm 4: base_2b(X, b, out_len) static inline size_t base_2b( uint32_t *v, const uint8_t *x, uint32_t b, size_t v_len) @@ -89,15 +90,15 @@ static inline size_t base_16( uint32_t *v, const uint8_t *x, int v_len) } // === Chaining function used in WOTS+ -// Algorithm 4: chain(X, i, s, PK.seed, ADRS) +// Algorithm 5: chain(X, i, s, PK.seed, ADRS) // (see prm->chain) // === Generate a WOTS+ public key. -// Algorithm 5: wots_PKgen(SK.seed, PK.seed, ADRS) +// Algorithm 6: wots_PKgen(SK.seed, PK.seed, ADRS) // (see xmms_node) // === Generate a WOTS+ signature on an n-byte message. -// Algorithm 6: wots_sign(M, SK.seed, PK.seed, ADRS) +// Algorithm 7: wots_sign(M, SK.seed, PK.seed, ADRS) // (wots_csum is a shared helper function for algorithms 6 and 7) static void wots_csum(uint32_t *vm, const uint8_t *m, const slh_param_t *prm) @@ -146,7 +147,7 @@ static size_t wots_sign(slh_ctx_t *ctx, uint8_t *sig, const uint8_t *m) } // === Compute a WOTS+ public key from a message and its signature. -// Algorithm 7: wots_PKFromSig(sig, M, PK.seed, ADRS) +// Algorithm 8: wots_PKFromSig(sig, M, PK.seed, ADRS) static void wots_pk_from_sig( slh_ctx_t *ctx, uint8_t *pk, const uint8_t *sig, @@ -175,7 +176,7 @@ static void wots_pk_from_sig( slh_ctx_t *ctx, uint8_t *pk, } // === Compute the root of a Merkle subtree of WOTS+ public keys. -// Algorithm 8: xmss_node(SK.seed, i, z, PK.seed, ADRS) +// Algorithm 9: xmss_node(SK.seed, i, z, PK.seed, ADRS) static void xmss_node( slh_ctx_t *ctx, uint8_t *node, uint32_t i, uint32_t z) @@ -196,7 +197,7 @@ static void xmss_node( slh_ctx_t *ctx, uint8_t *node, adrs_set_key_pair_address(ctx, i); // === Generate a WOTS+ public key. - // Algorithm 5: wots_PKgen(SK.seed, PK.seed, ADRS) + // Algorithm 6: wots_PKgen(SK.seed, PK.seed, ADRS) sk = tmp; for (k = 0; k < len; k++) { adrs_set_chain_address(ctx, k); @@ -222,7 +223,7 @@ static void xmss_node( slh_ctx_t *ctx, uint8_t *node, } // === Generate an XMSS signature. -// Algorithm 9: xmss_sign(M, SK.seed, idx, PK.seed, ADRS) +// Algorithm 10: xmss_sign(M, SK.seed, idx, PK.seed, ADRS) static size_t xmss_sign(slh_ctx_t *ctx, uint8_t *sx, const uint8_t *m, uint32_t idx) @@ -252,7 +253,7 @@ static size_t xmss_sign(slh_ctx_t *ctx, uint8_t *sx, const uint8_t *m, } // === Compute an XMSS public key from an XMSS signature. -// Algorithm 10: xmss_PKFromSig(idx, SIGXMSS, M, PK.seed, ADRS) +// Algorithm 11: xmss_PKFromSig(idx, SIGXMSS, M, PK.seed, ADRS) static void xmss_pk_from_sig( slh_ctx_t *ctx, uint8_t *root, uint32_t idx, const uint8_t *sig, const uint8_t *m) @@ -287,7 +288,7 @@ static void xmss_pk_from_sig( slh_ctx_t *ctx, uint8_t *root, uint32_t idx, // === Generate a hypertree signature. -// Algorithm 11: ht_sign(M, SK.seed, PK.seed, idx_tree, idx_leaf ) +// Algorithm 12: ht_sign(M, SK.seed, PK.seed, idx_tree, idx_leaf ) static size_t ht_sign( slh_ctx_t *ctx, uint8_t *sh, uint8_t *m, uint64_t i_tree, uint32_t i_leaf) @@ -317,7 +318,7 @@ static size_t ht_sign( slh_ctx_t *ctx, uint8_t *sh, uint8_t *m, // === Verify a hypertree signature. -// Algorithm 12: ht_verify(M, SIG_HT, PK.seed, idx_tree, idx_leaf, PK.root) +// Algorithm 13: ht_verify(M, SIG_HT, PK.seed, idx_tree, idx_leaf, PK.root) static bool ht_verify( slh_ctx_t *ctx, const uint8_t *m, const uint8_t *sig_ht, @@ -352,12 +353,12 @@ static bool ht_verify( slh_ctx_t *ctx, const uint8_t *m, } // === Generate a FORS private-key value. -// Algorithm 13: fors_SKgen(SK.seed, PK.seed, ADRS, idx) +// Algorithm 14: fors_SKgen(SK.seed, PK.seed, ADRS, idx) // ( see prm->fors_hash() ) // === Compute the root of a Merkle subtree of FORS public values. -// Algorithm 14: fors_node(SK.seed, i, z, PK.seed, ADRS) +// Algorithm 15: fors_node(SK.seed, i, z, PK.seed, ADRS) static void fors_node( slh_ctx_t *ctx, uint8_t *node, uint32_t i, uint32_t z) @@ -392,7 +393,7 @@ static void fors_node( slh_ctx_t *ctx, uint8_t *node, // === Generate a FORS signature. -// Algorithm 15: fors_sign(md, SK.seed, PK.seed, ADRS) +// Algorithm 16: fors_sign(md, SK.seed, PK.seed, ADRS) static size_t fors_sign(slh_ctx_t *ctx, uint8_t *sf, const uint8_t *md) { @@ -421,7 +422,7 @@ static size_t fors_sign(slh_ctx_t *ctx, uint8_t *sf, const uint8_t *md) } // === Compute a FORS public key from a FORS signature. -// Algorithm 16: fors_pkFromSig(SIGFORS , md, PK.seed, ADRS) +// Algorithm 17: fors_pkFromSig(SIGFORS , md, PK.seed, ADRS) static void fors_pk_from_sig( slh_ctx_t *ctx, uint8_t *pk, const uint8_t *sf, const uint8_t *md) @@ -486,15 +487,30 @@ size_t slh_sk_sz(const slh_param_t *prm) return 4 * prm->n; } -// === Generate an SLH-DSA key pair. -// Algorithm 17: slh_keygen() +// === Deterministic portion of SLH-DSA key pair generation. +// Algorithm 18: slh_keygen_internal() -int slh_keygen(uint8_t *pk, uint8_t *sk, - int (*rbg)(uint8_t *x, size_t xlen), const slh_param_t *prm) +int slh_keygen_internal(uint8_t *pk, uint8_t *sk, const slh_param_t *prm, slh_ctx_t *ctx) { + uint8_t pk_root[SLH_MAX_N]; + size_t n = prm->n; + + adrs_zero(ctx); + adrs_set_layer_address(ctx, prm->d - 1); + xmss_node(ctx, pk_root, 0, prm->hp); + // fill pk_root + memcpy(sk + 3 * n, pk_root, n); + memcpy(pk + n, pk_root, n); + return 0; +} + +// === Random portion of SLH-DSA key pair generation. +// Algorithm 21: slh_keygen() +int slh_keygen(uint8_t *pk, uint8_t *sk, int (*rbg)(uint8_t *x, size_t xlen), + const slh_param_t *prm) +{ slh_ctx_t ctx; - uint8_t pk_root[SLH_MAX_N]; size_t n = prm->n; rbg(sk, 3 * n); // SK.seed || SK.prf || PK.seed @@ -502,18 +518,11 @@ int slh_keygen(uint8_t *pk, uint8_t *sk, memset(sk + 3 * n, 0x00, n); // PK.root not generated yet prm->mk_ctx(&ctx, NULL, sk, prm); // fill in partial - adrs_zero(&ctx); - adrs_set_layer_address(&ctx, prm->d - 1); - xmss_node(&ctx, pk_root, 0, prm->hp); - - // fill pk_root - memcpy(sk + 3 * n, pk_root, n); - memcpy(pk + n, pk_root, n); - return 0; + return slh_keygen_internal(pk, sk, prm, &ctx); } // === Generate an SLH-DSA signature. -// Algorithm 18: slh_sign(M, SK) +// Algorithm 22: slh_sign(M, SK) // (Shared helper function for algorithms 18 and 19.) @@ -563,28 +572,32 @@ size_t slh_do_sign( slh_ctx_t *ctx, uint8_t *sig, const uint8_t *digest) return sig_sz; } -size_t slh_sign(uint8_t *sig, const uint8_t *m, size_t m_sz, - const uint8_t *sk, int (*rbg)(uint8_t *x, size_t xlen), - const slh_param_t *prm) +// Deterministic portion of signing +// Algorithm 19: slh_sign_internal(M, SK, addrnd) +size_t slh_sign_internal(uint8_t *sig, + const uint8_t *m, size_t m_sz, const uint8_t *sk, + const uint8_t *pre, size_t pre_sz, + const slh_param_t *prm, uint8_t *add_rnd) { - slh_ctx_t ctx; uint8_t opt_rand[SLH_MAX_N]; uint8_t digest[SLH_MAX_M]; + slh_ctx_t ctx; // set up secret key etc prm->mk_ctx(&ctx, NULL, sk, prm); -#ifdef SLH_DETERMINISTIC - memcpy(opt_rand, ctx.pk_seed, prm->n); -#else - rbg(opt_rand, prm->n); -#endif + #ifdef SLH_DETERMINISTIC + memcpy(opt_rand, ctx.pk_seed, prm->n); + #endif + #ifndef SLH_DETERMINISTIC + memcpy(opt_rand, add_rnd, prm->n); + #endif // randomized hashing; R uint8_t *r = sig; size_t sig_sz = prm->n; - prm->prf_msg(&ctx, r, opt_rand, m, m_sz); - prm->h_msg(&ctx, digest, r, m, m_sz); + prm->prf_msg(&ctx, r, opt_rand, pre, pre_sz, m, m_sz); + prm->h_msg(&ctx, digest, r, pre, pre_sz, m, m_sz); // create FORS and HT signature parts sig_sz += slh_do_sign(&ctx, sig + sig_sz, digest); @@ -592,10 +605,46 @@ size_t slh_sign(uint8_t *sig, const uint8_t *m, size_t m_sz, return sig_sz; } +// Pure signing wrapper function +size_t slh_sign(uint8_t *sig, const uint8_t *m, size_t m_sz, + uint8_t *ctx_str, size_t ctx_str_len, const uint8_t *sk, + int (*rbg)(uint8_t *x, size_t xlen), const slh_param_t *prm) +{ + if (ctx_str_len > SLH_MAX_CTX_STR_LEN) + { + return 0; + } + + uint8_t add_rnd[SLH_MAX_N]; + + #ifdef SLH_DETERMINISTIC + // add_rnd not needed here so is uninitialized + #endif + #ifndef SLH_DETERMINISTIC + if (rbg(add_rnd, prm->n) != 0) + { + return 0; + } + #endif + + uint8_t pre[MAX_PRE_SIZE]; + size_t pre_sz = 2 + ctx_str_len; + + pre[0] = 0; + pre[1] = ctx_str_len; + for (size_t i = 0; i < ctx_str_len; i++) + { + pre[2 + i] = ctx_str[i]; + } + + return slh_sign_internal(sig,m,m_sz,sk,pre,pre_sz,prm,add_rnd); +} + // === Verify an SLH-DSA signature. -// Algorithm 19: slh_verify(M, SIG, PK) +// Algorithm 20: slh_verify_internal(M, SIG, PK) -bool slh_verify(const uint8_t *m, size_t m_sz, +bool slh_verify_internal(const uint8_t *pre, size_t pre_sz, + const uint8_t *m, size_t m_sz, const uint8_t *sig, const uint8_t *pk, const slh_param_t *prm) { @@ -609,7 +658,7 @@ bool slh_verify(const uint8_t *m, size_t m_sz, const uint8_t *sig_ht = sig + ((1 + prm->k*(1 + prm->a)) * prm->n); prm->mk_ctx(&ctx, pk, NULL, prm); - prm->h_msg(&ctx, digest, r, m, m_sz); + prm->h_msg(&ctx, digest, r, pre, pre_sz, m, m_sz); const uint8_t *md = digest; uint64_t i_tree = 0; @@ -627,3 +676,26 @@ bool slh_verify(const uint8_t *m, size_t m_sz, return sig_ok; } +// Algorithm 24: slh_verify(M, SIG, PK) +// Pure signature verification +bool slh_verify(const uint8_t *m, size_t m_sz, + const uint8_t *sig, uint8_t *ctx_str, size_t ctx_str_len, + const uint8_t *pk, const slh_param_t *prm) +{ + if (ctx_str_len > SLH_MAX_CTX_STR_LEN) + { + return false; + } + + uint8_t pre[MAX_PRE_SIZE]; + size_t pre_sz = 2 + ctx_str_len; + + pre[0] = 0; + pre[1] = ctx_str_len; + for (size_t i = 0; i < ctx_str_len; i++) + { + pre[2 + i] = ctx_str[i]; + } + + return slh_verify_internal(pre,pre_sz,m,m_sz,sig,pk,prm); +} diff --git a/slh/slh_dsa.h b/slh/slh_dsa.h index 39a5f698..75e0e782 100644 --- a/slh/slh_dsa.h +++ b/slh/slh_dsa.h @@ -14,8 +14,28 @@ extern "C" { #include #include +#define MAX_PRE_SIZE 257 +#define SLH_DETERMINISTIC + typedef struct slh_param_s slh_param_t; +// struct +// { +// char *skSeedString; +// char *skPrfString; +// char *pkSeedString; +// } seed_g; + +// const slh_param_t *prmSet_g; + +// enum +// { +// SLH_DETERMINISTIC, +// SLH_NON_DETERMINISTIC +// } deterministic_g; + +// uint8_t addRnd_g[32]; + // === SLH-DSA parameter sets extern const slh_param_t slh_dsa_sha2_128s; extern const slh_param_t slh_dsa_shake_128s; @@ -49,17 +69,26 @@ int slh_keygen( uint8_t *pk, uint8_t *sk, int (*rbg)(uint8_t *x, size_t xlen), const slh_param_t *prm); +// Internal interface to slh_sign +size_t slh_sign_internal(uint8_t *sig, + const uint8_t *m, size_t m_sz, const uint8_t *sk, + const uint8_t *pre, size_t pre_sz, + const slh_param_t *prm, uint8_t *add_rnd); + // Generate a SLH-DSA signature. -size_t slh_sign(uint8_t *sig, - const uint8_t *m, size_t m_sz, - const uint8_t *sk, - int (*rbg)(uint8_t *x, size_t xlen), - const slh_param_t *prm); +size_t slh_sign(uint8_t *sig, const uint8_t *m, size_t m_sz, + uint8_t *ctx_str, size_t ctx_str_len, const uint8_t *sk, + int (*rbg)(uint8_t *x, size_t xlen), const slh_param_t *prm); // Verify an SLH-DSA signature. bool slh_verify(const uint8_t *m, size_t m_sz, - const uint8_t *sig, const uint8_t *pk, - const slh_param_t *prm); + const uint8_t *sig, uint8_t *ctx_str, size_t ctx_str_len, + const uint8_t *pk, const slh_param_t *prm); + +bool slh_verify_internal(const uint8_t *pre, size_t pre_sz, + const uint8_t *m, size_t m_sz, + const uint8_t *sig, const uint8_t *pk, + const slh_param_t *prm); #ifdef __cplusplus } diff --git a/slh/slh_param.h b/slh/slh_param.h index 1df47076..46841c35 100644 --- a/slh/slh_param.h +++ b/slh/slh_param.h @@ -40,9 +40,11 @@ struct slh_param_s { void (*wots_chain)(slh_ctx_t *ctx, uint8_t *tmp, uint32_t s); void (*fors_hash)(slh_ctx_t *ctx, uint8_t *tmp, uint32_t s); void (*h_msg)(slh_ctx_t *ctx, uint8_t *h, const uint8_t *r, + const uint8_t *pre, size_t pre_sz, const uint8_t *m, size_t m_sz); void (*prf)(slh_ctx_t *ctx, uint8_t *h); void (*prf_msg)(slh_ctx_t *ctx, uint8_t *h, const uint8_t *opt_rand, + const uint8_t *pre, size_t pre_sz, const uint8_t *m, size_t m_sz); void (*h_f)(slh_ctx_t *ctx, uint8_t *h, const uint8_t *m1); void (*h_h)(slh_ctx_t *ctx, uint8_t *h, diff --git a/slh/slh_sha2.c b/slh/slh_sha2.c index afee1006..7b45bf9c 100644 --- a/slh/slh_sha2.c +++ b/slh/slh_sha2.c @@ -15,7 +15,8 @@ // MGF1-SHA-256(R || PK.seed || SHA-256(R ||PK.seed || PK.root || M), m) static void sha2_256_h_msg( slh_ctx_t *ctx, uint8_t *h, - const uint8_t *r, const uint8_t *m, size_t m_sz) + const uint8_t *r, const uint8_t *pre, size_t pre_sz, + const uint8_t *m, size_t m_sz) { sha256_t sha2; uint8_t mgf[16 + 16 + 32 + 4]; @@ -30,6 +31,7 @@ static void sha2_256_h_msg( slh_ctx_t *ctx, uint8_t *h, sha256_update(&sha2, r, n); sha256_update(&sha2, ctx->pk_seed, n); sha256_update(&sha2, ctx->pk_root, n); + sha256_update(&sha2, pre, pre_sz); sha256_update(&sha2, m, m_sz); sha256_final(&sha2, mgf + 2 * n); @@ -60,7 +62,8 @@ static void sha2_256_h_msg( slh_ctx_t *ctx, uint8_t *h, // MGF1-SHA-512(R || PK.seed || SHA-512(R || PK.seed || PK.root || M), m) static void sha2_512_h_msg( slh_ctx_t *ctx, uint8_t *h, - const uint8_t *r, const uint8_t *m, size_t m_sz) + const uint8_t *r, const uint8_t *pre, size_t pre_sz, + const uint8_t *m, size_t m_sz) { sha512_t sha2; uint8_t mgf[32 + 32 + 64 + 4]; @@ -75,6 +78,7 @@ static void sha2_512_h_msg( slh_ctx_t *ctx, uint8_t *h, sha512_update(&sha2, r, n); sha512_update(&sha2, ctx->pk_seed, n); sha512_update(&sha2, ctx->pk_root, n); + sha512_update(&sha2, pre, pre_sz); sha512_update(&sha2, m, m_sz); sha512_final(&sha2, mgf + 2 * n); @@ -139,6 +143,7 @@ static void sha256_prf( slh_ctx_t *ctx, static void sha256_prf_msg( slh_ctx_t *ctx, uint8_t *h, const uint8_t *opt_rand, + const uint8_t *pre, size_t pre_sz, const uint8_t *m, size_t m_sz) { unsigned i; @@ -156,6 +161,7 @@ static void sha256_prf_msg( slh_ctx_t *ctx, sha256_init(&sha2); sha256_update(&sha2, pad, 64); sha256_update(&sha2, opt_rand, n); + sha256_update(&sha2, pre, pre_sz); sha256_update(&sha2, m, m_sz); sha256_final(&sha2, buf); @@ -175,6 +181,7 @@ static void sha256_prf_msg( slh_ctx_t *ctx, static void sha512_prf_msg( slh_ctx_t *ctx, uint8_t *h, const uint8_t *opt_rand, + const uint8_t *pre, size_t pre_sz, const uint8_t *m, size_t m_sz) { unsigned i; @@ -192,6 +199,7 @@ static void sha512_prf_msg( slh_ctx_t *ctx, sha512_init(&sha2); sha512_update(&sha2, pad, 128); sha512_update(&sha2, opt_rand, n); + sha512_update(&sha2, pre, pre_sz); sha512_update(&sha2, m, m_sz); sha512_final(&sha2, buf); diff --git a/slh/slh_shake.c b/slh/slh_shake.c index 77c1f2a8..8cd0255f 100644 --- a/slh/slh_shake.c +++ b/slh/slh_shake.c @@ -16,6 +16,7 @@ static void shake_h_msg( slh_ctx_t *ctx, uint8_t *h, const uint8_t *r, + const uint8_t *pre, size_t pre_sz, const uint8_t *m, size_t m_sz) { sha3_ctx_t sha3; @@ -25,6 +26,7 @@ static void shake_h_msg( slh_ctx_t *ctx, shake_update(&sha3, r, n); shake_update(&sha3, ctx->pk_seed, n); shake_update(&sha3, ctx->pk_root, n); + shake_update(&sha3, pre, pre_sz); shake_update(&sha3, m, m_sz); shake_out(&sha3, h, ctx->prm->m); @@ -59,6 +61,7 @@ static void shake_prf(slh_ctx_t *ctx, uint8_t *h) static void shake_prf_msg( slh_ctx_t *ctx, uint8_t *h, const uint8_t *opt_rand, + const uint8_t *pre, size_t pre_sz, const uint8_t *m, size_t m_sz) { sha3_ctx_t sha3; @@ -67,6 +70,7 @@ static void shake_prf_msg( slh_ctx_t *ctx, shake256_init(&sha3); shake_update(&sha3, ctx->sk_prf, n); shake_update(&sha3, opt_rand, n); + shake_update(&sha3, pre, pre_sz); shake_update(&sha3, m, m_sz); shake_out(&sha3, h, n);