diff options
Diffstat (limited to 'base/win')
-rw-r--r-- | base/win/pe_image.cc | 570 | ||||
-rw-r--r-- | base/win/pe_image.h | 264 | ||||
-rw-r--r-- | base/win/pe_image_unittest.cc | 219 | ||||
-rw-r--r-- | base/win/registry.cc | 387 | ||||
-rw-r--r-- | base/win/registry.h | 171 | ||||
-rw-r--r-- | base/win/registry_unittest.cc | 65 |
6 files changed, 1676 insertions, 0 deletions
diff --git a/base/win/pe_image.cc b/base/win/pe_image.cc new file mode 100644 index 0000000..76fdbcd --- /dev/null +++ b/base/win/pe_image.cc @@ -0,0 +1,570 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file implements PEImage, a generic class to manipulate PE files. +// This file was adapted from GreenBorder's Code. + +#include "base/win/pe_image.h" + +namespace base { +namespace win { + +#if defined(_WIN64) && !defined(NACL_WIN64) +// TODO(rvargas): Bug 27218. Make sure this is ok. +#error This code is not tested on x64. Please make sure all the base unit tests\ + pass before doing any real work. The current unit tests don't test the\ + differences between 32- and 64-bits implementations. Bugs may slip through.\ + You need to improve the coverage before continuing. +#endif + +// Structure to perform imports enumerations. +struct EnumAllImportsStorage { + PEImage::EnumImportsFunction callback; + PVOID cookie; +}; + +namespace { + + // Compare two strings byte by byte on an unsigned basis. + // if s1 == s2, return 0 + // if s1 < s2, return negative + // if s1 > s2, return positive + // Exception if inputs are invalid. + int StrCmpByByte(LPCSTR s1, LPCSTR s2) { + while (*s1 != '\0' && *s1 == *s2) { + ++s1; + ++s2; + } + + return (*reinterpret_cast<const unsigned char*>(s1) - + *reinterpret_cast<const unsigned char*>(s2)); + } + +} // namespace + +// Callback used to enumerate imports. See EnumImportChunksFunction. +bool ProcessImportChunk(const PEImage &image, LPCSTR module, + PIMAGE_THUNK_DATA name_table, + PIMAGE_THUNK_DATA iat, PVOID cookie) { + EnumAllImportsStorage &storage = *reinterpret_cast<EnumAllImportsStorage*>( + cookie); + + return image.EnumOneImportChunk(storage.callback, module, name_table, iat, + storage.cookie); +} + +// Callback used to enumerate delay imports. See EnumDelayImportChunksFunction. +bool ProcessDelayImportChunk(const PEImage &image, + PImgDelayDescr delay_descriptor, + LPCSTR module, PIMAGE_THUNK_DATA name_table, + PIMAGE_THUNK_DATA iat, PIMAGE_THUNK_DATA bound_iat, + PIMAGE_THUNK_DATA unload_iat, PVOID cookie) { + EnumAllImportsStorage &storage = *reinterpret_cast<EnumAllImportsStorage*>( + cookie); + + return image.EnumOneDelayImportChunk(storage.callback, delay_descriptor, + module, name_table, iat, bound_iat, + unload_iat, storage.cookie); +} + +void PEImage::set_module(HMODULE module) { + module_ = module; +} + +PIMAGE_DOS_HEADER PEImage::GetDosHeader() const { + return reinterpret_cast<PIMAGE_DOS_HEADER>(module_); +} + +PIMAGE_NT_HEADERS PEImage::GetNTHeaders() const { + PIMAGE_DOS_HEADER dos_header = GetDosHeader(); + + return reinterpret_cast<PIMAGE_NT_HEADERS>( + reinterpret_cast<char*>(dos_header) + dos_header->e_lfanew); +} + +PIMAGE_SECTION_HEADER PEImage::GetSectionHeader(UINT section) const { + PIMAGE_NT_HEADERS nt_headers = GetNTHeaders(); + PIMAGE_SECTION_HEADER first_section = IMAGE_FIRST_SECTION(nt_headers); + + if (section < nt_headers->FileHeader.NumberOfSections) + return first_section + section; + else + return NULL; +} + +WORD PEImage::GetNumSections() const { + return GetNTHeaders()->FileHeader.NumberOfSections; +} + +DWORD PEImage::GetImageDirectoryEntrySize(UINT directory) const { + PIMAGE_NT_HEADERS nt_headers = GetNTHeaders(); + + return nt_headers->OptionalHeader.DataDirectory[directory].Size; +} + +PVOID PEImage::GetImageDirectoryEntryAddr(UINT directory) const { + PIMAGE_NT_HEADERS nt_headers = GetNTHeaders(); + + return RVAToAddr( + nt_headers->OptionalHeader.DataDirectory[directory].VirtualAddress); +} + +PIMAGE_SECTION_HEADER PEImage::GetImageSectionFromAddr(PVOID address) const { + PBYTE target = reinterpret_cast<PBYTE>(address); + PIMAGE_SECTION_HEADER section; + + for (UINT i = 0; NULL != (section = GetSectionHeader(i)); i++) { + // Don't use the virtual RVAToAddr. + PBYTE start = reinterpret_cast<PBYTE>( + PEImage::RVAToAddr(section->VirtualAddress)); + + DWORD size = section->Misc.VirtualSize; + + if ((start <= target) && (start + size > target)) + return section; + } + + return NULL; +} + +PIMAGE_SECTION_HEADER PEImage::GetImageSectionHeaderByName( + LPCSTR section_name) const { + if (NULL == section_name) + return NULL; + + PIMAGE_SECTION_HEADER ret = NULL; + int num_sections = GetNumSections(); + + for (int i = 0; i < num_sections; i++) { + PIMAGE_SECTION_HEADER section = GetSectionHeader(i); + if (0 == _strnicmp(reinterpret_cast<LPCSTR>(section->Name), section_name, + sizeof(section->Name))) { + ret = section; + break; + } + } + + return ret; +} + +PDWORD PEImage::GetExportEntry(LPCSTR name) const { + PIMAGE_EXPORT_DIRECTORY exports = GetExportDirectory(); + + if (NULL == exports) + return NULL; + + WORD ordinal = 0; + if (!GetProcOrdinal(name, &ordinal)) + return NULL; + + PDWORD functions = reinterpret_cast<PDWORD>( + RVAToAddr(exports->AddressOfFunctions)); + + return functions + ordinal - exports->Base; +} + +FARPROC PEImage::GetProcAddress(LPCSTR function_name) const { + PDWORD export_entry = GetExportEntry(function_name); + if (NULL == export_entry) + return NULL; + + PBYTE function = reinterpret_cast<PBYTE>(RVAToAddr(*export_entry)); + + PBYTE exports = reinterpret_cast<PBYTE>( + GetImageDirectoryEntryAddr(IMAGE_DIRECTORY_ENTRY_EXPORT)); + DWORD size = GetImageDirectoryEntrySize(IMAGE_DIRECTORY_ENTRY_EXPORT); + + // Check for forwarded exports as a special case. + if (exports <= function && exports + size > function) +#pragma warning(push) +#pragma warning(disable: 4312) + // This cast generates a warning because it is 32 bit specific. + return reinterpret_cast<FARPROC>(0xFFFFFFFF); +#pragma warning(pop) + + return reinterpret_cast<FARPROC>(function); +} + +bool PEImage::GetProcOrdinal(LPCSTR function_name, WORD *ordinal) const { + if (NULL == ordinal) + return false; + + PIMAGE_EXPORT_DIRECTORY exports = GetExportDirectory(); + + if (NULL == exports) + return false; + + if (IsOrdinal(function_name)) { + *ordinal = ToOrdinal(function_name); + } else { + PDWORD names = reinterpret_cast<PDWORD>(RVAToAddr(exports->AddressOfNames)); + PDWORD lower = names; + PDWORD upper = names + exports->NumberOfNames; + int cmp = -1; + + // Binary Search for the name. + while (lower != upper) { + PDWORD middle = lower + (upper - lower) / 2; + LPCSTR name = reinterpret_cast<LPCSTR>(RVAToAddr(*middle)); + + // This may be called by sandbox before MSVCRT dll loads, so can't use + // CRT function here. + cmp = StrCmpByByte(function_name, name); + + if (cmp == 0) { + lower = middle; + break; + } + + if (cmp > 0) + lower = middle + 1; + else + upper = middle; + } + + if (cmp != 0) + return false; + + + PWORD ordinals = reinterpret_cast<PWORD>( + RVAToAddr(exports->AddressOfNameOrdinals)); + + *ordinal = ordinals[lower - names] + static_cast<WORD>(exports->Base); + } + + return true; +} + +bool PEImage::EnumSections(EnumSectionsFunction callback, PVOID cookie) const { + PIMAGE_NT_HEADERS nt_headers = GetNTHeaders(); + UINT num_sections = nt_headers->FileHeader.NumberOfSections; + PIMAGE_SECTION_HEADER section = GetSectionHeader(0); + + for (UINT i = 0; i < num_sections; i++, section++) { + PVOID section_start = RVAToAddr(section->VirtualAddress); + DWORD size = section->Misc.VirtualSize; + + if (!callback(*this, section, section_start, size, cookie)) + return false; + } + + return true; +} + +bool PEImage::EnumExports(EnumExportsFunction callback, PVOID cookie) const { + PVOID directory = GetImageDirectoryEntryAddr(IMAGE_DIRECTORY_ENTRY_EXPORT); + DWORD size = GetImageDirectoryEntrySize(IMAGE_DIRECTORY_ENTRY_EXPORT); + + // Check if there are any exports at all. + if (NULL == directory || 0 == size) + return true; + + PIMAGE_EXPORT_DIRECTORY exports = reinterpret_cast<PIMAGE_EXPORT_DIRECTORY>( + directory); + UINT ordinal_base = exports->Base; + UINT num_funcs = exports->NumberOfFunctions; + UINT num_names = exports->NumberOfNames; + PDWORD functions = reinterpret_cast<PDWORD>(RVAToAddr( + exports->AddressOfFunctions)); + PDWORD names = reinterpret_cast<PDWORD>(RVAToAddr(exports->AddressOfNames)); + PWORD ordinals = reinterpret_cast<PWORD>(RVAToAddr( + exports->AddressOfNameOrdinals)); + + for (UINT count = 0; count < num_funcs; count++) { + PVOID func = RVAToAddr(functions[count]); + if (NULL == func) + continue; + + // Check for a name. + LPCSTR name = NULL; + UINT hint; + for (hint = 0; hint < num_names; hint++) { + if (ordinals[hint] == count) { + name = reinterpret_cast<LPCSTR>(RVAToAddr(names[hint])); + break; + } + } + + if (name == NULL) + hint = 0; + + // Check for forwarded exports. + LPCSTR forward = NULL; + if (reinterpret_cast<char*>(func) >= reinterpret_cast<char*>(directory) && + reinterpret_cast<char*>(func) <= reinterpret_cast<char*>(directory) + + size) { + forward = reinterpret_cast<LPCSTR>(func); + func = 0; + } + + if (!callback(*this, ordinal_base + count, hint, name, func, forward, + cookie)) + return false; + } + + return true; +} + +bool PEImage::EnumRelocs(EnumRelocsFunction callback, PVOID cookie) const { + PVOID directory = GetImageDirectoryEntryAddr(IMAGE_DIRECTORY_ENTRY_BASERELOC); + DWORD size = GetImageDirectoryEntrySize(IMAGE_DIRECTORY_ENTRY_BASERELOC); + PIMAGE_BASE_RELOCATION base = reinterpret_cast<PIMAGE_BASE_RELOCATION>( + directory); + + if (directory == NULL || size < sizeof(IMAGE_BASE_RELOCATION)) + return true; + + while (base->SizeOfBlock) { + PWORD reloc = reinterpret_cast<PWORD>(base + 1); + UINT num_relocs = (base->SizeOfBlock - sizeof(IMAGE_BASE_RELOCATION)) / + sizeof(WORD); + + for (UINT i = 0; i < num_relocs; i++, reloc++) { + WORD type = *reloc >> 12; + PVOID address = RVAToAddr(base->VirtualAddress + (*reloc & 0x0FFF)); + + if (!callback(*this, type, address, cookie)) + return false; + } + + base = reinterpret_cast<PIMAGE_BASE_RELOCATION>( + reinterpret_cast<char*>(base) + base->SizeOfBlock); + } + + return true; +} + +bool PEImage::EnumImportChunks(EnumImportChunksFunction callback, + PVOID cookie) const { + DWORD size = GetImageDirectoryEntrySize(IMAGE_DIRECTORY_ENTRY_IMPORT); + PIMAGE_IMPORT_DESCRIPTOR import = GetFirstImportChunk(); + + if (import == NULL || size < sizeof(IMAGE_IMPORT_DESCRIPTOR)) + return true; + + for (; import->FirstThunk; import++) { + LPCSTR module_name = reinterpret_cast<LPCSTR>(RVAToAddr(import->Name)); + PIMAGE_THUNK_DATA name_table = reinterpret_cast<PIMAGE_THUNK_DATA>( + RVAToAddr(import->OriginalFirstThunk)); + PIMAGE_THUNK_DATA iat = reinterpret_cast<PIMAGE_THUNK_DATA>( + RVAToAddr(import->FirstThunk)); + + if (!callback(*this, module_name, name_table, iat, cookie)) + return false; + } + + return true; +} + +bool PEImage::EnumOneImportChunk(EnumImportsFunction callback, + LPCSTR module_name, + PIMAGE_THUNK_DATA name_table, + PIMAGE_THUNK_DATA iat, PVOID cookie) const { + if (NULL == name_table) + return false; + + for (; name_table && name_table->u1.Ordinal; name_table++, iat++) { + LPCSTR name = NULL; + WORD ordinal = 0; + WORD hint = 0; + + if (IMAGE_SNAP_BY_ORDINAL(name_table->u1.Ordinal)) { + ordinal = static_cast<WORD>(IMAGE_ORDINAL32(name_table->u1.Ordinal)); + } else { + PIMAGE_IMPORT_BY_NAME import = reinterpret_cast<PIMAGE_IMPORT_BY_NAME>( + RVAToAddr(name_table->u1.ForwarderString)); + + hint = import->Hint; + name = reinterpret_cast<LPCSTR>(&import->Name); + } + + if (!callback(*this, module_name, ordinal, name, hint, iat, cookie)) + return false; + } + + return true; +} + +bool PEImage::EnumAllImports(EnumImportsFunction callback, PVOID cookie) const { + EnumAllImportsStorage temp = { callback, cookie }; + return EnumImportChunks(ProcessImportChunk, &temp); +} + +bool PEImage::EnumDelayImportChunks(EnumDelayImportChunksFunction callback, + PVOID cookie) const { + PVOID directory = GetImageDirectoryEntryAddr( + IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT); + DWORD size = GetImageDirectoryEntrySize(IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT); + PImgDelayDescr delay_descriptor = reinterpret_cast<PImgDelayDescr>(directory); + + if (directory == NULL || size == 0) + return true; + + for (; delay_descriptor->rvaHmod; delay_descriptor++) { + PIMAGE_THUNK_DATA name_table; + PIMAGE_THUNK_DATA iat; + PIMAGE_THUNK_DATA bound_iat; // address of the optional bound IAT + PIMAGE_THUNK_DATA unload_iat; // address of optional copy of original IAT + LPCSTR module_name; + + // check if VC7-style imports, using RVAs instead of + // VC6-style addresses. + bool rvas = (delay_descriptor->grAttrs & dlattrRva) != 0; + + if (rvas) { + module_name = reinterpret_cast<LPCSTR>( + RVAToAddr(delay_descriptor->rvaDLLName)); + name_table = reinterpret_cast<PIMAGE_THUNK_DATA>( + RVAToAddr(delay_descriptor->rvaINT)); + iat = reinterpret_cast<PIMAGE_THUNK_DATA>( + RVAToAddr(delay_descriptor->rvaIAT)); + bound_iat = reinterpret_cast<PIMAGE_THUNK_DATA>( + RVAToAddr(delay_descriptor->rvaBoundIAT)); + unload_iat = reinterpret_cast<PIMAGE_THUNK_DATA>( + RVAToAddr(delay_descriptor->rvaUnloadIAT)); + } else { +#pragma warning(push) +#pragma warning(disable: 4312) + // These casts generate warnings because they are 32 bit specific. + module_name = reinterpret_cast<LPCSTR>(delay_descriptor->rvaDLLName); + name_table = reinterpret_cast<PIMAGE_THUNK_DATA>( + delay_descriptor->rvaINT); + iat = reinterpret_cast<PIMAGE_THUNK_DATA>(delay_descriptor->rvaIAT); + bound_iat = reinterpret_cast<PIMAGE_THUNK_DATA>( + delay_descriptor->rvaBoundIAT); + unload_iat = reinterpret_cast<PIMAGE_THUNK_DATA>( + delay_descriptor->rvaUnloadIAT); +#pragma warning(pop) + } + + if (!callback(*this, delay_descriptor, module_name, name_table, iat, + bound_iat, unload_iat, cookie)) + return false; + } + + return true; +} + +bool PEImage::EnumOneDelayImportChunk(EnumImportsFunction callback, + PImgDelayDescr delay_descriptor, + LPCSTR module_name, + PIMAGE_THUNK_DATA name_table, + PIMAGE_THUNK_DATA iat, + PIMAGE_THUNK_DATA bound_iat, + PIMAGE_THUNK_DATA unload_iat, + PVOID cookie) const { + UNREFERENCED_PARAMETER(bound_iat); + UNREFERENCED_PARAMETER(unload_iat); + + for (; name_table->u1.Ordinal; name_table++, iat++) { + LPCSTR name = NULL; + WORD ordinal = 0; + WORD hint = 0; + + if (IMAGE_SNAP_BY_ORDINAL(name_table->u1.Ordinal)) { + ordinal = static_cast<WORD>(IMAGE_ORDINAL32(name_table->u1.Ordinal)); + } else { + PIMAGE_IMPORT_BY_NAME import; + bool rvas = (delay_descriptor->grAttrs & dlattrRva) != 0; + + if (rvas) { + import = reinterpret_cast<PIMAGE_IMPORT_BY_NAME>( + RVAToAddr(name_table->u1.ForwarderString)); + } else { +#pragma warning(push) +#pragma warning(disable: 4312) + // This cast generates a warning because it is 32 bit specific. + import = reinterpret_cast<PIMAGE_IMPORT_BY_NAME>( + name_table->u1.ForwarderString); +#pragma warning(pop) + } + + hint = import->Hint; + name = reinterpret_cast<LPCSTR>(&import->Name); + } + + if (!callback(*this, module_name, ordinal, name, hint, iat, cookie)) + return false; + } + + return true; +} + +bool PEImage::EnumAllDelayImports(EnumImportsFunction callback, + PVOID cookie) const { + EnumAllImportsStorage temp = { callback, cookie }; + return EnumDelayImportChunks(ProcessDelayImportChunk, &temp); +} + +bool PEImage::VerifyMagic() const { + PIMAGE_DOS_HEADER dos_header = GetDosHeader(); + + if (dos_header->e_magic != IMAGE_DOS_SIGNATURE) + return false; + + PIMAGE_NT_HEADERS nt_headers = GetNTHeaders(); + + if (nt_headers->Signature != IMAGE_NT_SIGNATURE) + return false; + + if (nt_headers->FileHeader.SizeOfOptionalHeader != + sizeof(IMAGE_OPTIONAL_HEADER)) + return false; + + if (nt_headers->OptionalHeader.Magic != IMAGE_NT_OPTIONAL_HDR_MAGIC) + return false; + + return true; +} + +bool PEImage::ImageRVAToOnDiskOffset(DWORD rva, DWORD *on_disk_offset) const { + LPVOID address = RVAToAddr(rva); + return ImageAddrToOnDiskOffset(address, on_disk_offset); +} + +bool PEImage::ImageAddrToOnDiskOffset(LPVOID address, + DWORD *on_disk_offset) const { + if (NULL == address) + return false; + + // Get the section that this address belongs to. + PIMAGE_SECTION_HEADER section_header = GetImageSectionFromAddr(address); + if (NULL == section_header) + return false; + +#pragma warning(push) +#pragma warning(disable: 4311) + // These casts generate warnings because they are 32 bit specific. + // Don't follow the virtual RVAToAddr, use the one on the base. + DWORD offset_within_section = reinterpret_cast<DWORD>(address) - + reinterpret_cast<DWORD>(PEImage::RVAToAddr( + section_header->VirtualAddress)); +#pragma warning(pop) + + *on_disk_offset = section_header->PointerToRawData + offset_within_section; + return true; +} + +PVOID PEImage::RVAToAddr(DWORD rva) const { + if (rva == 0) + return NULL; + + return reinterpret_cast<char*>(module_) + rva; +} + +PVOID PEImageAsData::RVAToAddr(DWORD rva) const { + if (rva == 0) + return NULL; + + PVOID in_memory = PEImage::RVAToAddr(rva); + DWORD dummy; + + if (!ImageAddrToOnDiskOffset(in_memory, &dummy)) + return NULL; + + return in_memory; +} + +} // namespace win +} // namespace base diff --git a/base/win/pe_image.h b/base/win/pe_image.h new file mode 100644 index 0000000..e1205e0 --- /dev/null +++ b/base/win/pe_image.h @@ -0,0 +1,264 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file was adapted from GreenBorder's Code. +// To understand what this class is about (for other than well known functions +// as GetProcAddress), a good starting point is "An In-Depth Look into the +// Win32 Portable Executable File Format" by Matt Pietrek: +// http://msdn.microsoft.com/msdnmag/issues/02/02/PE/default.aspx + +#ifndef BASE_WIN_PE_IMAGE_H_ +#define BASE_WIN_PE_IMAGE_H_ +#pragma once + +#include <windows.h> +#include <DelayIMP.h> + +namespace base { +namespace win { + +// This class is a wrapper for the Portable Executable File Format (PE). +// It's main purpose is to provide an easy way to work with imports and exports +// from a file, mapped in memory as image. +class PEImage { + public: + // Callback to enumerate sections. + // cookie is the value passed to the enumerate method. + // Returns true to continue the enumeration. + typedef bool (*EnumSectionsFunction)(const PEImage &image, + PIMAGE_SECTION_HEADER header, + PVOID section_start, DWORD section_size, + PVOID cookie); + + // Callback to enumerate exports. + // function is the actual address of the symbol. If forward is not null, it + // contains the dll and symbol to forward this export to. cookie is the value + // passed to the enumerate method. + // Returns true to continue the enumeration. + typedef bool (*EnumExportsFunction)(const PEImage &image, DWORD ordinal, + DWORD hint, LPCSTR name, PVOID function, + LPCSTR forward, PVOID cookie); + + // Callback to enumerate import blocks. + // name_table and iat point to the imports name table and address table for + // this block. cookie is the value passed to the enumerate method. + // Returns true to continue the enumeration. + typedef bool (*EnumImportChunksFunction)(const PEImage &image, LPCSTR module, + PIMAGE_THUNK_DATA name_table, + PIMAGE_THUNK_DATA iat, PVOID cookie); + + // Callback to enumerate imports. + // module is the dll that exports this symbol. cookie is the value passed to + // the enumerate method. + // Returns true to continue the enumeration. + typedef bool (*EnumImportsFunction)(const PEImage &image, LPCSTR module, + DWORD ordinal, LPCSTR name, DWORD hint, + PIMAGE_THUNK_DATA iat, PVOID cookie); + + // Callback to enumerate dalayed import blocks. + // module is the dll that exports this block of symbols. cookie is the value + // passed to the enumerate method. + // Returns true to continue the enumeration. + typedef bool (*EnumDelayImportChunksFunction)(const PEImage &image, + PImgDelayDescr delay_descriptor, + LPCSTR module, + PIMAGE_THUNK_DATA name_table, + PIMAGE_THUNK_DATA iat, + PIMAGE_THUNK_DATA bound_iat, + PIMAGE_THUNK_DATA unload_iat, + PVOID cookie); + + // Callback to enumerate relocations. + // cookie is the value passed to the enumerate method. + // Returns true to continue the enumeration. + typedef bool (*EnumRelocsFunction)(const PEImage &image, WORD type, + PVOID address, PVOID cookie); + + explicit PEImage(HMODULE module) : module_(module) {} + explicit PEImage(const void* module) { + module_ = reinterpret_cast<HMODULE>(const_cast<void*>(module)); + } + + // Gets the HMODULE for this object. + HMODULE module() const; + + // Sets this object's HMODULE. + void set_module(HMODULE module); + + // Checks if this symbol is actually an ordinal. + static bool IsOrdinal(LPCSTR name); + + // Converts a named symbol to the corresponding ordinal. + static WORD ToOrdinal(LPCSTR name); + + // Returns the DOS_HEADER for this PE. + PIMAGE_DOS_HEADER GetDosHeader() const; + + // Returns the NT_HEADER for this PE. + PIMAGE_NT_HEADERS GetNTHeaders() const; + + // Returns number of sections of this PE. + WORD GetNumSections() const; + + // Returns the header for a given section. + // returns NULL if there is no such section. + PIMAGE_SECTION_HEADER GetSectionHeader(UINT section) const; + + // Returns the size of a given directory entry. + DWORD GetImageDirectoryEntrySize(UINT directory) const; + + // Returns the address of a given directory entry. + PVOID GetImageDirectoryEntryAddr(UINT directory) const; + + // Returns the section header for a given address. + // Use: s = image.GetImageSectionFromAddr(a); + // Post: 's' is the section header of the section that contains 'a' + // or NULL if there is no such section. + PIMAGE_SECTION_HEADER GetImageSectionFromAddr(PVOID address) const; + + // Returns the section header for a given section. + PIMAGE_SECTION_HEADER GetImageSectionHeaderByName(LPCSTR section_name) const; + + // Returns the first block of imports. + PIMAGE_IMPORT_DESCRIPTOR GetFirstImportChunk() const; + + // Returns the exports directory. + PIMAGE_EXPORT_DIRECTORY GetExportDirectory() const; + + // Returns a given export entry. + // Use: e = image.GetExportEntry(f); + // Pre: 'f' is either a zero terminated string or ordinal + // Post: 'e' is a pointer to the export directory entry + // that contains 'f's export RVA, or NULL if 'f' + // is not exported from this image + PDWORD GetExportEntry(LPCSTR name) const; + + // Returns the address for a given exported symbol. + // Use: p = image.GetProcAddress(f); + // Pre: 'f' is either a zero terminated string or ordinal. + // Post: if 'f' is a non-forwarded export from image, 'p' is + // the exported function. If 'f' is a forwarded export + // then p is the special value 0xFFFFFFFF. In this case + // RVAToAddr(*GetExportEntry) can be used to resolve + // the string that describes the forward. + FARPROC GetProcAddress(LPCSTR function_name) const; + + // Retrieves the ordinal for a given exported symbol. + // Returns true if the symbol was found. + bool GetProcOrdinal(LPCSTR function_name, WORD *ordinal) const; + + // Enumerates PE sections. + // cookie is a generic cookie to pass to the callback. + // Returns true on success. + bool EnumSections(EnumSectionsFunction callback, PVOID cookie) const; + + // Enumerates PE exports. + // cookie is a generic cookie to pass to the callback. + // Returns true on success. + bool EnumExports(EnumExportsFunction callback, PVOID cookie) const; + + // Enumerates PE imports. + // cookie is a generic cookie to pass to the callback. + // Returns true on success. + bool EnumAllImports(EnumImportsFunction callback, PVOID cookie) const; + + // Enumerates PE import blocks. + // cookie is a generic cookie to pass to the callback. + // Returns true on success. + bool EnumImportChunks(EnumImportChunksFunction callback, PVOID cookie) const; + + // Enumerates the imports from a single PE import block. + // cookie is a generic cookie to pass to the callback. + // Returns true on success. + bool EnumOneImportChunk(EnumImportsFunction callback, LPCSTR module_name, + PIMAGE_THUNK_DATA name_table, PIMAGE_THUNK_DATA iat, + PVOID cookie) const; + + + // Enumerates PE delay imports. + // cookie is a generic cookie to pass to the callback. + // Returns true on success. + bool EnumAllDelayImports(EnumImportsFunction callback, PVOID cookie) const; + + // Enumerates PE delay import blocks. + // cookie is a generic cookie to pass to the callback. + // Returns true on success. + bool EnumDelayImportChunks(EnumDelayImportChunksFunction callback, + PVOID cookie) const; + + // Enumerates imports from a single PE delay import block. + // cookie is a generic cookie to pass to the callback. + // Returns true on success. + bool EnumOneDelayImportChunk(EnumImportsFunction callback, + PImgDelayDescr delay_descriptor, + LPCSTR module_name, + PIMAGE_THUNK_DATA name_table, + PIMAGE_THUNK_DATA iat, + PIMAGE_THUNK_DATA bound_iat, + PIMAGE_THUNK_DATA unload_iat, + PVOID cookie) const; + + // Enumerates PE relocation entries. + // cookie is a generic cookie to pass to the callback. + // Returns true on success. + bool EnumRelocs(EnumRelocsFunction callback, PVOID cookie) const; + + // Verifies the magic values on the PE file. + // Returns true if all values are correct. + bool VerifyMagic() const; + + // Converts an rva value to the appropriate address. + virtual PVOID RVAToAddr(DWORD rva) const; + + // Converts an rva value to an offset on disk. + // Returns true on success. + bool ImageRVAToOnDiskOffset(DWORD rva, DWORD *on_disk_offset) const; + + // Converts an address to an offset on disk. + // Returns true on success. + bool ImageAddrToOnDiskOffset(LPVOID address, DWORD *on_disk_offset) const; + + private: + HMODULE module_; +}; + +// This class is an extension to the PEImage class that allows working with PE +// files mapped as data instead of as image file. +class PEImageAsData : public PEImage { + public: + explicit PEImageAsData(HMODULE hModule) : PEImage(hModule) {} + + virtual PVOID RVAToAddr(DWORD rva) const; +}; + +inline bool PEImage::IsOrdinal(LPCSTR name) { +#pragma warning(push) +#pragma warning(disable: 4311) + // This cast generates a warning because it is 32 bit specific. + return reinterpret_cast<DWORD>(name) <= 0xFFFF; +#pragma warning(pop) +} + +inline WORD PEImage::ToOrdinal(LPCSTR name) { + return reinterpret_cast<WORD>(name); +} + +inline HMODULE PEImage::module() const { + return module_; +} + +inline PIMAGE_IMPORT_DESCRIPTOR PEImage::GetFirstImportChunk() const { + return reinterpret_cast<PIMAGE_IMPORT_DESCRIPTOR>( + GetImageDirectoryEntryAddr(IMAGE_DIRECTORY_ENTRY_IMPORT)); +} + +inline PIMAGE_EXPORT_DIRECTORY PEImage::GetExportDirectory() const { + return reinterpret_cast<PIMAGE_EXPORT_DIRECTORY>( + GetImageDirectoryEntryAddr(IMAGE_DIRECTORY_ENTRY_EXPORT)); +} + +} // namespace win +} // namespace base + +#endif // BASE_WIN_PE_IMAGE_H_ diff --git a/base/win/pe_image_unittest.cc b/base/win/pe_image_unittest.cc new file mode 100644 index 0000000..899ce94 --- /dev/null +++ b/base/win/pe_image_unittest.cc @@ -0,0 +1,219 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +// This file contains unit tests for PEImage. + +#include "testing/gtest/include/gtest/gtest.h" +#include "base/win/pe_image.h" +#include "base/win/windows_version.h" + +namespace base { +namespace win { + +// Just counts the number of invocations. +bool ExportsCallback(const PEImage &image, + DWORD ordinal, + DWORD hint, + LPCSTR name, + PVOID function, + LPCSTR forward, + PVOID cookie) { + int* count = reinterpret_cast<int*>(cookie); + (*count)++; + return true; +} + +// Just counts the number of invocations. +bool ImportsCallback(const PEImage &image, + LPCSTR module, + DWORD ordinal, + LPCSTR name, + DWORD hint, + PIMAGE_THUNK_DATA iat, + PVOID cookie) { + int* count = reinterpret_cast<int*>(cookie); + (*count)++; + return true; +} + +// Just counts the number of invocations. +bool SectionsCallback(const PEImage &image, + PIMAGE_SECTION_HEADER header, + PVOID section_start, + DWORD section_size, + PVOID cookie) { + int* count = reinterpret_cast<int*>(cookie); + (*count)++; + return true; +} + +// Just counts the number of invocations. +bool RelocsCallback(const PEImage &image, + WORD type, + PVOID address, + PVOID cookie) { + int* count = reinterpret_cast<int*>(cookie); + (*count)++; + return true; +} + +// Just counts the number of invocations. +bool ImportChunksCallback(const PEImage &image, + LPCSTR module, + PIMAGE_THUNK_DATA name_table, + PIMAGE_THUNK_DATA iat, + PVOID cookie) { + int* count = reinterpret_cast<int*>(cookie); + (*count)++; + return true; +} + +// Just counts the number of invocations. +bool DelayImportChunksCallback(const PEImage &image, + PImgDelayDescr delay_descriptor, + LPCSTR module, + PIMAGE_THUNK_DATA name_table, + PIMAGE_THUNK_DATA iat, + PIMAGE_THUNK_DATA bound_iat, + PIMAGE_THUNK_DATA unload_iat, + PVOID cookie) { + int* count = reinterpret_cast<int*>(cookie); + (*count)++; + return true; +} + +// We'll be using some known values for the tests. +enum Value { + sections = 0, + imports_dlls, + delay_dlls, + exports, + imports, + delay_imports, + relocs +}; + +// Retrieves the expected value from advapi32.dll based on the OS. +int GetExpectedValue(Value value, DWORD os) { + const int xp_delay_dlls = 2; + const int xp_exports = 675; + const int xp_imports = 422; + const int xp_delay_imports = 8; + const int xp_relocs = 9180; + const int vista_delay_dlls = 4; + const int vista_exports = 799; + const int vista_imports = 476; + const int vista_delay_imports = 24; + const int vista_relocs = 10188; + const int w2k_delay_dlls = 0; + const int w2k_exports = 566; + const int w2k_imports = 357; + const int w2k_delay_imports = 0; + const int w2k_relocs = 7388; + const int win7_delay_dlls = 7; + const int win7_exports = 806; + const int win7_imports = 568; + const int win7_delay_imports = 71; + const int win7_relocs = 7812; + + // Contains the expected value, for each enumerated property (Value), and the + // OS version: [Value][os_version] + const int expected[][4] = { + {4, 4, 4, 4}, + {3, 3, 3, 13}, + {w2k_delay_dlls, xp_delay_dlls, vista_delay_dlls, win7_delay_dlls}, + {w2k_exports, xp_exports, vista_exports, win7_exports}, + {w2k_imports, xp_imports, vista_imports, win7_imports}, + {w2k_delay_imports, xp_delay_imports, + vista_delay_imports, win7_delay_imports}, + {w2k_relocs, xp_relocs, vista_relocs, win7_relocs} + }; + + if (value > relocs) + return 0; + if (50 == os) + os = 0; // 5.0 + else if (51 == os || 52 == os) + os = 1; + else if (os == 60) + os = 2; // 6.x + else if (os >= 61) + os = 3; + else + return 0; + + return expected[value][os]; +} + +// Tests that we are able to enumerate stuff from a PE file, and that +// the actual number of items found is within the expected range. +TEST(PEImageTest, EnumeratesPE) { + // Windows Server 2003 is not supported as a test environment for this test. + if (base::win::GetVersion() == base::win::VERSION_SERVER_2003) + return; + HMODULE module = LoadLibrary(L"advapi32.dll"); + ASSERT_TRUE(NULL != module); + + PEImage pe(module); + int count = 0; + EXPECT_TRUE(pe.VerifyMagic()); + + DWORD os = pe.GetNTHeaders()->OptionalHeader.MajorOperatingSystemVersion; + os = os * 10 + pe.GetNTHeaders()->OptionalHeader.MinorOperatingSystemVersion; + + pe.EnumSections(SectionsCallback, &count); + EXPECT_EQ(GetExpectedValue(sections, os), count); + + count = 0; + pe.EnumImportChunks(ImportChunksCallback, &count); + EXPECT_EQ(GetExpectedValue(imports_dlls, os), count); + + count = 0; + pe.EnumDelayImportChunks(DelayImportChunksCallback, &count); + EXPECT_EQ(GetExpectedValue(delay_dlls, os), count); + + count = 0; + pe.EnumExports(ExportsCallback, &count); + EXPECT_GT(count, GetExpectedValue(exports, os) - 20); + EXPECT_LT(count, GetExpectedValue(exports, os) + 100); + + count = 0; + pe.EnumAllImports(ImportsCallback, &count); + EXPECT_GT(count, GetExpectedValue(imports, os) - 20); + EXPECT_LT(count, GetExpectedValue(imports, os) + 100); + + count = 0; + pe.EnumAllDelayImports(ImportsCallback, &count); + EXPECT_GT(count, GetExpectedValue(delay_imports, os) - 2); + EXPECT_LT(count, GetExpectedValue(delay_imports, os) + 8); + + count = 0; + pe.EnumRelocs(RelocsCallback, &count); + EXPECT_GT(count, GetExpectedValue(relocs, os) - 150); + EXPECT_LT(count, GetExpectedValue(relocs, os) + 1500); + + FreeLibrary(module); +} + +// Tests that we can locate an specific exported symbol, by name and by ordinal. +TEST(PEImageTest, RetrievesExports) { + HMODULE module = LoadLibrary(L"advapi32.dll"); + ASSERT_TRUE(NULL != module); + + PEImage pe(module); + WORD ordinal; + + EXPECT_TRUE(pe.GetProcOrdinal("RegEnumKeyExW", &ordinal)); + + FARPROC address1 = pe.GetProcAddress("RegEnumKeyExW"); + FARPROC address2 = pe.GetProcAddress(reinterpret_cast<char*>(ordinal)); + EXPECT_TRUE(address1 != NULL); + EXPECT_TRUE(address2 != NULL); + EXPECT_TRUE(address1 == address2); + + FreeLibrary(module); +} + +} // namespace win +} // namespace base diff --git a/base/win/registry.cc b/base/win/registry.cc new file mode 100644 index 0000000..545c337 --- /dev/null +++ b/base/win/registry.cc @@ -0,0 +1,387 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/win/registry.h" + +#include <shlwapi.h> + +#include "base/logging.h" + +#pragma comment(lib, "shlwapi.lib") // for SHDeleteKey + +namespace base { +namespace win { + +RegistryValueIterator::RegistryValueIterator(HKEY root_key, + const wchar_t* folder_key) { + LONG result = RegOpenKeyEx(root_key, folder_key, 0, KEY_READ, &key_); + if (result != ERROR_SUCCESS) { + key_ = NULL; + } else { + DWORD count = 0; + result = ::RegQueryInfoKey(key_, NULL, 0, NULL, NULL, NULL, NULL, &count, + NULL, NULL, NULL, NULL); + + if (result != ERROR_SUCCESS) { + ::RegCloseKey(key_); + key_ = NULL; + } else { + index_ = count - 1; + } + } + + Read(); +} + +RegistryValueIterator::~RegistryValueIterator() { + if (key_) + ::RegCloseKey(key_); +} + +bool RegistryValueIterator::Valid() const { + return key_ != NULL && index_ >= 0; +} + +void RegistryValueIterator::operator++() { + --index_; + Read(); +} + +bool RegistryValueIterator::Read() { + if (Valid()) { + DWORD ncount = arraysize(name_); + value_size_ = sizeof(value_); + LRESULT r = ::RegEnumValue(key_, index_, name_, &ncount, NULL, &type_, + reinterpret_cast<BYTE*>(value_), &value_size_); + if (ERROR_SUCCESS == r) + return true; + } + + name_[0] = '\0'; + value_[0] = '\0'; + value_size_ = 0; + return false; +} + +DWORD RegistryValueIterator::ValueCount() const { + DWORD count = 0; + HRESULT result = ::RegQueryInfoKey(key_, NULL, 0, NULL, NULL, NULL, NULL, + &count, NULL, NULL, NULL, NULL); + + if (result != ERROR_SUCCESS) + return 0; + + return count; +} + +RegistryKeyIterator::RegistryKeyIterator(HKEY root_key, + const wchar_t* folder_key) { + LONG result = RegOpenKeyEx(root_key, folder_key, 0, KEY_READ, &key_); + if (result != ERROR_SUCCESS) { + key_ = NULL; + } else { + DWORD count = 0; + HRESULT result = ::RegQueryInfoKey(key_, NULL, 0, NULL, &count, NULL, NULL, + NULL, NULL, NULL, NULL, NULL); + + if (result != ERROR_SUCCESS) { + ::RegCloseKey(key_); + key_ = NULL; + } else { + index_ = count - 1; + } + } + + Read(); +} + +RegistryKeyIterator::~RegistryKeyIterator() { + if (key_) + ::RegCloseKey(key_); +} + +bool RegistryKeyIterator::Valid() const { + return key_ != NULL && index_ >= 0; +} + +void RegistryKeyIterator::operator++() { + --index_; + Read(); +} + +bool RegistryKeyIterator::Read() { + if (Valid()) { + DWORD ncount = arraysize(name_); + FILETIME written; + LRESULT r = ::RegEnumKeyEx(key_, index_, name_, &ncount, NULL, NULL, + NULL, &written); + if (ERROR_SUCCESS == r) + return true; + } + + name_[0] = '\0'; + return false; +} + +DWORD RegistryKeyIterator::SubkeyCount() const { + DWORD count = 0; + HRESULT result = ::RegQueryInfoKey(key_, NULL, 0, NULL, &count, NULL, NULL, + NULL, NULL, NULL, NULL, NULL); + + if (result != ERROR_SUCCESS) + return 0; + + return count; +} + +RegKey::RegKey() + : key_(NULL), + watch_event_(0) { +} + +RegKey::RegKey(HKEY rootkey, const wchar_t* subkey, REGSAM access) + : key_(NULL), + watch_event_(0) { + if (rootkey) { + if (access & (KEY_SET_VALUE | KEY_CREATE_SUB_KEY | KEY_CREATE_LINK)) + Create(rootkey, subkey, access); + else + Open(rootkey, subkey, access); + } else { + DCHECK(!subkey); + } +} + +RegKey::~RegKey() { + Close(); +} + +void RegKey::Close() { + StopWatching(); + if (key_) { + ::RegCloseKey(key_); + key_ = NULL; + } +} + +bool RegKey::Create(HKEY rootkey, const wchar_t* subkey, REGSAM access) { + DWORD disposition_value; + return CreateWithDisposition(rootkey, subkey, &disposition_value, access); +} + +bool RegKey::CreateWithDisposition(HKEY rootkey, const wchar_t* subkey, + DWORD* disposition, REGSAM access) { + DCHECK(rootkey && subkey && access && disposition); + Close(); + + LONG result = RegCreateKeyEx(rootkey, + subkey, + 0, + NULL, + REG_OPTION_NON_VOLATILE, + access, + NULL, + &key_, + disposition); + if (result != ERROR_SUCCESS) { + key_ = NULL; + return false; + } + + return true; +} + +bool RegKey::Open(HKEY rootkey, const wchar_t* subkey, REGSAM access) { + DCHECK(rootkey && subkey && access); + Close(); + + LONG result = RegOpenKeyEx(rootkey, subkey, 0, access, &key_); + if (result != ERROR_SUCCESS) { + key_ = NULL; + return false; + } + return true; +} + +bool RegKey::CreateKey(const wchar_t* name, REGSAM access) { + DCHECK(name && access); + + HKEY subkey = NULL; + LONG result = RegCreateKeyEx(key_, name, 0, NULL, REG_OPTION_NON_VOLATILE, + access, NULL, &subkey, NULL); + Close(); + + key_ = subkey; + return (result == ERROR_SUCCESS); +} + +bool RegKey::OpenKey(const wchar_t* name, REGSAM access) { + DCHECK(name && access); + + HKEY subkey = NULL; + LONG result = RegOpenKeyEx(key_, name, 0, access, &subkey); + + Close(); + + key_ = subkey; + return (result == ERROR_SUCCESS); +} + +DWORD RegKey::ValueCount() { + DWORD count = 0; + HRESULT result = RegQueryInfoKey(key_, NULL, 0, NULL, NULL, NULL, + NULL, &count, NULL, NULL, NULL, NULL); + return (result != ERROR_SUCCESS) ? 0 : count; +} + +bool RegKey::ReadName(int index, std::wstring* name) { + wchar_t buf[256]; + DWORD bufsize = arraysize(buf); + LRESULT r = ::RegEnumValue(key_, index, buf, &bufsize, NULL, NULL, + NULL, NULL); + if (r != ERROR_SUCCESS) + return false; + if (name) + *name = buf; + return true; +} + +bool RegKey::ValueExists(const wchar_t* name) { + if (!key_) + return false; + HRESULT result = RegQueryValueEx(key_, name, 0, NULL, NULL, NULL); + return (result == ERROR_SUCCESS); +} + +bool RegKey::ReadValue(const wchar_t* name, void* data, + DWORD* dsize, DWORD* dtype) { + if (!key_) + return false; + HRESULT result = RegQueryValueEx(key_, name, 0, dtype, + reinterpret_cast<LPBYTE>(data), dsize); + return (result == ERROR_SUCCESS); +} + +bool RegKey::ReadValue(const wchar_t* name, std::wstring* value) { + DCHECK(value); + const size_t kMaxStringLength = 1024; // This is after expansion. + // Use the one of the other forms of ReadValue if 1024 is too small for you. + wchar_t raw_value[kMaxStringLength]; + DWORD type = REG_SZ, size = sizeof(raw_value); + if (ReadValue(name, raw_value, &size, &type)) { + if (type == REG_SZ) { + *value = raw_value; + } else if (type == REG_EXPAND_SZ) { + wchar_t expanded[kMaxStringLength]; + size = ExpandEnvironmentStrings(raw_value, expanded, kMaxStringLength); + // Success: returns the number of wchar_t's copied + // Fail: buffer too small, returns the size required + // Fail: other, returns 0 + if (size == 0 || size > kMaxStringLength) + return false; + *value = expanded; + } else { + // Not a string. Oops. + return false; + } + return true; + } + + return false; +} + +bool RegKey::ReadValueDW(const wchar_t* name, DWORD* value) { + DCHECK(value); + DWORD type = REG_DWORD; + DWORD size = sizeof(DWORD); + DWORD result = 0; + if (ReadValue(name, &result, &size, &type) && + (type == REG_DWORD || type == REG_BINARY) && + size == sizeof(DWORD)) { + *value = result; + return true; + } + + return false; +} + +bool RegKey::WriteValue(const wchar_t* name, const void * data, + DWORD dsize, DWORD dtype) { + DCHECK(data); + + if (!key_) + return false; + + HRESULT result = RegSetValueEx( + key_, + name, + 0, + dtype, + reinterpret_cast<LPBYTE>(const_cast<void*>(data)), + dsize); + return (result == ERROR_SUCCESS); +} + +bool RegKey::WriteValue(const wchar_t * name, const wchar_t* value) { + return WriteValue(name, value, + static_cast<DWORD>(sizeof(*value) * (wcslen(value) + 1)), REG_SZ); +} + +bool RegKey::WriteValue(const wchar_t* name, DWORD value) { + return WriteValue(name, &value, + static_cast<DWORD>(sizeof(value)), REG_DWORD); +} + +bool RegKey::DeleteKey(const wchar_t* name) { + return (!key_) ? false : (ERROR_SUCCESS == SHDeleteKey(key_, name)); +} + +bool RegKey::DeleteValue(const wchar_t* value_name) { + DCHECK(value_name); + HRESULT result = RegDeleteValue(key_, value_name); + return (result == ERROR_SUCCESS); +} + +bool RegKey::StartWatching() { + if (!watch_event_) + watch_event_ = CreateEvent(NULL, TRUE, FALSE, NULL); + + DWORD filter = REG_NOTIFY_CHANGE_NAME | + REG_NOTIFY_CHANGE_ATTRIBUTES | + REG_NOTIFY_CHANGE_LAST_SET | + REG_NOTIFY_CHANGE_SECURITY; + + // Watch the registry key for a change of value. + HRESULT result = RegNotifyChangeKeyValue(key_, TRUE, filter, + watch_event_, TRUE); + if (SUCCEEDED(result)) { + return true; + } else { + CloseHandle(watch_event_); + watch_event_ = 0; + return false; + } +} + +bool RegKey::StopWatching() { + if (watch_event_) { + CloseHandle(watch_event_); + watch_event_ = 0; + return true; + } + return false; +} + +bool RegKey::HasChanged() { + if (watch_event_) { + if (WaitForSingleObject(watch_event_, 0) == WAIT_OBJECT_0) { + StartWatching(); + return true; + } + } + return false; +} + +} // namespace win +} // namespace base diff --git a/base/win/registry.h b/base/win/registry.h new file mode 100644 index 0000000..d1ef25b --- /dev/null +++ b/base/win/registry.h @@ -0,0 +1,171 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef BASE_WIN_REGISTRY_H_ +#define BASE_WIN_REGISTRY_H_ +#pragma once + +#include <windows.h> +#include <string> + +#include "base/basictypes.h" + +namespace base { +namespace win { + +// Utility class to read, write and manipulate the Windows Registry. +// Registry vocabulary primer: a "key" is like a folder, in which there +// are "values", which are <name, data> pairs, with an associated data type. +class RegKey { + public: + RegKey(); + RegKey(HKEY rootkey, const wchar_t* subkey, REGSAM access); + ~RegKey(); + + bool Create(HKEY rootkey, const wchar_t* subkey, REGSAM access); + + bool CreateWithDisposition(HKEY rootkey, const wchar_t* subkey, + DWORD* disposition, REGSAM access); + + bool Open(HKEY rootkey, const wchar_t* subkey, REGSAM access); + + // Creates a subkey or open it if it already exists. + bool CreateKey(const wchar_t* name, REGSAM access); + + // Opens a subkey + bool OpenKey(const wchar_t* name, REGSAM access); + + void Close(); + + DWORD ValueCount(); + + // Determine the nth value's name. + bool ReadName(int index, std::wstring* name); + + // True while the key is valid. + bool Valid() const { return key_ != NULL; } + + // Kill a key and everything that live below it; please be careful when using + // it. + bool DeleteKey(const wchar_t* name); + + // Deletes a single value within the key. + bool DeleteValue(const wchar_t* name); + + bool ValueExists(const wchar_t* name); + + bool ReadValue(const wchar_t* name, void* data, DWORD* dsize, DWORD* dtype); + bool ReadValue(const wchar_t* name, std::wstring* value); + bool ReadValueDW(const wchar_t* name, DWORD* value); + + bool WriteValue(const wchar_t* name, const void* data, DWORD dsize, + DWORD dtype); + bool WriteValue(const wchar_t* name, const wchar_t* value); + bool WriteValue(const wchar_t* name, DWORD value); + + // Starts watching the key to see if any of its values have changed. + // The key must have been opened with the KEY_NOTIFY access privelege. + bool StartWatching(); + + // If StartWatching hasn't been called, always returns false. + // Otherwise, returns true if anything under the key has changed. + // This can't be const because the |watch_event_| may be refreshed. + bool HasChanged(); + + // Will automatically be called by destructor if not manually called + // beforehand. Returns true if it was watching, false otherwise. + bool StopWatching(); + + inline bool IsWatching() const { return watch_event_ != 0; } + HANDLE watch_event() const { return watch_event_; } + HKEY Handle() const { return key_; } + + private: + HKEY key_; // The registry key being iterated. + HANDLE watch_event_; + + DISALLOW_COPY_AND_ASSIGN(RegKey); +}; + +// Iterates the entries found in a particular folder on the registry. +// For this application I happen to know I wont need data size larger +// than MAX_PATH, but in real life this wouldn't neccessarily be +// adequate. +class RegistryValueIterator { + public: + RegistryValueIterator(HKEY root_key, const wchar_t* folder_key); + + ~RegistryValueIterator(); + + DWORD ValueCount() const; + + // True while the iterator is valid. + bool Valid() const; + + // Advances to the next registry entry. + void operator++(); + + const wchar_t* Name() const { return name_; } + const wchar_t* Value() const { return value_; } + DWORD ValueSize() const { return value_size_; } + DWORD Type() const { return type_; } + + int Index() const { return index_; } + + private: + // Read in the current values. + bool Read(); + + // The registry key being iterated. + HKEY key_; + + // Current index of the iteration. + int index_; + + // Current values. + wchar_t name_[MAX_PATH]; + wchar_t value_[MAX_PATH]; + DWORD value_size_; + DWORD type_; + + DISALLOW_COPY_AND_ASSIGN(RegistryValueIterator); +}; + +class RegistryKeyIterator { + public: + RegistryKeyIterator(HKEY root_key, const wchar_t* folder_key); + + ~RegistryKeyIterator(); + + DWORD SubkeyCount() const; + + // True while the iterator is valid. + bool Valid() const; + + // Advances to the next entry in the folder. + void operator++(); + + const wchar_t* Name() const { return name_; } + + int Index() const { return index_; } + + private: + // Read in the current values. + bool Read(); + + // The registry key being iterated. + HKEY key_; + + // Current index of the iteration. + int index_; + + wchar_t name_[MAX_PATH]; + + DISALLOW_COPY_AND_ASSIGN(RegistryKeyIterator); +}; + +} // namespace win +} // namespace base + +#endif // BASE_WIN_REGISTRY_H_ diff --git a/base/win/registry_unittest.cc b/base/win/registry_unittest.cc new file mode 100644 index 0000000..524612a --- /dev/null +++ b/base/win/registry_unittest.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2010 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "base/win/registry.h" +#include "testing/gtest/include/gtest/gtest.h" + +namespace base { +namespace win { + +namespace { + +const wchar_t kRootKey[] = L"Base_Registry_Unittest"; + +class RegistryTest : public testing::Test { + public: + RegistryTest() {} + + protected: + virtual void SetUp() { + // Create a temporary key. + RegKey key(HKEY_CURRENT_USER, L"", KEY_ALL_ACCESS); + key.DeleteKey(kRootKey); + ASSERT_FALSE(key.Open(HKEY_CURRENT_USER, kRootKey, KEY_READ)); + ASSERT_TRUE(key.Create(HKEY_CURRENT_USER, kRootKey, KEY_READ)); + } + + virtual void TearDown() { + // Clean up the temporary key. + RegKey key(HKEY_CURRENT_USER, L"", KEY_SET_VALUE); + ASSERT_TRUE(key.DeleteKey(kRootKey)); + } + + private: + DISALLOW_COPY_AND_ASSIGN(RegistryTest); +}; + +TEST_F(RegistryTest, ValueTest) { + RegKey key; + + std::wstring foo_key(kRootKey); + foo_key += L"\\Foo"; + ASSERT_TRUE(key.Create(HKEY_CURRENT_USER, foo_key.c_str(), KEY_READ)); + + { + ASSERT_TRUE(key.Open(HKEY_CURRENT_USER, foo_key.c_str(), + KEY_READ | KEY_SET_VALUE)); + + const wchar_t* kName = L"Bar"; + const wchar_t* kValue = L"bar"; + EXPECT_TRUE(key.WriteValue(kName, kValue)); + EXPECT_TRUE(key.ValueExists(kName)); + std::wstring out_value; + EXPECT_TRUE(key.ReadValue(kName, &out_value)); + EXPECT_NE(out_value, L""); + EXPECT_STREQ(out_value.c_str(), kValue); + EXPECT_EQ(1U, key.ValueCount()); + EXPECT_TRUE(key.DeleteValue(kName)); + } +} + +} // namespace + +} // namespace win +} // namespace base |