diff --git a/src/libutil/hash.cc b/src/libutil/hash.cc index 36de293bb..e8d290d13 100644 --- a/src/libutil/hash.cc +++ b/src/libutil/hash.cc @@ -132,17 +132,24 @@ std::string Hash::to_string(Base base, bool includeType) const return s; } -Hash fromSRI(std::string_view original) { +Hash Hash::fromSRI(std::string_view original) { + auto rest = original; + // Parse the has type before the separater, if there was one. + auto hashRaw = splitPrefix(rest, '-'); + if (!hashRaw) + throw BadHash("hash '%s' is not SRI", original); + HashType parsedType = parseHashType(*hashRaw); + + return Hash(rest, std::make_pair(parsedType, true)); } -Hash::Hash(std::string_view s) : Hash(s, std::nullopt) { } +Hash::Hash(std::string_view s) : Hash(s, std::nullopt) {} -static HashType newFunction(std::string_view & rest, std::optional optType) +static std::pair newFunction(std::string_view & original, std::optional optType) { auto rest = original; - size_t pos = 0; bool isSRI = false; // Parse the has type before the separater, if there was one. @@ -166,16 +173,23 @@ static HashType newFunction(std::string_view & rest, std::optional opt } else { if (optParsedType && optType && *optParsedType != *optType) throw BadHash("hash '%s' should have type '%s'", original, printHashType(*optType)); - return optParsedType ? *optParsedType : *optType; + return { + optParsedType ? *optParsedType : *optType, + isSRI, + }; } } // mutates the string_view Hash::Hash(std::string_view original, std::optional optType) - : Hash(original, newFunction(original, optType)) + : Hash(original, newFunction(original, optType)) {} -Hash::Hash(std::string_view original, HashType type) : Hash(type) { - if (!isSRI && rest.size() == base16Len()) { +Hash::Hash(std::string_view original, std::pair typeAndSRI) + : Hash(typeAndSRI.first) +{ + auto [type, isSRI] = std::move(typeAndSRI); + + if (!isSRI && original.size() == base16Len()) { auto parseHexDigit = [&](char c) { if (c >= '0' && c <= '9') return c - '0'; @@ -186,15 +200,15 @@ Hash::Hash(std::string_view original, HashType type) : Hash(type) { for (unsigned int i = 0; i < hashSize; i++) { hash[i] = - parseHexDigit(rest[pos + i * 2]) << 4 - | parseHexDigit(rest[pos + i * 2 + 1]); + parseHexDigit(original[i * 2]) << 4 + | parseHexDigit(original[i * 2 + 1]); } } - else if (!isSRI && rest.size() == base32Len()) { + else if (!isSRI && original.size() == base32Len()) { - for (unsigned int n = 0; n < rest.size(); ++n) { - char c = rest[rest.size() - n - 1]; + for (unsigned int n = 0; n < original.size(); ++n) { + char c = original[original.size() - n - 1]; unsigned char digit; for (digit = 0; digit < base32Chars.size(); ++digit) /* !!! slow */ if (base32Chars[digit] == c) break; @@ -214,8 +228,8 @@ Hash::Hash(std::string_view original, HashType type) : Hash(type) { } } - else if (isSRI || rest.size() == base64Len()) { - auto d = base64Decode(rest); + else if (isSRI || original.size() == base64Len()) { + auto d = base64Decode(original); if (d.size() != hashSize) throw BadHash("invalid %s hash '%s'", isSRI ? "SRI" : "base-64", original); assert(hashSize); @@ -223,7 +237,7 @@ Hash::Hash(std::string_view original, HashType type) : Hash(type) { } else - throw BadHash("hash '%s' has wrong length for hash type '%s'", rest, printHashType(this->type)); + throw BadHash("hash '%s' has wrong length for hash type '%s'", original, printHashType(this->type)); } Hash newHashAllowEmpty(std::string hashStr, std::optional ht) diff --git a/src/libutil/hash.hh b/src/libutil/hash.hh index 42bd585a2..570e50dd4 100644 --- a/src/libutil/hash.hh +++ b/src/libutil/hash.hh @@ -43,9 +43,11 @@ struct Hash // hash type must be part of string Hash(std::string_view s); + Hash fromSRI(std::string_view original); + private: // type must be provided, s must not include prefix - Hash(std::string_view s, HashType type); + Hash(std::string_view s, std::pair typeAndSRI); void init(); diff --git a/src/libutil/util.cc b/src/libutil/util.cc index 1268b146a..ed43c403f 100644 --- a/src/libutil/util.cc +++ b/src/libutil/util.cc @@ -1433,7 +1433,7 @@ string base64Decode(std::string_view s) char digit = decode[(unsigned char) c]; if (digit == -1) - throw Error("invalid character in Base64 string"); + throw Error("invalid character in Base64 string: '%c'", c); bits += 6; d = d << 6 | digit;