diff options
Diffstat (limited to 'courgette')
-rw-r--r-- | courgette/assembly_program.cc | 47 | ||||
-rw-r--r-- | courgette/encoded_program.cc | 165 | ||||
-rw-r--r-- | courgette/encoded_program.h | 24 | ||||
-rw-r--r-- | courgette/encoded_program_unittest.cc | 12 | ||||
-rw-r--r-- | courgette/ensemble_apply.cc | 7 | ||||
-rw-r--r-- | courgette/memory_allocator.h | 41 | ||||
-rw-r--r-- | courgette/streams.cc | 74 | ||||
-rw-r--r-- | courgette/streams.h | 31 | ||||
-rw-r--r-- | courgette/streams_unittest.cc | 30 | ||||
-rw-r--r-- | courgette/third_party/bsdiff_apply.cc | 10 | ||||
-rw-r--r-- | courgette/third_party/bsdiff_create.cc | 41 |
11 files changed, 313 insertions, 169 deletions
diff --git a/courgette/assembly_program.cc b/courgette/assembly_program.cc index 1dc9ecf..ea6e4d2 100644 --- a/courgette/assembly_program.cc +++ b/courgette/assembly_program.cc @@ -12,6 +12,7 @@ #include <vector> #include "base/logging.h" +#include "base/scoped_ptr.h" #include "courgette/courgette.h" #include "courgette/encoded_program.h" @@ -287,26 +288,35 @@ void AssemblyProgram::AssignRemainingIndexes(RVAToLabel* labels) { << " infill " << fill_infill_count; } -typedef void (EncodedProgram::*DefineLabelMethod)(int index, RVA value); +typedef CheckBool (EncodedProgram::*DefineLabelMethod)(int index, RVA value); #if defined(OS_WIN) __declspec(noinline) #endif -static void DefineLabels(const RVAToLabel& labels, - EncodedProgram* encoded_format, - DefineLabelMethod define_label) { - for (RVAToLabel::const_iterator p = labels.begin(); p != labels.end(); ++p) { +static CheckBool DefineLabels(const RVAToLabel& labels, + EncodedProgram* encoded_format, + DefineLabelMethod define_label) { + bool ok = true; + for (RVAToLabel::const_iterator p = labels.begin(); + ok && p != labels.end(); + ++p) { Label* label = p->second; - (encoded_format->*define_label)(label->index_, label->rva_); + ok = (encoded_format->*define_label)(label->index_, label->rva_); } + return ok; } EncodedProgram* AssemblyProgram::Encode() const { - EncodedProgram* encoded = new EncodedProgram(); - + scoped_ptr<EncodedProgram> encoded(new EncodedProgram()); encoded->set_image_base(image_base_); - DefineLabels(abs32_labels_, encoded, &EncodedProgram::DefineAbs32Label); - DefineLabels(rel32_labels_, encoded, &EncodedProgram::DefineRel32Label); + + if (!DefineLabels(abs32_labels_, encoded.get(), + &EncodedProgram::DefineAbs32Label) || + !DefineLabels(rel32_labels_, encoded.get(), + &EncodedProgram::DefineRel32Label)) { + return NULL; + } + encoded->EndLabels(); for (size_t i = 0; i < instructions_.size(); ++i) { @@ -315,26 +325,31 @@ EncodedProgram* AssemblyProgram::Encode() const { switch (instruction->op()) { case ORIGIN: { OriginInstruction* org = static_cast<OriginInstruction*>(instruction); - encoded->AddOrigin(org->origin_rva()); + if (!encoded->AddOrigin(org->origin_rva())) + return NULL; break; } case DEFBYTE: { uint8 b = static_cast<ByteInstruction*>(instruction)->byte_value(); - encoded->AddCopy(1, &b); + if (!encoded->AddCopy(1, &b)) + return NULL; break; } case REL32: { Label* label = static_cast<InstructionWithLabel*>(instruction)->label(); - encoded->AddRel32(label->index_); + if (!encoded->AddRel32(label->index_)) + return NULL; break; } case ABS32: { Label* label = static_cast<InstructionWithLabel*>(instruction)->label(); - encoded->AddAbs32(label->index_); + if (!encoded->AddAbs32(label->index_)) + return NULL; break; } case MAKERELOCS: { - encoded->AddMakeRelocs(); + if (!encoded->AddMakeRelocs()) + return NULL; break; } default: { @@ -343,7 +358,7 @@ EncodedProgram* AssemblyProgram::Encode() const { } } - return encoded; + return encoded.release(); } Instruction* AssemblyProgram::GetByteInstruction(uint8 byte) { diff --git a/courgette/encoded_program.cc b/courgette/encoded_program.cc index 5169c16..ad41bca 100644 --- a/courgette/encoded_program.cc +++ b/courgette/encoded_program.cc @@ -40,14 +40,15 @@ EncodedProgram::~EncodedProgram() {} // Serializes a vector of integral values using Varint32 coding. template<typename T, typename A> -void WriteVector(const std::vector<T, A>& items, SinkStream* buffer) { +CheckBool WriteVector(const std::vector<T, A>& items, SinkStream* buffer) { size_t count = items.size(); - buffer->WriteSizeVarint32(count); - for (size_t i = 0; i < count; ++i) { + bool ok = buffer->WriteSizeVarint32(count); + for (size_t i = 0; ok && i < count; ++i) { COMPILE_ASSERT(sizeof(T) <= sizeof(uint32), // NOLINT T_must_fit_in_uint32); - buffer->WriteSizeVarint32(items[i]); + ok = buffer->WriteSizeVarint32(items[i]); } + return ok; } template<typename T, typename A> @@ -62,6 +63,7 @@ bool ReadVector(std::vector<T, A>* items, SourceStream* buffer) { uint32 item; if (!buffer->ReadVarint32(&item)) return false; + // TODO(tommi): Handle errors. items->push_back(static_cast<T>(item)); } @@ -70,26 +72,29 @@ bool ReadVector(std::vector<T, A>* items, SourceStream* buffer) { // Serializes a vector, using delta coding followed by Varint32 coding. template<typename A> -void WriteU32Delta(const std::vector<uint32, A>& set, SinkStream* buffer) { +CheckBool WriteU32Delta(const std::vector<uint32, A>& set, SinkStream* buffer) { size_t count = set.size(); - buffer->WriteSizeVarint32(count); + bool ok = buffer->WriteSizeVarint32(count); uint32 prev = 0; - for (size_t i = 0; i < count; ++i) { + for (size_t i = 0; ok && i < count; ++i) { uint32 current = set[i]; uint32 delta = current - prev; - buffer->WriteVarint32(delta); + ok = buffer->WriteVarint32(delta); prev = current; } + return ok; } template <typename A> -static bool ReadU32Delta(std::vector<uint32, A>* set, SourceStream* buffer) { +static CheckBool ReadU32Delta(std::vector<uint32, A>* set, + SourceStream* buffer) { uint32 count; if (!buffer->ReadVarint32(&count)) return false; set->clear(); + // TODO(tommi): Handle errors. set->reserve(count); uint32 prev = 0; @@ -98,28 +103,31 @@ static bool ReadU32Delta(std::vector<uint32, A>* set, SourceStream* buffer) { if (!buffer->ReadVarint32(&delta)) return false; uint32 current = prev + delta; + // TODO(tommi): handle errors set->push_back(current); prev = current; } + // TODO(tommi): Handle errors. return true; } // Write a vector as the byte representation of the contents. // // (This only really makes sense for a type T that has sizeof(T)==1, otherwise -// serilized representation is not endian-agnositic. But it is useful to keep +// serialized representation is not endian-agnositic. But it is useful to keep // the possibility of a greater size for experiments comparing Varint32 encoding // of a vector of larger integrals vs a plain form.) // template<typename T, typename A> -void WriteVectorU8(const std::vector<T, A>& items, SinkStream* buffer) { +CheckBool WriteVectorU8(const std::vector<T, A>& items, SinkStream* buffer) { size_t count = items.size(); - buffer->WriteSizeVarint32(count); - if (count != 0) { + bool ok = buffer->WriteSizeVarint32(count); + if (count != 0 && ok) { size_t byte_count = count * sizeof(T); - buffer->Write(static_cast<const void*>(&items[0]), byte_count); + ok = buffer->Write(static_cast<const void*>(&items[0]), byte_count); } + return ok; } template<typename T, typename A> @@ -129,6 +137,7 @@ bool ReadVectorU8(std::vector<T, A>* items, SourceStream* buffer) { return false; items->clear(); + // TODO(tommi): check error items->resize(count); if (count != 0) { size_t byte_count = count * sizeof(T); @@ -139,26 +148,29 @@ bool ReadVectorU8(std::vector<T, A>* items, SourceStream* buffer) { //////////////////////////////////////////////////////////////////////////////// -void EncodedProgram::DefineRel32Label(int index, RVA value) { - DefineLabelCommon(&rel32_rva_, index, value); +CheckBool EncodedProgram::DefineRel32Label(int index, RVA value) { + return DefineLabelCommon(&rel32_rva_, index, value); } -void EncodedProgram::DefineAbs32Label(int index, RVA value) { - DefineLabelCommon(&abs32_rva_, index, value); +CheckBool EncodedProgram::DefineAbs32Label(int index, RVA value) { + return DefineLabelCommon(&abs32_rva_, index, value); } static const RVA kUnassignedRVA = static_cast<RVA>(-1); -void EncodedProgram::DefineLabelCommon(RvaVector* rvas, - int index, - RVA rva) { +CheckBool EncodedProgram::DefineLabelCommon(RvaVector* rvas, + int index, + RVA rva) { if (static_cast<int>(rvas->size()) <= index) { + // TODO(tommi): handle error rvas->resize(index + 1, kUnassignedRVA); } if ((*rvas)[index] != kUnassignedRVA) { NOTREACHED() << "DefineLabel double assigned " << index; } (*rvas)[index] = rva; + // TODO(tommi): Handle errors + return true; } void EncodedProgram::EndLabels() { @@ -181,12 +193,15 @@ void EncodedProgram::FinishLabelsCommon(RvaVector* rvas) { } } -void EncodedProgram::AddOrigin(RVA origin) { +CheckBool EncodedProgram::AddOrigin(RVA origin) { + //TODO(tommi): Handle errors ops_.push_back(ORIGIN); origins_.push_back(origin); + return true; } -void EncodedProgram::AddCopy(uint32 count, const void* bytes) { +CheckBool EncodedProgram::AddCopy(uint32 count, const void* bytes) { + //TODO(tommi): Handle errors const uint8* source = static_cast<const uint8*>(bytes); // Fold adjacent COPY instructions into one. This nearly halves the size of @@ -205,7 +220,7 @@ void EncodedProgram::AddCopy(uint32 count, const void* bytes) { for (uint32 i = 0; i < count; ++i) { copy_bytes_.push_back(source[i]); } - return; + return true; } } @@ -219,20 +234,28 @@ void EncodedProgram::AddCopy(uint32 count, const void* bytes) { copy_bytes_.push_back(source[i]); } } + + return true; } -void EncodedProgram::AddAbs32(int label_index) { +CheckBool EncodedProgram::AddAbs32(int label_index) { + //TODO(tommi): Handle errors ops_.push_back(ABS32); abs32_ix_.push_back(label_index); + return true; } -void EncodedProgram::AddRel32(int label_index) { +CheckBool EncodedProgram::AddRel32(int label_index) { + //TODO(tommi): Handle errors ops_.push_back(REL32); rel32_ix_.push_back(label_index); + return true; } -void EncodedProgram::AddMakeRelocs() { +CheckBool EncodedProgram::AddMakeRelocs() { + //TODO(tommi): Handle errors ops_.push_back(MAKE_BASE_RELOCATION_TABLE); + return true; } void EncodedProgram::DebuggingSummary() { @@ -279,7 +302,7 @@ static FieldSelect GetFieldSelect() { return static_cast<FieldSelect>(~0); } -void EncodedProgram::WriteTo(SinkStreamSet* streams) { +CheckBool EncodedProgram::WriteTo(SinkStreamSet* streams) { FieldSelect select = GetFieldSelect(); // The order of fields must be consistent in WriteTo and ReadFrom, regardless @@ -293,28 +316,46 @@ void EncodedProgram::WriteTo(SinkStreamSet* streams) { if (select & INCLUDE_MISC) { // TODO(sra): write 64 bits. - streams->stream(kStreamMisc)->WriteVarint32( - static_cast<uint32>(image_base_)); + if (!streams->stream(kStreamMisc)->WriteVarint32( + static_cast<uint32>(image_base_))) { + return false; + } + } + + bool success = true; + + if (select & INCLUDE_ABS32_ADDRESSES) { + success &= WriteU32Delta(abs32_rva_, + streams->stream(kStreamAbs32Addresses)); + } + + if (select & INCLUDE_REL32_ADDRESSES) { + success &= WriteU32Delta(rel32_rva_, + streams->stream(kStreamRel32Addresses)); } - if (select & INCLUDE_ABS32_ADDRESSES) - WriteU32Delta(abs32_rva_, streams->stream(kStreamAbs32Addresses)); - if (select & INCLUDE_REL32_ADDRESSES) - WriteU32Delta(rel32_rva_, streams->stream(kStreamRel32Addresses)); if (select & INCLUDE_MISC) - WriteVector(origins_, streams->stream(kStreamOriginAddresses)); + success &= WriteVector(origins_, streams->stream(kStreamOriginAddresses)); + if (select & INCLUDE_OPS) { - streams->stream(kStreamOps)->Reserve(ops_.size() + 5); // 5 for length. - WriteVector(ops_, streams->stream(kStreamOps)); + // 5 for length. + success &= streams->stream(kStreamOps)->Reserve(ops_.size() + 5); + success &= WriteVector(ops_, streams->stream(kStreamOps)); } + if (select & INCLUDE_COPY_COUNTS) - WriteVector(copy_counts_, streams->stream(kStreamCopyCounts)); + success &= WriteVector(copy_counts_, streams->stream(kStreamCopyCounts)); + if (select & INCLUDE_BYTES) - WriteVectorU8(copy_bytes_, streams->stream(kStreamBytes)); + success &= WriteVectorU8(copy_bytes_, streams->stream(kStreamBytes)); + if (select & INCLUDE_ABS32_INDEXES) - WriteVector(abs32_ix_, streams->stream(kStreamAbs32Indexes)); + success &= WriteVector(abs32_ix_, streams->stream(kStreamAbs32Indexes)); + if (select & INCLUDE_REL32_INDEXES) - WriteVector(rel32_ix_, streams->stream(kStreamRel32Indexes)); + success &= WriteVector(rel32_ix_, streams->stream(kStreamRel32Indexes)); + + return success; } bool EncodedProgram::ReadFrom(SourceStreamSet* streams) { @@ -360,7 +401,7 @@ bool VectorAt(const std::vector<T, A>& v, size_t index, T* output) { return true; } -bool EncodedProgram::AssembleTo(SinkStream* final_buffer) { +CheckBool EncodedProgram::AssembleTo(SinkStream* final_buffer) { // For the most part, the assembly process walks the various tables. // ix_mumble is the index into the mumble table. size_t ix_origins = 0; @@ -402,7 +443,8 @@ bool EncodedProgram::AssembleTo(SinkStream* final_buffer) { if (!VectorAt(copy_bytes_, ix_copy_bytes, &b)) return false; ++ix_copy_bytes; - output->Write(&b, 1); + if (!output->Write(&b, 1)) + return false; } current_rva += count; break; @@ -413,7 +455,8 @@ bool EncodedProgram::AssembleTo(SinkStream* final_buffer) { if (!VectorAt(copy_bytes_, ix_copy_bytes, &b)) return false; ++ix_copy_bytes; - output->Write(&b, 1); + if (!output->Write(&b, 1)) + return false; current_rva += 1; break; } @@ -427,7 +470,8 @@ bool EncodedProgram::AssembleTo(SinkStream* final_buffer) { if (!VectorAt(rel32_rva_, index, &rva)) return false; uint32 offset = (rva - (current_rva + 4)); - output->Write(&offset, 4); + if (!output->Write(&offset, 4)) + return false; current_rva += 4; break; } @@ -442,7 +486,8 @@ bool EncodedProgram::AssembleTo(SinkStream* final_buffer) { return false; uint32 abs32 = static_cast<uint32>(rva + image_base_); abs32_relocs_.push_back(current_rva); - output->Write(&abs32, 4); + if (!output->Write(&abs32, 4)) + return false; current_rva += 4; break; } @@ -471,8 +516,9 @@ bool EncodedProgram::AssembleTo(SinkStream* final_buffer) { } if (pending_base_relocation_table) { - GenerateBaseRelocations(final_buffer); - final_buffer->Append(&bytes_following_base_relocation_table); + if (!GenerateBaseRelocations(final_buffer) || + !final_buffer->Append(&bytes_following_base_relocation_table)) + return false; } // Final verification check: did we consume all lists? @@ -488,7 +534,6 @@ bool EncodedProgram::AssembleTo(SinkStream* final_buffer) { return true; } - // RelocBlock has the layout of a block of relocations in the base relocation // table file format. // @@ -512,39 +557,45 @@ class RelocBlock { pod.block_size += 2; } - void Flush(SinkStream* buffer) { + CheckBool Flush(SinkStream* buffer) { + bool ok = true; if (pod.block_size != 8) { if (pod.block_size % 4 != 0) { // Pad to make size multiple of 4 bytes. Add(0); } - buffer->Write(&pod, pod.block_size); + ok = buffer->Write(&pod, pod.block_size); pod.block_size = 8; } + return ok; } RelocBlockPOD pod; }; -void EncodedProgram::GenerateBaseRelocations(SinkStream* buffer) { +CheckBool EncodedProgram::GenerateBaseRelocations(SinkStream* buffer) { std::sort(abs32_relocs_.begin(), abs32_relocs_.end()); RelocBlock block; - for (size_t i = 0; i < abs32_relocs_.size(); ++i) { + bool ok = true; + for (size_t i = 0; ok && i < abs32_relocs_.size(); ++i) { uint32 rva = abs32_relocs_[i]; uint32 page_rva = rva & ~0xFFF; if (page_rva != block.pod.page_rva) { - block.Flush(buffer); + ok &= block.Flush(buffer); block.pod.page_rva = page_rva; } - block.Add(0x3000 | (rva & 0xFFF)); + if (ok) + block.Add(0x3000 | (rva & 0xFFF)); } - block.Flush(buffer); + ok &= block.Flush(buffer); + return ok; } //////////////////////////////////////////////////////////////////////////////// Status WriteEncodedProgram(EncodedProgram* encoded, SinkStreamSet* sink) { - encoded->WriteTo(sink); + if (!encoded->WriteTo(sink)) + return C_STREAM_ERROR; return C_OK; } diff --git a/courgette/encoded_program.h b/courgette/encoded_program.h index 5662f2e..6d2f440 100644 --- a/courgette/encoded_program.h +++ b/courgette/encoded_program.h @@ -32,27 +32,27 @@ class EncodedProgram { void set_image_base(uint64 base) { image_base_ = base; } // (2) Address tables and indexes defined first. - void DefineRel32Label(int index, RVA address); - void DefineAbs32Label(int index, RVA address); + CheckBool DefineRel32Label(int index, RVA address); + CheckBool DefineAbs32Label(int index, RVA address); void EndLabels(); // (3) Add instructions in the order needed to generate bytes of file. - void AddOrigin(RVA rva); - void AddCopy(uint32 count, const void* bytes); - void AddRel32(int label_index); - void AddAbs32(int label_index); - void AddMakeRelocs(); + CheckBool AddOrigin(RVA rva); + CheckBool AddCopy(uint32 count, const void* bytes); + CheckBool AddRel32(int label_index); + CheckBool AddAbs32(int label_index); + CheckBool AddMakeRelocs(); // (3) Serialize binary assembly language tables to a set of streams. - void WriteTo(SinkStreamSet *streams); + CheckBool WriteTo(SinkStreamSet* streams); // Using an EncodedProgram to generate a byte stream: // // (4) Deserializes a fresh EncodedProgram from a set of streams. - bool ReadFrom(SourceStreamSet *streams); + bool ReadFrom(SourceStreamSet* streams); // (5) Assembles the 'binary assembly language' into final file. - bool AssembleTo(SinkStream *buffer); + CheckBool AssembleTo(SinkStream* buffer); private: // Binary assembly language operations. @@ -74,8 +74,8 @@ class EncodedProgram { typedef std::vector<OP, MemoryAllocator<OP> > OPVector; void DebuggingSummary(); - void GenerateBaseRelocations(SinkStream *buffer); - void DefineLabelCommon(RvaVector*, int, RVA); + CheckBool GenerateBaseRelocations(SinkStream *buffer); + CheckBool DefineLabelCommon(RvaVector*, int, RVA); void FinishLabelsCommon(RvaVector* addresses); // Binary assembly language tables. diff --git a/courgette/encoded_program_unittest.cc b/courgette/encoded_program_unittest.cc index fb3fd17..e1d7698 100644 --- a/courgette/encoded_program_unittest.cc +++ b/courgette/encoded_program_unittest.cc @@ -17,18 +17,18 @@ TEST(EncodedProgramTest, Test) { uint32 base = 0x00900000; program->set_image_base(base); - program->DefineRel32Label(5, 0); // REL32 index 5 == base + 0 - program->DefineAbs32Label(7, 4); // ABS32 index 7 == base + 4 + EXPECT_TRUE(program->DefineRel32Label(5, 0)); // REL32 index 5 == base + 0 + EXPECT_TRUE(program->DefineAbs32Label(7, 4)); // ABS32 index 7 == base + 4 program->EndLabels(); - program->AddOrigin(0); // Start at base. - program->AddAbs32(7); - program->AddRel32(5); + EXPECT_TRUE(program->AddOrigin(0)); // Start at base. + EXPECT_TRUE(program->AddAbs32(7)); + EXPECT_TRUE(program->AddRel32(5)); // Serialize and deserialize. courgette::SinkStreamSet sinks; - program->WriteTo(&sinks); + EXPECT_TRUE(program->WriteTo(&sinks)); courgette::SinkStream sink; bool can_collect = sinks.CopyTo(&sink); diff --git a/courgette/ensemble_apply.cc b/courgette/ensemble_apply.cc index c865bfa..9621f30 100644 --- a/courgette/ensemble_apply.cc +++ b/courgette/ensemble_apply.cc @@ -220,10 +220,13 @@ Status EnsemblePatchApplication::TransformDown( SinkStream* basic_elements) { // Construct blob of original input followed by reformed elements. - basic_elements->Reserve(final_patch_input_size_prediction_); + if (!basic_elements->Reserve(final_patch_input_size_prediction_)) { + return C_STREAM_ERROR; + } // The original input: - basic_elements->Write(base_region_.start(), base_region_.length()); + if (!basic_elements->Write(base_region_.start(), base_region_.length())) + return C_STREAM_ERROR; for (size_t i = 0; i < patchers_.size(); ++i) { SourceStreamSet single_corrected_element; diff --git a/courgette/memory_allocator.h b/courgette/memory_allocator.h index fce1e25..0b2f376 100644 --- a/courgette/memory_allocator.h +++ b/courgette/memory_allocator.h @@ -12,6 +12,47 @@ #include "base/logging.h" #include "base/platform_file.h" +#ifndef NDEBUG + +// A helper class to track down call sites that are not handling error cases. +template<class T> +class CheckReturnValue { + public: + // Not marked explicit on purpose. + CheckReturnValue(T value) : value_(value), checked_(false) { // NOLINT + } + CheckReturnValue(const CheckReturnValue& other) + : value_(other.value_), checked_(other.checked_) { + other.checked_ = true; + } + + CheckReturnValue& operator=(const CheckReturnValue& other) { + if (this != &other) { + DCHECK(checked_); + value_ = other.value_; + checked_ = other.checked_; + other.checked_ = true; + } + } + + ~CheckReturnValue() { + DCHECK(checked_); + } + + operator const T&() const { + checked_ = true; + return value_; + } + + private: + T value_; + mutable bool checked_; +}; +typedef CheckReturnValue<bool> CheckBool; +#else +typedef bool CheckBool; +#endif + namespace courgette { #ifdef OS_WIN diff --git a/courgette/streams.cc b/courgette/streams.cc index 32dbf6b..ef81ded 100644 --- a/courgette/streams.cc +++ b/courgette/streams.cc @@ -181,37 +181,43 @@ bool SourceStream::Skip(size_t byte_count) { return true; } -void SinkStream::Write(const void* data, size_t byte_count) { +CheckBool SinkStream::Write(const void* data, size_t byte_count) { buffer_.append(static_cast<const char*>(data), byte_count); + //TODO(tommi): return error on failure. + return true; } -void SinkStream::WriteVarint32(uint32 value) { +CheckBool SinkStream::WriteVarint32(uint32 value) { uint8 buffer[Varint::kMax32]; uint8* end = Varint::Encode32(buffer, value); - Write(buffer, end - buffer); + return Write(buffer, end - buffer); } -void SinkStream::WriteVarint32Signed(int32 value) { +CheckBool SinkStream::WriteVarint32Signed(int32 value) { // Encode signed numbers so that numbers nearer zero have shorter // varint encoding. // 0000xxxx encoded as 000xxxx0. // 1111xxxx encoded as 000yyyy1 where yyyy is complement of xxxx. + bool ret; if (value < 0) - WriteVarint32(~value * 2 + 1); + ret = WriteVarint32(~value * 2 + 1); else - WriteVarint32(value * 2); + ret = WriteVarint32(value * 2); + return ret; } -void SinkStream::WriteSizeVarint32(size_t value) { +CheckBool SinkStream::WriteSizeVarint32(size_t value) { uint32 narrowed_value = static_cast<uint32>(value); // On 32-bit, the compiler should figure out this test always fails. LOG_ASSERT(value == narrowed_value); - WriteVarint32(narrowed_value); + return WriteVarint32(narrowed_value); } -void SinkStream::Append(SinkStream* other) { - Write(other->buffer_.c_str(), other->buffer_.size()); - other->Retire(); +CheckBool SinkStream::Append(SinkStream* other) { + bool ret = Write(other->buffer_.c_str(), other->buffer_.size()); + if (ret) + other->Retire(); + return ret; } void SinkStream::Retire() { @@ -326,35 +332,41 @@ void SinkStreamSet::Init(size_t stream_index_limit) { // The header for a stream set for N streams is serialized as // <version><N><length1><length2>...<lengthN> -void SinkStreamSet::CopyHeaderTo(SinkStream* header) { - header->WriteVarint32(kStreamsSerializationFormatVersion); - header->WriteSizeVarint32(count_); - for (size_t i = 0; i < count_; ++i) { - header->WriteSizeVarint32(stream(i)->Length()); +CheckBool SinkStreamSet::CopyHeaderTo(SinkStream* header) { + bool ret = header->WriteVarint32(kStreamsSerializationFormatVersion); + if (ret) { + ret = header->WriteSizeVarint32(count_); + for (size_t i = 0; ret && i < count_; ++i) { + ret = header->WriteSizeVarint32(stream(i)->Length()); + } } + return ret; } // Writes |this| to |combined_stream|. See SourceStreamSet::Init for the layout // of the stream metadata and contents. -bool SinkStreamSet::CopyTo(SinkStream *combined_stream) { +CheckBool SinkStreamSet::CopyTo(SinkStream *combined_stream) { SinkStream header; - CopyHeaderTo(&header); + bool ret = CopyHeaderTo(&header); + if (!ret) + return ret; // Reserve the correct amount of storage. size_t length = header.Length(); for (size_t i = 0; i < count_; ++i) { length += stream(i)->Length(); } - combined_stream->Reserve(length); - - combined_stream->Append(&header); - for (size_t i = 0; i < count_; ++i) { - combined_stream->Append(stream(i)); + ret = combined_stream->Reserve(length); + if (ret) { + ret = combined_stream->Append(&header); + for (size_t i = 0; ret && i < count_; ++i) { + ret = combined_stream->Append(stream(i)); + } } - return true; + return ret; } -bool SinkStreamSet::WriteSet(SinkStreamSet* set) { +CheckBool SinkStreamSet::WriteSet(SinkStreamSet* set) { uint32 lengths[kMaxStreams]; // 'stream_count' includes all non-empty streams and all empty stream numbered // lower than a non-empty stream. @@ -367,15 +379,15 @@ bool SinkStreamSet::WriteSet(SinkStreamSet* set) { } SinkStream* control_stream = this->stream(0); - control_stream->WriteSizeVarint32(stream_count); - for (size_t i = 0; i < stream_count; ++i) { - control_stream->WriteSizeVarint32(lengths[i]); + bool ret = control_stream->WriteSizeVarint32(stream_count); + for (size_t i = 0; ret && i < stream_count; ++i) { + ret = control_stream->WriteSizeVarint32(lengths[i]); } - for (size_t i = 0; i < stream_count; ++i) { - this->stream(i)->Append(set->stream(i)); + for (size_t i = 0; ret && i < stream_count; ++i) { + ret = this->stream(i)->Append(set->stream(i)); } - return true; + return ret; } } // namespace diff --git a/courgette/streams.h b/courgette/streams.h index 7be28a5..a778185 100644 --- a/courgette/streams.h +++ b/courgette/streams.h @@ -4,11 +4,11 @@ // Streams classes. // -// These memory-resident streams are used for serialzing data into a sequential +// These memory-resident streams are used for serializing data into a sequential // region of memory. // Streams are divided into SourceStreams for reading and SinkStreams for // writing. Streams are aggregated into Sets which allows several streams to be -// used at once. Example: we can write A1, B1, A2, B2 but achive the memory +// used at once. Example: we can write A1, B1, A2, B2 but achieve the memory // layout A1 A2 B1 B2 by writing 'A's to one stream and 'B's to another. #ifndef COURGETTE_STREAMS_H_ #define COURGETTE_STREAMS_H_ @@ -17,10 +17,12 @@ #include <string> #include "base/basictypes.h" +#include "base/compiler_specific.h" #include "courgette/memory_allocator.h" #include "courgette/region.h" + namespace courgette { class SourceStream; @@ -109,7 +111,7 @@ class SourceStream { }; // A SinkStream accumulates writes into a buffer that it owns. The stream is -// initialy in an 'accumulating' state where writes are permitted. Accessing +// initially in an 'accumulating' state where writes are permitted. Accessing // the buffer moves the stream into a 'locked' state where no more writes are // permitted. The stream may also be in a 'retired' state where the buffer // contents are no longer available. @@ -119,21 +121,21 @@ class SinkStream { ~SinkStream() {} // Appends |byte_count| bytes from |data| to the stream. - void Write(const void* data, size_t byte_count); + CheckBool Write(const void* data, size_t byte_count) WARN_UNUSED_RESULT; // Appends the 'varint32' encoding of |value| to the stream. - void WriteVarint32(uint32 value); + CheckBool WriteVarint32(uint32 value) WARN_UNUSED_RESULT; // Appends the 'varint32' encoding of |value| to the stream. - void WriteVarint32Signed(int32 value); + CheckBool WriteVarint32Signed(int32 value) WARN_UNUSED_RESULT; // Appends the 'varint32' encoding of |value| to the stream. // On platforms where sizeof(size_t) != sizeof(int32), do a safety check. - void WriteSizeVarint32(size_t value); + CheckBool WriteSizeVarint32(size_t value) WARN_UNUSED_RESULT; // Contents of |other| are appended to |this| stream. The |other| stream // becomes retired. - void Append(SinkStream* other); + CheckBool Append(SinkStream* other) WARN_UNUSED_RESULT; // Returns the number of bytes in this SinkStream size_t Length() const { return buffer_.size(); } @@ -146,7 +148,12 @@ class SinkStream { } // Hints that the stream will grow by an additional |length| bytes. - void Reserve(size_t length) { buffer_.reserve(length + buffer_.length()); } + // Caller must be prepared to handle memory allocation problems. + CheckBool Reserve(size_t length) WARN_UNUSED_RESULT { + buffer_.reserve(length + buffer_.length()); + //TODO(tommi): return false when allocation fails. + return true; + } // Finished with this stream and any storage it has. void Retire(); @@ -215,15 +222,15 @@ class SinkStreamSet { // CopyTo serializes the streams in this SinkStreamSet into a single target // stream. The serialized format may be re-read by initializing a // SourceStreamSet with a buffer containing the data. - bool CopyTo(SinkStream* combined_stream); + CheckBool CopyTo(SinkStream* combined_stream); // Writes the streams of |set| into the corresponding streams of |this|. // Stream zero first has some metadata written to it. |set| becomes retired. // Partner to SourceStreamSet::ReadSet. - bool WriteSet(SinkStreamSet* set); + CheckBool WriteSet(SinkStreamSet* set); private: - void CopyHeaderTo(SinkStream* stream); + CheckBool CopyHeaderTo(SinkStream* stream); size_t count_; SinkStream streams_[kMaxStreams]; diff --git a/courgette/streams_unittest.cc b/courgette/streams_unittest.cc index 64c12dd..d0903c9 100644 --- a/courgette/streams_unittest.cc +++ b/courgette/streams_unittest.cc @@ -12,7 +12,7 @@ TEST(StreamsTest, SimpleWriteRead) { const unsigned int kValue1 = 12345; courgette::SinkStream sink; - sink.WriteVarint32(kValue1); + EXPECT_TRUE(sink.WriteVarint32(kValue1)); const uint8* sink_buffer = sink.Buffer(); size_t length = sink.Length(); @@ -30,7 +30,7 @@ TEST(StreamsTest, SimpleWriteRead) { TEST(StreamsTest, SimpleWriteRead2) { courgette::SinkStream sink; - sink.Write("Hello", 5); + EXPECT_TRUE(sink.Write("Hello", 5)); const uint8* sink_buffer = sink.Buffer(); size_t sink_length = sink.Length(); @@ -51,11 +51,11 @@ TEST(StreamsTest, StreamSetWriteRead) { const unsigned int kValue1 = 12345; - out.stream(3)->WriteVarint32(kValue1); + EXPECT_TRUE(out.stream(3)->WriteVarint32(kValue1)); courgette::SinkStream collected; - out.CopyTo(&collected); + EXPECT_TRUE(out.CopyTo(&collected)); const uint8* collected_buffer = collected.Buffer(); size_t collected_length = collected.Length(); @@ -90,12 +90,12 @@ TEST(StreamsTest, StreamSetWriteRead2) { for (size_t i = 0; data[i] != kEnd; i += 2) { size_t id = data[i]; size_t datum = data[i + 1]; - out.stream(id)->WriteVarint32(datum); + EXPECT_TRUE(out.stream(id)->WriteVarint32(datum)); } courgette::SinkStream collected; - out.CopyTo(&collected); + EXPECT_TRUE(out.CopyTo(&collected)); courgette::SourceStreamSet in; bool can_init = in.Init(collected.Buffer(), collected.Length()); @@ -129,9 +129,9 @@ TEST(StreamsTest, SignedVarint32) { for (size_t i = 0; i < sizeof(data)/sizeof(data[0]); ++i) { int32 basis = data[i]; for (int delta = -4; delta <= 4; ++delta) { - out.WriteVarint32Signed(basis + delta); + EXPECT_TRUE(out.WriteVarint32Signed(basis + delta)); values.push_back(basis + delta); - out.WriteVarint32Signed(-basis + delta); + EXPECT_TRUE(out.WriteVarint32Signed(-basis + delta)); values.push_back(-basis + delta); } } @@ -155,18 +155,18 @@ TEST(StreamsTest, StreamSetReadWrite) { { // Local scope for temporary stream sets. courgette::SinkStreamSet subset1; - subset1.stream(3)->WriteVarint32(30000); - subset1.stream(5)->WriteVarint32(50000); - out.WriteSet(&subset1); + EXPECT_TRUE(subset1.stream(3)->WriteVarint32(30000)); + EXPECT_TRUE(subset1.stream(5)->WriteVarint32(50000)); + EXPECT_TRUE(out.WriteSet(&subset1)); courgette::SinkStreamSet subset2; - subset2.stream(2)->WriteVarint32(20000); - subset2.stream(6)->WriteVarint32(60000); - out.WriteSet(&subset2); + EXPECT_TRUE(subset2.stream(2)->WriteVarint32(20000)); + EXPECT_TRUE(subset2.stream(6)->WriteVarint32(60000)); + EXPECT_TRUE(out.WriteSet(&subset2)); } courgette::SinkStream collected; - out.CopyTo(&collected); + EXPECT_TRUE(out.CopyTo(&collected)); courgette::SourceStreamSet in; bool can_init_in = in.Init(collected.Buffer(), collected.Length()); EXPECT_TRUE(can_init_in); diff --git a/courgette/third_party/bsdiff_apply.cc b/courgette/third_party/bsdiff_apply.cc index cc9bb50..762c12c 100644 --- a/courgette/third_party/bsdiff_apply.cc +++ b/courgette/third_party/bsdiff_apply.cc @@ -77,7 +77,8 @@ BSDiffStatus MBS_ApplyPatch(const MBSPatchHeader *header, const uint8* old_position = old_start; - new_stream->Reserve(header->dlen); + if (header->dlen && !new_stream->Reserve(header->dlen)) + return MEM_ERROR; uint32 pending_diff_zeros = 0; if (!diff_skips->ReadVarint32(&pending_diff_zeros)) @@ -114,7 +115,8 @@ BSDiffStatus MBS_ApplyPatch(const MBSPatchHeader *header, return UNEXPECTED_ERROR; } uint8 byte = old_position[i] + diff_byte; - new_stream->Write(&byte, 1); + if (!new_stream->Write(&byte, 1)) + return MEM_ERROR; } old_position += copy_count; @@ -122,7 +124,9 @@ BSDiffStatus MBS_ApplyPatch(const MBSPatchHeader *header, if (extra_count > static_cast<size_t>(extra_end - extra_position)) return UNEXPECTED_ERROR; - new_stream->Write(extra_position, extra_count); + if (!new_stream->Write(extra_position, extra_count)) + return MEM_ERROR; + extra_position += extra_count; // "seek" forwards (or backwards) in oldfile. diff --git a/courgette/third_party/bsdiff_create.cc b/courgette/third_party/bsdiff_create.cc index 111c9f0..b05b070 100644 --- a/courgette/third_party/bsdiff_create.cc +++ b/courgette/third_party/bsdiff_create.cc @@ -191,11 +191,12 @@ search(PagedArray<int>& I,const unsigned char *old,int oldsize, // End of 'verbatim' code. // ------------------------------------------------------------------------ -static void WriteHeader(SinkStream* stream, MBSPatchHeader* header) { - stream->Write(header->tag, sizeof(header->tag)); - stream->WriteVarint32(header->slen); - stream->WriteVarint32(header->scrc32); - stream->WriteVarint32(header->dlen); +static CheckBool WriteHeader(SinkStream* stream, MBSPatchHeader* header) { + bool ok = stream->Write(header->tag, sizeof(header->tag)); + ok &= stream->WriteVarint32(header->slen); + ok &= stream->WriteVarint32(header->scrc32); + ok &= stream->WriteVarint32(header->dlen); + return ok; } BSDiffStatus CreateBinaryPatch(SourceStream* old_stream, @@ -375,16 +376,20 @@ BSDiffStatus CreateBinaryPatch(SourceStream* old_stream, uint8 diff_byte = newbuf[lastscan + i] - old[lastpos + i]; if (diff_byte) { ++diff_bytes_nonzero; - diff_skips->WriteVarint32(pending_diff_zeros); + if (!diff_skips->WriteVarint32(pending_diff_zeros)) + return MEM_ERROR; pending_diff_zeros = 0; - diff_bytes->Write(&diff_byte, 1); + if (!diff_bytes->Write(&diff_byte, 1)) + return MEM_ERROR; } else { ++pending_diff_zeros; } } int gap = (scan - lenb) - (lastscan + lenf); - for (int i = 0; i < gap; i++) - extra_bytes->Write(&newbuf[lastscan + lenf + i], 1); + for (int i = 0; i < gap; i++) { + if (!extra_bytes->Write(&newbuf[lastscan + lenf + i], 1)) + return MEM_ERROR; + } diff_bytes_length += lenf; extra_bytes_length += gap; @@ -393,9 +398,12 @@ BSDiffStatus CreateBinaryPatch(SourceStream* old_stream, uint32 extra_count = gap; int32 seek_adjustment = ((pos - lenb) - (lastpos + lenf)); - control_stream_copy_counts->WriteVarint32(copy_count); - control_stream_extra_counts->WriteVarint32(extra_count); - control_stream_seeks->WriteVarint32Signed(seek_adjustment); + if (!control_stream_copy_counts->WriteVarint32(copy_count) || + !control_stream_extra_counts->WriteVarint32(extra_count) || + !control_stream_seeks->WriteVarint32Signed(seek_adjustment)) { + return MEM_ERROR; + } + ++control_length; #ifdef DEBUG_bsmedberg VLOG(1) << StringPrintf("Writing a block: copy: %-8u extra: %-8u seek: " @@ -409,7 +417,8 @@ BSDiffStatus CreateBinaryPatch(SourceStream* old_stream, } } - diff_skips->WriteVarint32(pending_diff_zeros); + if (!diff_skips->WriteVarint32(pending_diff_zeros)) + return MEM_ERROR; I.clear(); @@ -422,10 +431,12 @@ BSDiffStatus CreateBinaryPatch(SourceStream* old_stream, header.scrc32 = CalculateCrc(old, oldsize); header.dlen = newsize; - WriteHeader(patch_stream, &header); + if (!WriteHeader(patch_stream, &header)) + return MEM_ERROR; size_t diff_skips_length = diff_skips->Length(); - patch_streams.CopyTo(patch_stream); + if (!patch_streams.CopyTo(patch_stream)) + return MEM_ERROR; VLOG(1) << "Control tuples: " << control_length << " copy bytes: " << diff_bytes_length |