1
0
Fork 0
mirror of https://github.com/juce-framework/JUCE.git synced 2026-01-10 23:44:24 +00:00

Modified BigInteger so that small (128-bit) values don't require heap allocation

This commit is contained in:
jules 2016-08-26 09:42:28 +01:00
parent 28bb28a642
commit c0c912ab4c
2 changed files with 213 additions and 141 deletions

View file

@ -28,76 +28,110 @@
namespace
{
inline size_t bitToIndex (const int bit) noexcept { return (size_t) (bit >> 5); }
inline uint32 bitToMask (const int bit) noexcept { return (uint32) 1 << (bit & 31); }
inline uint32 bitToMask (const int bit) noexcept { return (uint32) 1 << (bit & 31); }
inline size_t bitToIndex (const int bit) noexcept { return (size_t) (bit >> 5); }
inline size_t sizeNeededToHold (int highestBit) noexcept { return (size_t) (highestBit >> 5) + 1; }
inline int highestBitInInt (uint32 n) noexcept
{
jassert (n != 0); // (the built-in functions may not work for n = 0)
#if JUCE_GCC || JUCE_CLANG
return 31 - __builtin_clz (n);
#elif JUCE_MSVC
unsigned long highest;
_BitScanReverse (&highest, n);
return (int) highest;
#else
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return countBitsInInt32 (n >> 1);
#endif
}
}
//==============================================================================
BigInteger::BigInteger()
: numValues (4),
: allocatedSize (numPreallocatedInts),
highestBit (-1),
negative (false)
{
values.calloc (numValues + 1);
for (int i = 0; i < numPreallocatedInts; ++i)
preallocated[i] = 0;
}
BigInteger::BigInteger (const int32 value)
: numValues (4),
: allocatedSize (numPreallocatedInts),
highestBit (31),
negative (value < 0)
{
values.calloc (numValues + 1);
values[0] = (uint32) std::abs (value);
preallocated[0] = (uint32) std::abs (value);
for (int i = 1; i < numPreallocatedInts; ++i)
preallocated[i] = 0;
highestBit = getHighestBit();
}
BigInteger::BigInteger (const uint32 value)
: numValues (4),
: allocatedSize (numPreallocatedInts),
highestBit (31),
negative (false)
{
values.calloc (numValues + 1);
values[0] = value;
preallocated[0] = value;
for (int i = 1; i < numPreallocatedInts; ++i)
preallocated[i] = 0;
highestBit = getHighestBit();
}
BigInteger::BigInteger (int64 value)
: numValues (4),
: allocatedSize (numPreallocatedInts),
highestBit (63),
negative (value < 0)
{
values.calloc (numValues + 1);
if (value < 0)
value = -value;
values[0] = (uint32) value;
values[1] = (uint32) (value >> 32);
preallocated[0] = (uint32) value;
preallocated[1] = (uint32) (value >> 32);
for (int i = 2; i < numPreallocatedInts; ++i)
preallocated[i] = 0;
highestBit = getHighestBit();
}
BigInteger::BigInteger (const BigInteger& other)
: numValues ((size_t) jmax ((size_t) 4, bitToIndex (other.highestBit) + 1)),
: allocatedSize (other.allocatedSize),
highestBit (other.getHighestBit()),
negative (other.negative)
{
values.malloc (numValues + 1);
memcpy (values, other.values, sizeof (uint32) * (numValues + 1));
if (allocatedSize > numPreallocatedInts)
heapAllocation.malloc (allocatedSize);
memcpy (getValues(), other.getValues(), sizeof (uint32) * allocatedSize);
}
#if JUCE_COMPILER_SUPPORTS_MOVE_SEMANTICS
BigInteger::BigInteger (BigInteger&& other) noexcept
: values (static_cast<HeapBlock<uint32>&&> (other.values)),
numValues (other.numValues),
: heapAllocation (static_cast<HeapBlock<uint32>&&> (other.heapAllocation)),
allocatedSize (other.allocatedSize),
highestBit (other.highestBit),
negative (other.negative)
{
memcpy (preallocated, other.preallocated, sizeof (preallocated));
}
BigInteger& BigInteger::operator= (BigInteger&& other) noexcept
{
values = static_cast<HeapBlock<uint32>&&> (other.values);
numValues = other.numValues;
heapAllocation = static_cast<HeapBlock<uint32>&&> (other.heapAllocation);
memcpy (preallocated, other.preallocated, sizeof (preallocated));
allocatedSize = other.allocatedSize;
highestBit = other.highestBit;
negative = other.negative;
return *this;
@ -110,8 +144,11 @@ BigInteger::~BigInteger()
void BigInteger::swapWith (BigInteger& other) noexcept
{
values.swapWith (other.values);
std::swap (numValues, other.numValues);
for (int i = 0; i < numPreallocatedInts; ++i)
std::swap (preallocated[i], other.preallocated[i]);
heapAllocation.swapWith (other.heapAllocation);
std::swap (allocatedSize, other.allocatedSize);
std::swap (highestBit, other.highestBit);
std::swap (negative, other.negative);
}
@ -121,44 +158,68 @@ BigInteger& BigInteger::operator= (const BigInteger& other)
if (this != &other)
{
highestBit = other.getHighestBit();
jassert (other.numValues >= 4);
numValues = (size_t) jmax ((size_t) 4, bitToIndex (highestBit) + 1);
allocatedSize = (size_t) jmax ((size_t) numPreallocatedInts, sizeNeededToHold (highestBit));
if (allocatedSize <= numPreallocatedInts)
heapAllocation.free();
else
heapAllocation.malloc (allocatedSize);
memcpy (getValues(), other.getValues(), sizeof (uint32) * allocatedSize);
negative = other.negative;
values.malloc (numValues + 1);
memcpy (values, other.values, sizeof (uint32) * (numValues + 1));
}
return *this;
}
void BigInteger::ensureSize (const size_t numVals)
uint32* BigInteger::getValues() const noexcept
{
if (numVals + 2 >= numValues)
{
size_t oldSize = numValues;
numValues = ((numVals + 2) * 3) / 2;
values.realloc (numValues + 1);
jassert (heapAllocation != nullptr || allocatedSize <= numPreallocatedInts);
while (oldSize < numValues)
values [oldSize++] = 0;
return heapAllocation != nullptr ? heapAllocation
: (uint32*) preallocated;
}
uint32* BigInteger::ensureSize (const size_t numVals)
{
if (numVals > allocatedSize)
{
size_t oldSize = allocatedSize;
allocatedSize = ((numVals + 2) * 3) / 2;
if (heapAllocation == nullptr)
{
heapAllocation.calloc (allocatedSize);
memcpy (heapAllocation, preallocated, sizeof (uint32) * numPreallocatedInts);
}
else
{
heapAllocation.realloc (allocatedSize);
for (uint32* values = getValues(); oldSize < allocatedSize; ++oldSize)
values[oldSize] = 0;
}
}
return getValues();
}
//==============================================================================
bool BigInteger::operator[] (const int bit) const noexcept
{
return bit <= highestBit && bit >= 0
&& ((values [bitToIndex (bit)] & bitToMask (bit)) != 0);
&& ((getValues() [bitToIndex (bit)] & bitToMask (bit)) != 0);
}
int BigInteger::toInteger() const noexcept
{
const int n = (int) (values[0] & 0x7fffffff);
const int n = (int) (getValues()[0] & 0x7fffffff);
return negative ? -n : n;
}
int64 BigInteger::toInt64() const noexcept
{
const uint32* values = getValues();
const int64 n = (((int64) (values[1] & 0x7fffffff)) << 32) | values[0];
return negative ? -n : n;
}
@ -167,13 +228,12 @@ BigInteger BigInteger::getBitRange (int startBit, int numBits) const
{
BigInteger r;
numBits = jmin (numBits, getHighestBit() + 1 - startBit);
r.ensureSize ((size_t) bitToIndex (numBits));
uint32* const destValues = r.ensureSize (sizeNeededToHold (numBits));
r.highestBit = numBits;
int i = 0;
while (numBits > 0)
for (int i = 0; numBits > 0;)
{
r.values[i++] = getBitRangeAsInt (startBit, (int) jmin (32, numBits));
destValues[i++] = getBitRangeAsInt (startBit, (int) jmin (32, numBits));
numBits -= 32;
startBit += 32;
}
@ -198,6 +258,7 @@ uint32 BigInteger::getBitRangeAsInt (const int startBit, int numBits) const noex
const size_t pos = bitToIndex (startBit);
const int offset = startBit & 31;
const int endSpace = 32 - numBits;
const uint32* values = getValues();
uint32 n = ((uint32) values [pos]) >> offset;
@ -223,20 +284,15 @@ void BigInteger::setBitRangeAsInt (const int startBit, int numBits, uint32 value
}
//==============================================================================
void BigInteger::clear()
void BigInteger::clear() noexcept
{
if (numValues > 16)
{
numValues = 4;
values.calloc (numValues + 1);
}
else
{
values.clear (numValues + 1);
}
heapAllocation.free();
allocatedSize = numPreallocatedInts;
highestBit = -1;
negative = false;
for (int i = 0; i < numPreallocatedInts; ++i)
preallocated[i] = 0;
}
void BigInteger::setBit (const int bit)
@ -245,11 +301,11 @@ void BigInteger::setBit (const int bit)
{
if (bit > highestBit)
{
ensureSize (bitToIndex (bit));
ensureSize (sizeNeededToHold (bit));
highestBit = bit;
}
values [bitToIndex (bit)] |= bitToMask (bit);
getValues() [bitToIndex (bit)] |= bitToMask (bit);
}
}
@ -264,7 +320,12 @@ void BigInteger::setBit (const int bit, const bool shouldBeSet)
void BigInteger::clearBit (const int bit) noexcept
{
if (bit >= 0 && bit <= highestBit)
values [bitToIndex (bit)] &= ~bitToMask (bit);
{
getValues() [bitToIndex (bit)] &= ~bitToMask (bit);
if (bit == highestBit)
highestBit = getHighestBit();
}
}
void BigInteger::setRange (int startBit, int numBits, const bool shouldBeSet)
@ -311,31 +372,12 @@ void BigInteger::negate() noexcept
#pragma intrinsic (_BitScanReverse)
#endif
inline static int highestBitInInt (uint32 n) noexcept
{
jassert (n != 0); // (the built-in functions may not work for n = 0)
#if JUCE_GCC || JUCE_CLANG
return 31 - __builtin_clz (n);
#elif JUCE_MSVC
unsigned long highest;
_BitScanReverse (&highest, n);
return (int) highest;
#else
n |= (n >> 1);
n |= (n >> 2);
n |= (n >> 4);
n |= (n >> 8);
n |= (n >> 16);
return countBitsInInt32 (n >> 1);
#endif
}
int BigInteger::countNumberOfSetBits() const noexcept
{
int total = 0;
const uint32* values = getValues();
for (int i = (int) bitToIndex (highestBit) + 1; --i >= 0;)
for (int i = (int) sizeNeededToHold (highestBit); --i >= 0;)
total += countNumberOfBits (values[i]);
return total;
@ -343,19 +385,19 @@ int BigInteger::countNumberOfSetBits() const noexcept
int BigInteger::getHighestBit() const noexcept
{
for (int i = (int) bitToIndex (highestBit + 1); i >= 0; --i)
{
const uint32 n = values[i];
const uint32* values = getValues();
if (n != 0)
for (int i = (int) bitToIndex (highestBit); i >= 0; --i)
if (uint32 n = values[i])
return highestBitInInt (n) + (i << 5);
}
return -1;
}
int BigInteger::findNextSetBit (int i) const noexcept
{
const uint32* values = getValues();
for (; i <= highestBit; ++i)
if ((values [bitToIndex (i)] & bitToMask (i)) != 0)
return i;
@ -365,6 +407,8 @@ int BigInteger::findNextSetBit (int i) const noexcept
int BigInteger::findNextClearBit (int i) const noexcept
{
const uint32* values = getValues();
for (; i <= highestBit; ++i)
if ((values [bitToIndex (i)] & bitToMask (i)) == 0)
break;
@ -396,23 +440,19 @@ BigInteger& BigInteger::operator+= (const BigInteger& other)
}
else
{
if (other.highestBit > highestBit)
highestBit = other.highestBit;
++highestBit;
const size_t numInts = bitToIndex (highestBit) + 1;
ensureSize (numInts);
highestBit = jmax (highestBit, other.highestBit) + 1;
const size_t numInts = sizeNeededToHold (highestBit);
uint32* const values = ensureSize (numInts);
const uint32* const otherValues = other.getValues();
int64 remainder = 0;
for (size_t i = 0; i <= numInts; ++i)
for (size_t i = 0; i < numInts; ++i)
{
if (i < numValues)
remainder += values[i];
remainder += values[i];
if (i < other.numValues)
remainder += other.values[i];
if (i < other.allocatedSize)
remainder += otherValues[i];
values[i] = (uint32) remainder;
remainder >>= 32;
@ -449,14 +489,17 @@ BigInteger& BigInteger::operator-= (const BigInteger& other)
return *this;
}
const size_t numInts = bitToIndex (highestBit) + 1;
const size_t maxOtherInts = bitToIndex (other.highestBit) + 1;
const size_t numInts = sizeNeededToHold (highestBit);
const size_t maxOtherInts = sizeNeededToHold (other.highestBit);
jassert (numInts >= maxOtherInts);
uint32* const values = getValues();
const uint32* const otherValues = other.getValues();
int64 amountToSubtract = 0;
for (size_t i = 0; i <= numInts; ++i)
for (size_t i = 0; i < numInts; ++i)
{
if (i <= maxOtherInts)
amountToSubtract += (int64) other.values[i];
if (i < maxOtherInts)
amountToSubtract += (int64) otherValues[i];
if (values[i] >= amountToSubtract)
{
@ -487,26 +530,29 @@ BigInteger& BigInteger::operator*= (const BigInteger& other)
n >>= 5;
t >>= 5;
total.ensureSize ((size_t) (n + t + 2));
uint32* const totalValues = total.ensureSize (n + t + 2);
BigInteger m (other);
m.setNegative (false);
const uint32* const mValues = m.getValues();
const uint32* const values = getValues();
for (int i = 0; i <= t; ++i)
{
uint32 c = 0;
for (int j = 0; j <= n; ++j)
{
uint64 uv = (uint64) total.values[i + j] + (uint64) values[j] * (uint64) m.values[i] + (uint64) c;
total.values[i + j] = (uint32) uv;
uint64 uv = (uint64) totalValues[i + j] + (uint64) values[j] * (uint64) mValues[i] + (uint64) c;
totalValues[i + j] = (uint32) uv;
c = uv >> 32;
}
total.values[i + n + 1] = c;
totalValues[i + n + 1] = c;
}
total.highestBit = total.getHighestBit();
total.setNegative (wasNegative ^ other.isNegative());
swapWith (total);
@ -571,12 +617,13 @@ BigInteger& BigInteger::operator|= (const BigInteger& other)
if (other.highestBit >= 0)
{
ensureSize (bitToIndex (other.highestBit));
uint32* const values = ensureSize (sizeNeededToHold (other.highestBit));
const uint32* const otherValues = other.getValues();
int n = (int) bitToIndex (other.highestBit) + 1;
while (--n >= 0)
values[n] |= other.values[n];
values[n] |= otherValues[n];
if (other.highestBit > highestBit)
highestBit = other.highestBit;
@ -592,13 +639,16 @@ BigInteger& BigInteger::operator&= (const BigInteger& other)
// this operation doesn't take into account negative values..
jassert (isNegative() == other.isNegative());
int n = (int) numValues;
uint32* const values = getValues();
const uint32* const otherValues = other.getValues();
while (n > (int) other.numValues)
int n = (int) allocatedSize;
while (n > (int) other.allocatedSize)
values[--n] = 0;
while (--n >= 0)
values[n] &= other.values[n];
values[n] &= otherValues[n];
if (other.highestBit < highestBit)
highestBit = other.highestBit;
@ -614,12 +664,13 @@ BigInteger& BigInteger::operator^= (const BigInteger& other)
if (other.highestBit >= 0)
{
ensureSize (bitToIndex (other.highestBit));
uint32* const values = ensureSize (sizeNeededToHold (other.highestBit));
const uint32* const otherValues = other.getValues();
int n = (int) bitToIndex (other.highestBit) + 1;
while (--n >= 0)
values[n] ^= other.values[n];
values[n] ^= otherValues[n];
if (other.highestBit > highestBit)
highestBit = other.highestBit;
@ -679,9 +730,12 @@ int BigInteger::compareAbsolute (const BigInteger& other) const noexcept
if (h1 > h2) return 1;
if (h1 < h2) return -1;
for (int i = (int) bitToIndex (h1) + 1; --i >= 0;)
if (values[i] != other.values[i])
return (values[i] > other.values[i]) ? 1 : -1;
const uint32* const values = getValues();
const uint32* const otherValues = other.getValues();
for (int i = (int) bitToIndex (h1); i >= 0; --i)
if (values[i] != otherValues[i])
return values[i] > otherValues[i] ? 1 : -1;
return 0;
}
@ -706,19 +760,19 @@ void BigInteger::shiftLeft (int bits, const int startBit)
}
else
{
ensureSize (bitToIndex (highestBit + bits) + 1);
uint32* const values = ensureSize (sizeNeededToHold (highestBit + bits));
const size_t wordsToMove = bitToIndex (bits);
size_t top = 1 + bitToIndex (highestBit);
size_t numOriginalInts = bitToIndex (highestBit);
highestBit += bits;
if (wordsToMove > 0)
{
for (int i = (int) top; --i >= 0;)
values [(size_t) i + wordsToMove] = values [i];
for (int i = (int) numOriginalInts; i >= 0; --i)
values[(size_t) i + wordsToMove] = values[i];
for (size_t j = 0; j < wordsToMove; ++j)
values [j] = 0;
values[j] = 0;
bits &= 31;
}
@ -727,10 +781,10 @@ void BigInteger::shiftLeft (int bits, const int startBit)
{
const int invBits = 32 - bits;
for (size_t i = top + 1 + wordsToMove; --i > wordsToMove;)
values[i] = (values[i] << bits) | (values [i - 1] >> invBits);
for (size_t i = bitToIndex (highestBit); i > wordsToMove; --i)
values[i] = (values[i] << bits) | (values[i - 1] >> invBits);
values [wordsToMove] = values [wordsToMove] << bits;
values[wordsToMove] = values[wordsToMove] << bits;
}
highestBit = getHighestBit();
@ -748,7 +802,7 @@ void BigInteger::shiftRight (int bits, const int startBit)
}
else
{
if (bits > highestBit)
if (bits >= highestBit)
{
clear();
}
@ -757,15 +811,16 @@ void BigInteger::shiftRight (int bits, const int startBit)
const size_t wordsToMove = bitToIndex (bits);
size_t top = 1 + bitToIndex (highestBit) - wordsToMove;
highestBit -= bits;
uint32* const values = getValues();
if (wordsToMove > 0)
{
size_t i;
for (i = 0; i < top; ++i)
values [i] = values [i + wordsToMove];
values[i] = values[i + wordsToMove];
for (i = 0; i < wordsToMove; ++i)
values [top + i] = 0;
values[top + i] = 0;
bits &= 31;
}
@ -773,10 +828,10 @@ void BigInteger::shiftRight (int bits, const int startBit)
if (bits != 0)
{
const int invBits = 32 - bits;
--top;
for (size_t i = 0; i < top; ++i)
values[i] = (values[i] >> bits) | (values [i + 1] << invBits);
values[i] = (values[i] >> bits) | (values[i + 1] << invBits);
values[top] = (values[top] >> bits);
}
@ -1112,6 +1167,7 @@ MemoryBlock BigInteger::toMemoryBlock() const
{
const int numBytes = (getHighestBit() + 8) >> 3;
MemoryBlock mb ((size_t) numBytes);
const uint32* const values = getValues();
for (int i = 0; i < numBytes; ++i)
mb[i] = (char) ((values[i / 4] >> ((i & 3) * 8)) & 0xff);
@ -1122,14 +1178,13 @@ MemoryBlock BigInteger::toMemoryBlock() const
void BigInteger::loadFromMemoryBlock (const MemoryBlock& data)
{
const size_t numBytes = data.getSize();
numValues = 1 + (numBytes / sizeof (uint32));
values.malloc (numValues + 1);
const size_t numInts = 1 + (numBytes / sizeof (uint32));
uint32* const values = ensureSize (numInts);
for (int i = 0; i < (int) numValues - 1; ++i)
for (int i = 0; i < (int) numInts - 1; ++i)
values[i] = (uint32) ByteOrder::littleEndianInt (addBytesToPointer (data.getData(), sizeof (uint32) * (size_t) i));
values[numValues - 1] = 0;
values[numValues] = 0;
values[numInts - 1] = 0;
for (int i = (int) (numBytes & ~3u); i < (int) numBytes; ++i)
this->setBitRangeAsInt (i << 3, 8, (uint32) data [i]);
@ -1153,7 +1208,7 @@ public:
BigInteger b;
while (b < 2)
r.fillBitsRandomly (b, 0, r.nextInt (150) + 1);
r.fillBitsRandomly (b, 0, r.nextInt (200) + 1);
return b;
}
@ -1172,15 +1227,29 @@ public:
BigInteger b1 (getBigRandom(r)),
b2 (getBigRandom(r));
if ((j % 100) == 1 || j == 1) b1 = BigInteger();
if ((j % 100) == 2 || j == 1) b2 = BigInteger();
expect (((b2 << 4) >> 4) == b2);
expect (((b2 << 32) >> 32) == b2);
expect (((b2 << 200) >> 200) == b2);
expect (((b2 & BigInteger (1)) | BigInteger (1)).isOne());
expect (((b2 | BigInteger (1)) & BigInteger (1)).isOne());
expect ((b2 ^ b2).isZero());
expect ((b2 >> 300).isZero());
expect ((((b2 >> 16) << 16) & BigInteger (0xffff)).isZero());
BigInteger b3 = b1 + b2;
expect (b3 > b1 && b3 > b2);
expect ((b3 > b1 && b3 > b2) || (b1.isZero() || b2.isZero()));
expect ((b3 > b1 && b3 > b2) || (b1.isZero() || b2.isZero()));
expect (b3 - b1 == b2);
expect (b3 - b2 == b1);
BigInteger b4 = b1 * b2;
expect (b4 > b1 && b4 > b2);
expect (b4 / b1 == b2);
expect (b4 / b2 == b1);
expect ((b4 > b1 && b4 > b2) || b4.isZero());
expect (b4 / b1 == b2 || b1.isZero());
expect (b4 / b2 == b1 || b2.isZero());
// TODO: should add tests for other ops (although they also get pretty well tested in the RSA unit test)

View file

@ -106,7 +106,7 @@ public:
//==============================================================================
/** Resets the value to 0. */
void clear();
void clear() noexcept;
/** Clears a particular bit in the number. */
void clearBit (int bitNumber) noexcept;
@ -325,12 +325,15 @@ public:
private:
//==============================================================================
HeapBlock<uint32> values;
size_t numValues;
enum { numPreallocatedInts = 4 };
HeapBlock<uint32> heapAllocation;
uint32 preallocated[numPreallocatedInts];
size_t allocatedSize;
int highestBit;
bool negative;
void ensureSize (size_t);
uint32* getValues() const noexcept;
uint32* ensureSize (size_t);
void shiftLeft (int bits, int startBit);
void shiftRight (int bits, int startBit);