diff --git a/modules/juce_core/maths/juce_BigInteger.cpp b/modules/juce_core/maths/juce_BigInteger.cpp index d55124506f..ca1405e1bd 100644 --- a/modules/juce_core/maths/juce_BigInteger.cpp +++ b/modules/juce_core/maths/juce_BigInteger.cpp @@ -28,76 +28,110 @@ namespace { - inline size_t bitToIndex (const int bit) noexcept { return (size_t) (bit >> 5); } - inline uint32 bitToMask (const int bit) noexcept { return (uint32) 1 << (bit & 31); } + inline uint32 bitToMask (const int bit) noexcept { return (uint32) 1 << (bit & 31); } + inline size_t bitToIndex (const int bit) noexcept { return (size_t) (bit >> 5); } + inline size_t sizeNeededToHold (int highestBit) noexcept { return (size_t) (highestBit >> 5) + 1; } + + inline int highestBitInInt (uint32 n) noexcept + { + jassert (n != 0); // (the built-in functions may not work for n = 0) + + #if JUCE_GCC || JUCE_CLANG + return 31 - __builtin_clz (n); + #elif JUCE_MSVC + unsigned long highest; + _BitScanReverse (&highest, n); + return (int) highest; + #else + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return countBitsInInt32 (n >> 1); + #endif + } } //============================================================================== BigInteger::BigInteger() - : numValues (4), + : allocatedSize (numPreallocatedInts), highestBit (-1), negative (false) { - values.calloc (numValues + 1); + for (int i = 0; i < numPreallocatedInts; ++i) + preallocated[i] = 0; } BigInteger::BigInteger (const int32 value) - : numValues (4), + : allocatedSize (numPreallocatedInts), highestBit (31), negative (value < 0) { - values.calloc (numValues + 1); - values[0] = (uint32) std::abs (value); + preallocated[0] = (uint32) std::abs (value); + + for (int i = 1; i < numPreallocatedInts; ++i) + preallocated[i] = 0; + highestBit = getHighestBit(); } BigInteger::BigInteger (const uint32 value) - : numValues (4), + : allocatedSize (numPreallocatedInts), highestBit (31), negative (false) { - values.calloc (numValues + 1); - values[0] = value; + preallocated[0] = value; + + for (int i = 1; i < numPreallocatedInts; ++i) + preallocated[i] = 0; + highestBit = getHighestBit(); } BigInteger::BigInteger (int64 value) - : numValues (4), + : allocatedSize (numPreallocatedInts), highestBit (63), negative (value < 0) { - values.calloc (numValues + 1); - if (value < 0) value = -value; - values[0] = (uint32) value; - values[1] = (uint32) (value >> 32); + preallocated[0] = (uint32) value; + preallocated[1] = (uint32) (value >> 32); + + for (int i = 2; i < numPreallocatedInts; ++i) + preallocated[i] = 0; + highestBit = getHighestBit(); } BigInteger::BigInteger (const BigInteger& other) - : numValues ((size_t) jmax ((size_t) 4, bitToIndex (other.highestBit) + 1)), + : allocatedSize (other.allocatedSize), highestBit (other.getHighestBit()), negative (other.negative) { - values.malloc (numValues + 1); - memcpy (values, other.values, sizeof (uint32) * (numValues + 1)); + if (allocatedSize > numPreallocatedInts) + heapAllocation.malloc (allocatedSize); + + memcpy (getValues(), other.getValues(), sizeof (uint32) * allocatedSize); } #if JUCE_COMPILER_SUPPORTS_MOVE_SEMANTICS BigInteger::BigInteger (BigInteger&& other) noexcept - : values (static_cast&&> (other.values)), - numValues (other.numValues), + : heapAllocation (static_cast&&> (other.heapAllocation)), + allocatedSize (other.allocatedSize), highestBit (other.highestBit), negative (other.negative) { + memcpy (preallocated, other.preallocated, sizeof (preallocated)); } BigInteger& BigInteger::operator= (BigInteger&& other) noexcept { - values = static_cast&&> (other.values); - numValues = other.numValues; + heapAllocation = static_cast&&> (other.heapAllocation); + memcpy (preallocated, other.preallocated, sizeof (preallocated)); + allocatedSize = other.allocatedSize; highestBit = other.highestBit; negative = other.negative; return *this; @@ -110,8 +144,11 @@ BigInteger::~BigInteger() void BigInteger::swapWith (BigInteger& other) noexcept { - values.swapWith (other.values); - std::swap (numValues, other.numValues); + for (int i = 0; i < numPreallocatedInts; ++i) + std::swap (preallocated[i], other.preallocated[i]); + + heapAllocation.swapWith (other.heapAllocation); + std::swap (allocatedSize, other.allocatedSize); std::swap (highestBit, other.highestBit); std::swap (negative, other.negative); } @@ -121,44 +158,68 @@ BigInteger& BigInteger::operator= (const BigInteger& other) if (this != &other) { highestBit = other.getHighestBit(); - jassert (other.numValues >= 4); - numValues = (size_t) jmax ((size_t) 4, bitToIndex (highestBit) + 1); + allocatedSize = (size_t) jmax ((size_t) numPreallocatedInts, sizeNeededToHold (highestBit)); + + if (allocatedSize <= numPreallocatedInts) + heapAllocation.free(); + else + heapAllocation.malloc (allocatedSize); + + memcpy (getValues(), other.getValues(), sizeof (uint32) * allocatedSize); negative = other.negative; - values.malloc (numValues + 1); - memcpy (values, other.values, sizeof (uint32) * (numValues + 1)); } return *this; } -void BigInteger::ensureSize (const size_t numVals) +uint32* BigInteger::getValues() const noexcept { - if (numVals + 2 >= numValues) - { - size_t oldSize = numValues; - numValues = ((numVals + 2) * 3) / 2; - values.realloc (numValues + 1); + jassert (heapAllocation != nullptr || allocatedSize <= numPreallocatedInts); - while (oldSize < numValues) - values [oldSize++] = 0; + return heapAllocation != nullptr ? heapAllocation + : (uint32*) preallocated; +} + +uint32* BigInteger::ensureSize (const size_t numVals) +{ + if (numVals > allocatedSize) + { + size_t oldSize = allocatedSize; + allocatedSize = ((numVals + 2) * 3) / 2; + + if (heapAllocation == nullptr) + { + heapAllocation.calloc (allocatedSize); + memcpy (heapAllocation, preallocated, sizeof (uint32) * numPreallocatedInts); + } + else + { + heapAllocation.realloc (allocatedSize); + + for (uint32* values = getValues(); oldSize < allocatedSize; ++oldSize) + values[oldSize] = 0; + } } + + return getValues(); } //============================================================================== bool BigInteger::operator[] (const int bit) const noexcept { return bit <= highestBit && bit >= 0 - && ((values [bitToIndex (bit)] & bitToMask (bit)) != 0); + && ((getValues() [bitToIndex (bit)] & bitToMask (bit)) != 0); } int BigInteger::toInteger() const noexcept { - const int n = (int) (values[0] & 0x7fffffff); + const int n = (int) (getValues()[0] & 0x7fffffff); return negative ? -n : n; } int64 BigInteger::toInt64() const noexcept { + const uint32* values = getValues(); const int64 n = (((int64) (values[1] & 0x7fffffff)) << 32) | values[0]; return negative ? -n : n; } @@ -167,13 +228,12 @@ BigInteger BigInteger::getBitRange (int startBit, int numBits) const { BigInteger r; numBits = jmin (numBits, getHighestBit() + 1 - startBit); - r.ensureSize ((size_t) bitToIndex (numBits)); + uint32* const destValues = r.ensureSize (sizeNeededToHold (numBits)); r.highestBit = numBits; - int i = 0; - while (numBits > 0) + for (int i = 0; numBits > 0;) { - r.values[i++] = getBitRangeAsInt (startBit, (int) jmin (32, numBits)); + destValues[i++] = getBitRangeAsInt (startBit, (int) jmin (32, numBits)); numBits -= 32; startBit += 32; } @@ -198,6 +258,7 @@ uint32 BigInteger::getBitRangeAsInt (const int startBit, int numBits) const noex const size_t pos = bitToIndex (startBit); const int offset = startBit & 31; const int endSpace = 32 - numBits; + const uint32* values = getValues(); uint32 n = ((uint32) values [pos]) >> offset; @@ -223,20 +284,15 @@ void BigInteger::setBitRangeAsInt (const int startBit, int numBits, uint32 value } //============================================================================== -void BigInteger::clear() +void BigInteger::clear() noexcept { - if (numValues > 16) - { - numValues = 4; - values.calloc (numValues + 1); - } - else - { - values.clear (numValues + 1); - } - + heapAllocation.free(); + allocatedSize = numPreallocatedInts; highestBit = -1; negative = false; + + for (int i = 0; i < numPreallocatedInts; ++i) + preallocated[i] = 0; } void BigInteger::setBit (const int bit) @@ -245,11 +301,11 @@ void BigInteger::setBit (const int bit) { if (bit > highestBit) { - ensureSize (bitToIndex (bit)); + ensureSize (sizeNeededToHold (bit)); highestBit = bit; } - values [bitToIndex (bit)] |= bitToMask (bit); + getValues() [bitToIndex (bit)] |= bitToMask (bit); } } @@ -264,7 +320,12 @@ void BigInteger::setBit (const int bit, const bool shouldBeSet) void BigInteger::clearBit (const int bit) noexcept { if (bit >= 0 && bit <= highestBit) - values [bitToIndex (bit)] &= ~bitToMask (bit); + { + getValues() [bitToIndex (bit)] &= ~bitToMask (bit); + + if (bit == highestBit) + highestBit = getHighestBit(); + } } void BigInteger::setRange (int startBit, int numBits, const bool shouldBeSet) @@ -311,31 +372,12 @@ void BigInteger::negate() noexcept #pragma intrinsic (_BitScanReverse) #endif -inline static int highestBitInInt (uint32 n) noexcept -{ - jassert (n != 0); // (the built-in functions may not work for n = 0) - - #if JUCE_GCC || JUCE_CLANG - return 31 - __builtin_clz (n); - #elif JUCE_MSVC - unsigned long highest; - _BitScanReverse (&highest, n); - return (int) highest; - #else - n |= (n >> 1); - n |= (n >> 2); - n |= (n >> 4); - n |= (n >> 8); - n |= (n >> 16); - return countBitsInInt32 (n >> 1); - #endif -} - int BigInteger::countNumberOfSetBits() const noexcept { int total = 0; + const uint32* values = getValues(); - for (int i = (int) bitToIndex (highestBit) + 1; --i >= 0;) + for (int i = (int) sizeNeededToHold (highestBit); --i >= 0;) total += countNumberOfBits (values[i]); return total; @@ -343,19 +385,19 @@ int BigInteger::countNumberOfSetBits() const noexcept int BigInteger::getHighestBit() const noexcept { - for (int i = (int) bitToIndex (highestBit + 1); i >= 0; --i) - { - const uint32 n = values[i]; + const uint32* values = getValues(); - if (n != 0) + for (int i = (int) bitToIndex (highestBit); i >= 0; --i) + if (uint32 n = values[i]) return highestBitInInt (n) + (i << 5); - } return -1; } int BigInteger::findNextSetBit (int i) const noexcept { + const uint32* values = getValues(); + for (; i <= highestBit; ++i) if ((values [bitToIndex (i)] & bitToMask (i)) != 0) return i; @@ -365,6 +407,8 @@ int BigInteger::findNextSetBit (int i) const noexcept int BigInteger::findNextClearBit (int i) const noexcept { + const uint32* values = getValues(); + for (; i <= highestBit; ++i) if ((values [bitToIndex (i)] & bitToMask (i)) == 0) break; @@ -396,23 +440,19 @@ BigInteger& BigInteger::operator+= (const BigInteger& other) } else { - if (other.highestBit > highestBit) - highestBit = other.highestBit; - - ++highestBit; - - const size_t numInts = bitToIndex (highestBit) + 1; - ensureSize (numInts); + highestBit = jmax (highestBit, other.highestBit) + 1; + const size_t numInts = sizeNeededToHold (highestBit); + uint32* const values = ensureSize (numInts); + const uint32* const otherValues = other.getValues(); int64 remainder = 0; - for (size_t i = 0; i <= numInts; ++i) + for (size_t i = 0; i < numInts; ++i) { - if (i < numValues) - remainder += values[i]; + remainder += values[i]; - if (i < other.numValues) - remainder += other.values[i]; + if (i < other.allocatedSize) + remainder += otherValues[i]; values[i] = (uint32) remainder; remainder >>= 32; @@ -449,14 +489,17 @@ BigInteger& BigInteger::operator-= (const BigInteger& other) return *this; } - const size_t numInts = bitToIndex (highestBit) + 1; - const size_t maxOtherInts = bitToIndex (other.highestBit) + 1; + const size_t numInts = sizeNeededToHold (highestBit); + const size_t maxOtherInts = sizeNeededToHold (other.highestBit); + jassert (numInts >= maxOtherInts); + uint32* const values = getValues(); + const uint32* const otherValues = other.getValues(); int64 amountToSubtract = 0; - for (size_t i = 0; i <= numInts; ++i) + for (size_t i = 0; i < numInts; ++i) { - if (i <= maxOtherInts) - amountToSubtract += (int64) other.values[i]; + if (i < maxOtherInts) + amountToSubtract += (int64) otherValues[i]; if (values[i] >= amountToSubtract) { @@ -487,26 +530,29 @@ BigInteger& BigInteger::operator*= (const BigInteger& other) n >>= 5; t >>= 5; - - total.ensureSize ((size_t) (n + t + 2)); + uint32* const totalValues = total.ensureSize (n + t + 2); BigInteger m (other); m.setNegative (false); + const uint32* const mValues = m.getValues(); + const uint32* const values = getValues(); + for (int i = 0; i <= t; ++i) { uint32 c = 0; for (int j = 0; j <= n; ++j) { - uint64 uv = (uint64) total.values[i + j] + (uint64) values[j] * (uint64) m.values[i] + (uint64) c; - total.values[i + j] = (uint32) uv; + uint64 uv = (uint64) totalValues[i + j] + (uint64) values[j] * (uint64) mValues[i] + (uint64) c; + totalValues[i + j] = (uint32) uv; c = uv >> 32; } - total.values[i + n + 1] = c; + totalValues[i + n + 1] = c; } + total.highestBit = total.getHighestBit(); total.setNegative (wasNegative ^ other.isNegative()); swapWith (total); @@ -571,12 +617,13 @@ BigInteger& BigInteger::operator|= (const BigInteger& other) if (other.highestBit >= 0) { - ensureSize (bitToIndex (other.highestBit)); + uint32* const values = ensureSize (sizeNeededToHold (other.highestBit)); + const uint32* const otherValues = other.getValues(); int n = (int) bitToIndex (other.highestBit) + 1; while (--n >= 0) - values[n] |= other.values[n]; + values[n] |= otherValues[n]; if (other.highestBit > highestBit) highestBit = other.highestBit; @@ -592,13 +639,16 @@ BigInteger& BigInteger::operator&= (const BigInteger& other) // this operation doesn't take into account negative values.. jassert (isNegative() == other.isNegative()); - int n = (int) numValues; + uint32* const values = getValues(); + const uint32* const otherValues = other.getValues(); - while (n > (int) other.numValues) + int n = (int) allocatedSize; + + while (n > (int) other.allocatedSize) values[--n] = 0; while (--n >= 0) - values[n] &= other.values[n]; + values[n] &= otherValues[n]; if (other.highestBit < highestBit) highestBit = other.highestBit; @@ -614,12 +664,13 @@ BigInteger& BigInteger::operator^= (const BigInteger& other) if (other.highestBit >= 0) { - ensureSize (bitToIndex (other.highestBit)); + uint32* const values = ensureSize (sizeNeededToHold (other.highestBit)); + const uint32* const otherValues = other.getValues(); int n = (int) bitToIndex (other.highestBit) + 1; while (--n >= 0) - values[n] ^= other.values[n]; + values[n] ^= otherValues[n]; if (other.highestBit > highestBit) highestBit = other.highestBit; @@ -679,9 +730,12 @@ int BigInteger::compareAbsolute (const BigInteger& other) const noexcept if (h1 > h2) return 1; if (h1 < h2) return -1; - for (int i = (int) bitToIndex (h1) + 1; --i >= 0;) - if (values[i] != other.values[i]) - return (values[i] > other.values[i]) ? 1 : -1; + const uint32* const values = getValues(); + const uint32* const otherValues = other.getValues(); + + for (int i = (int) bitToIndex (h1); i >= 0; --i) + if (values[i] != otherValues[i]) + return values[i] > otherValues[i] ? 1 : -1; return 0; } @@ -706,19 +760,19 @@ void BigInteger::shiftLeft (int bits, const int startBit) } else { - ensureSize (bitToIndex (highestBit + bits) + 1); + uint32* const values = ensureSize (sizeNeededToHold (highestBit + bits)); const size_t wordsToMove = bitToIndex (bits); - size_t top = 1 + bitToIndex (highestBit); + size_t numOriginalInts = bitToIndex (highestBit); highestBit += bits; if (wordsToMove > 0) { - for (int i = (int) top; --i >= 0;) - values [(size_t) i + wordsToMove] = values [i]; + for (int i = (int) numOriginalInts; i >= 0; --i) + values[(size_t) i + wordsToMove] = values[i]; for (size_t j = 0; j < wordsToMove; ++j) - values [j] = 0; + values[j] = 0; bits &= 31; } @@ -727,10 +781,10 @@ void BigInteger::shiftLeft (int bits, const int startBit) { const int invBits = 32 - bits; - for (size_t i = top + 1 + wordsToMove; --i > wordsToMove;) - values[i] = (values[i] << bits) | (values [i - 1] >> invBits); + for (size_t i = bitToIndex (highestBit); i > wordsToMove; --i) + values[i] = (values[i] << bits) | (values[i - 1] >> invBits); - values [wordsToMove] = values [wordsToMove] << bits; + values[wordsToMove] = values[wordsToMove] << bits; } highestBit = getHighestBit(); @@ -748,7 +802,7 @@ void BigInteger::shiftRight (int bits, const int startBit) } else { - if (bits > highestBit) + if (bits >= highestBit) { clear(); } @@ -757,15 +811,16 @@ void BigInteger::shiftRight (int bits, const int startBit) const size_t wordsToMove = bitToIndex (bits); size_t top = 1 + bitToIndex (highestBit) - wordsToMove; highestBit -= bits; + uint32* const values = getValues(); if (wordsToMove > 0) { size_t i; for (i = 0; i < top; ++i) - values [i] = values [i + wordsToMove]; + values[i] = values[i + wordsToMove]; for (i = 0; i < wordsToMove; ++i) - values [top + i] = 0; + values[top + i] = 0; bits &= 31; } @@ -773,10 +828,10 @@ void BigInteger::shiftRight (int bits, const int startBit) if (bits != 0) { const int invBits = 32 - bits; - --top; + for (size_t i = 0; i < top; ++i) - values[i] = (values[i] >> bits) | (values [i + 1] << invBits); + values[i] = (values[i] >> bits) | (values[i + 1] << invBits); values[top] = (values[top] >> bits); } @@ -1112,6 +1167,7 @@ MemoryBlock BigInteger::toMemoryBlock() const { const int numBytes = (getHighestBit() + 8) >> 3; MemoryBlock mb ((size_t) numBytes); + const uint32* const values = getValues(); for (int i = 0; i < numBytes; ++i) mb[i] = (char) ((values[i / 4] >> ((i & 3) * 8)) & 0xff); @@ -1122,14 +1178,13 @@ MemoryBlock BigInteger::toMemoryBlock() const void BigInteger::loadFromMemoryBlock (const MemoryBlock& data) { const size_t numBytes = data.getSize(); - numValues = 1 + (numBytes / sizeof (uint32)); - values.malloc (numValues + 1); + const size_t numInts = 1 + (numBytes / sizeof (uint32)); + uint32* const values = ensureSize (numInts); - for (int i = 0; i < (int) numValues - 1; ++i) + for (int i = 0; i < (int) numInts - 1; ++i) values[i] = (uint32) ByteOrder::littleEndianInt (addBytesToPointer (data.getData(), sizeof (uint32) * (size_t) i)); - values[numValues - 1] = 0; - values[numValues] = 0; + values[numInts - 1] = 0; for (int i = (int) (numBytes & ~3u); i < (int) numBytes; ++i) this->setBitRangeAsInt (i << 3, 8, (uint32) data [i]); @@ -1153,7 +1208,7 @@ public: BigInteger b; while (b < 2) - r.fillBitsRandomly (b, 0, r.nextInt (150) + 1); + r.fillBitsRandomly (b, 0, r.nextInt (200) + 1); return b; } @@ -1172,15 +1227,29 @@ public: BigInteger b1 (getBigRandom(r)), b2 (getBigRandom(r)); + if ((j % 100) == 1 || j == 1) b1 = BigInteger(); + if ((j % 100) == 2 || j == 1) b2 = BigInteger(); + + expect (((b2 << 4) >> 4) == b2); + expect (((b2 << 32) >> 32) == b2); + expect (((b2 << 200) >> 200) == b2); + + expect (((b2 & BigInteger (1)) | BigInteger (1)).isOne()); + expect (((b2 | BigInteger (1)) & BigInteger (1)).isOne()); + expect ((b2 ^ b2).isZero()); + expect ((b2 >> 300).isZero()); + expect ((((b2 >> 16) << 16) & BigInteger (0xffff)).isZero()); + BigInteger b3 = b1 + b2; - expect (b3 > b1 && b3 > b2); + expect ((b3 > b1 && b3 > b2) || (b1.isZero() || b2.isZero())); + expect ((b3 > b1 && b3 > b2) || (b1.isZero() || b2.isZero())); expect (b3 - b1 == b2); expect (b3 - b2 == b1); BigInteger b4 = b1 * b2; - expect (b4 > b1 && b4 > b2); - expect (b4 / b1 == b2); - expect (b4 / b2 == b1); + expect ((b4 > b1 && b4 > b2) || b4.isZero()); + expect (b4 / b1 == b2 || b1.isZero()); + expect (b4 / b2 == b1 || b2.isZero()); // TODO: should add tests for other ops (although they also get pretty well tested in the RSA unit test) diff --git a/modules/juce_core/maths/juce_BigInteger.h b/modules/juce_core/maths/juce_BigInteger.h index ada0e89e5b..4a663317bb 100644 --- a/modules/juce_core/maths/juce_BigInteger.h +++ b/modules/juce_core/maths/juce_BigInteger.h @@ -106,7 +106,7 @@ public: //============================================================================== /** Resets the value to 0. */ - void clear(); + void clear() noexcept; /** Clears a particular bit in the number. */ void clearBit (int bitNumber) noexcept; @@ -325,12 +325,15 @@ public: private: //============================================================================== - HeapBlock values; - size_t numValues; + enum { numPreallocatedInts = 4 }; + HeapBlock heapAllocation; + uint32 preallocated[numPreallocatedInts]; + size_t allocatedSize; int highestBit; bool negative; - void ensureSize (size_t); + uint32* getValues() const noexcept; + uint32* ensureSize (size_t); void shiftLeft (int bits, int startBit); void shiftRight (int bits, int startBit);