diff options
author | brettw@chromium.org <brettw@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-09-11 21:30:56 +0000 |
---|---|---|
committer | brettw@chromium.org <brettw@chromium.org@0039d316-1c4b-4281-b951-d872f2087c98> | 2009-09-11 21:30:56 +0000 |
commit | e5ffd0e471417e75ddcd5af20c3254c0ec2f1f5d (patch) | |
tree | 60e8d7c1de4ee33cc063cfa98a168e8fe9fcf27b /app/sql | |
parent | 92c3dc6b3fffbaf9fa2fa409120ca051bf317234 (diff) | |
download | chromium_src-e5ffd0e471417e75ddcd5af20c3254c0ec2f1f5d.zip chromium_src-e5ffd0e471417e75ddcd5af20c3254c0ec2f1f5d.tar.gz chromium_src-e5ffd0e471417e75ddcd5af20c3254c0ec2f1f5d.tar.bz2 |
Add a new wrapper for sqlite. This is mostly a large cleanup of the existing
one, combined with the statement cache in a nice way. It is designed to
entirely wrap sqlite so that we can catch corrupt errors in the future and
"do something" when we get them without having to change all the calling code.
There is also a new meta_table file which is almost exactly like the old one
but which uses the new sql interface.
This patch changes Chrome's history TextDatabase to use this new wrapper as a
proof of concept, because this usage is relatively well-confined.
Review URL: http://codereview.chromium.org/199047
git-svn-id: svn://svn.chromium.org/chrome/trunk/src@26022 0039d316-1c4b-4281-b951-d872f2087c98
Diffstat (limited to 'app/sql')
-rw-r--r-- | app/sql/connection.cc | 304 | ||||
-rw-r--r-- | app/sql/connection.h | 310 | ||||
-rw-r--r-- | app/sql/connection_unittest.cc | 113 | ||||
-rw-r--r-- | app/sql/meta_table.cc | 143 | ||||
-rw-r--r-- | app/sql/meta_table.h | 87 | ||||
-rw-r--r-- | app/sql/statement.cc | 212 | ||||
-rw-r--r-- | app/sql/statement.h | 132 | ||||
-rw-r--r-- | app/sql/statement_unittest.cc | 80 | ||||
-rw-r--r-- | app/sql/transaction.cc | 51 | ||||
-rw-r--r-- | app/sql/transaction.h | 56 | ||||
-rw-r--r-- | app/sql/transaction_unittest.cc | 139 |
11 files changed, 1627 insertions, 0 deletions
diff --git a/app/sql/connection.cc b/app/sql/connection.cc new file mode 100644 index 0000000..4878e1b --- /dev/null +++ b/app/sql/connection.cc @@ -0,0 +1,304 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "app/sql/connection.h" + +#include <string.h> + +#include "app/sql/statement.h" +#include "base/file_path.h" +#include "base/logging.h" +#include "base/string_util.h" +#include "third_party/sqlite/preprocessed/sqlite3.h" + +namespace sql { + +bool StatementID::operator<(const StatementID& other) const { + if (number_ != other.number_) + return number_ < other.number_; + return strcmp(str_, other.str_) < 0; +} + +Connection::StatementRef::StatementRef() + : connection_(NULL), + stmt_(NULL) { +} + +Connection::StatementRef::StatementRef(Connection* connection, + sqlite3_stmt* stmt) + : connection_(connection), + stmt_(stmt) { + connection_->StatementRefCreated(this); +} + +Connection::StatementRef::~StatementRef() { + if (connection_) + connection_->StatementRefDeleted(this); + Close(); +} + +void Connection::StatementRef::Close() { + if (stmt_) { + sqlite3_finalize(stmt_); + stmt_ = NULL; + } + connection_ = NULL; // The connection may be getting deleted. +} + +Connection::Connection() + : db_(NULL), + page_size_(0), + cache_size_(0), + exclusive_locking_(false), + transaction_nesting_(0), + needs_rollback_(false) { +} + +Connection::~Connection() { + Close(); +} + +bool Connection::Init(const FilePath& path) { +#if defined(OS_WIN) + // We want the default encoding to always be UTF-8, so we use the + // 8-bit version of open(). + int err = sqlite3_open(WideToUTF8(path.value()).c_str(), &db_); +#elif defined(OS_POSIX) + int err = sqlite3_open(path.value().c_str(), &db_); +#endif + + if (err != SQLITE_OK) { + db_ = NULL; + return false; + } + + if (page_size_ != 0) { + if (!Execute(StringPrintf("PRAGMA page_size=%d", page_size_).c_str())) + NOTREACHED() << "Could not set page size"; + } + + if (cache_size_ != 0) { + if (!Execute(StringPrintf("PRAGMA cache_size=%d", cache_size_).c_str())) + NOTREACHED() << "Could not set page size"; + } + + if (exclusive_locking_) { + if (!Execute("PRAGMA locking_mode=EXCLUSIVE")) + NOTREACHED() << "Could not set locking mode."; + } + + return true; +} + +void Connection::Close() { + statement_cache_.clear(); + DCHECK(open_statements_.empty()); + if (db_) { + sqlite3_close(db_); + db_ = NULL; + } +} + +void Connection::Preload() { + if (!db_) { + NOTREACHED(); + return; + } + + // A statement must be open for the preload command to work. If the meta + // table doesn't exist, it probably means this is a new database and there + // is nothing to preload (so it's OK we do nothing). + if (!DoesTableExist("meta")) + return; + Statement dummy(GetUniqueStatement("SELECT * FROM meta")); + if (!dummy || !dummy.Run()) + return; + + sqlite3Preload(db_); +} + +bool Connection::BeginTransaction() { + if (needs_rollback_) { + DCHECK(transaction_nesting_ > 0); + + // When we're going to rollback, fail on this begin and don't actually + // mark us as entering the nested transaction. + return false; + } + + bool success = true; + if (!transaction_nesting_) { + needs_rollback_ = false; + + Statement begin(GetCachedStatement(SQL_FROM_HERE, "BEGIN TRANSACTION")); + if (!begin || !begin.Run()) + return false; + } + transaction_nesting_++; + return success; +} + +void Connection::RollbackTransaction() { + if (!transaction_nesting_) { + NOTREACHED() << "Rolling back a nonexistant transaction"; + return; + } + + transaction_nesting_--; + + if (transaction_nesting_ > 0) { + // Mark the outermost transaction as needing rollback. + needs_rollback_ = true; + return; + } + + DoRollback(); +} + +bool Connection::CommitTransaction() { + if (!transaction_nesting_) { + NOTREACHED() << "Rolling back a nonexistant transaction"; + return false; + } + transaction_nesting_--; + + if (transaction_nesting_ > 0) { + // Mark any nested transactions as failing after we've already got one. + return !needs_rollback_; + } + + if (needs_rollback_) { + DoRollback(); + return false; + } + + Statement commit(GetCachedStatement(SQL_FROM_HERE, "COMMIT")); + if (!commit) + return false; + return commit.Run(); +} + +bool Connection::Execute(const char* sql) { + if (!db_) + return false; + return sqlite3_exec(db_, sql, NULL, NULL, NULL) == SQLITE_OK; +} + +bool Connection::HasCachedStatement(const StatementID& id) const { + return statement_cache_.find(id) != statement_cache_.end(); +} + +scoped_refptr<Connection::StatementRef> Connection::GetCachedStatement( + const StatementID& id, + const char* sql) { + CachedStatementMap::iterator i = statement_cache_.find(id); + if (i != statement_cache_.end()) { + // Statement is in the cache. It should still be active (we're the only + // one invalidating cached statements, and we'll remove it from the cache + // if we do that. Make sure we reset it before giving out the cached one in + // case it still has some stuff bound. + DCHECK(i->second->is_valid()); + sqlite3_reset(i->second->stmt()); + return i->second; + } + + scoped_refptr<StatementRef> statement = GetUniqueStatement(sql); + if (statement->is_valid()) + statement_cache_[id] = statement; // Only cache valid statements. + return statement; +} + +scoped_refptr<Connection::StatementRef> Connection::GetUniqueStatement( + const char* sql) { + if (!db_) + return new StatementRef(this, NULL); // Return inactive statement. + + sqlite3_stmt* stmt = NULL; + if (sqlite3_prepare_v2(db_, sql, -1, &stmt, NULL) != SQLITE_OK) { + // Treat this as non-fatal, it can occur in a number of valid cases, and + // callers should be doing their own error handling. + DLOG(WARNING) << "SQL compile error " << GetErrorMessage(); + return new StatementRef(this, NULL); + } + return new StatementRef(this, stmt); +} + +bool Connection::DoesTableExist(const char* table_name) { + Statement statement(GetUniqueStatement( + "SELECT name FROM sqlite_master " + "WHERE type='table' AND name=?")); + if (!statement) + return false; + statement.BindString(0, table_name); + return statement.Step(); // Table exists if any row was returned. +} + +bool Connection::DoesColumnExist(const char* table_name, + const char* column_name) { + std::string sql("PRAGMA TABLE_INFO("); + sql.append(table_name); + sql.append(")"); + + Statement statement(GetUniqueStatement(sql.c_str())); + if (!statement) + return false; + + while (statement.Step()) { + if (!statement.ColumnString(1).compare(column_name)) + return true; + } + return false; +} + +int64 Connection::GetLastInsertRowId() const { + if (!db_) { + NOTREACHED(); + return 0; + } + return sqlite3_last_insert_rowid(db_); +} + +int Connection::GetErrorCode() const { + if (!db_) + return SQLITE_ERROR; + return sqlite3_errcode(db_); +} + +const char* Connection::GetErrorMessage() const { + if (!db_) + return "sql::Connection has no connection."; + return sqlite3_errmsg(db_); +} + +void Connection::DoRollback() { + Statement rollback(GetCachedStatement(SQL_FROM_HERE, "ROLLBACK")); + if (rollback) + rollback.Run(); +} + +void Connection::StatementRefCreated(StatementRef* ref) { + DCHECK(open_statements_.find(ref) == open_statements_.end()); + open_statements_.insert(ref); +} + +void Connection::StatementRefDeleted(StatementRef* ref) { + StatementRefSet::iterator i = open_statements_.find(ref); + if (i == open_statements_.end()) + NOTREACHED(); + else + open_statements_.erase(i); +} + +void Connection::ClearCache() { + statement_cache_.clear(); + + // The cache clear will get most statements. There may be still be references + // to some statements that are held by others (including one-shot statements). + // This will deactivate them so they can't be used again. + for (StatementRefSet::iterator i = open_statements_.begin(); + i != open_statements_.end(); ++i) + (*i)->Close(); +} + +} // namespace sql diff --git a/app/sql/connection.h b/app/sql/connection.h new file mode 100644 index 0000000..2948980 --- /dev/null +++ b/app/sql/connection.h @@ -0,0 +1,310 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef APP_SQL_CONNECTION_H_ +#define APP_SQL_CONNECTION_H_ + +#include <map> +#include <set> + +#include "base/basictypes.h" +#include "base/ref_counted.h" + +class FilePath; +struct sqlite3; +struct sqlite3_stmt; + +namespace sql { + +class Statement; + +// Uniquely identifies a statement. There are two modes of operation: +// +// - In the most common mode, you will use the source file and line number to +// identify your statement. This is a convienient way to get uniqueness for +// a statement that is only used in one place. Use the SQL_FROM_HERE macro +// to generate a StatementID. +// +// - In the "custom" mode you may use the statement from different places or +// need to manage it yourself for whatever reason. In this case, you should +// make up your own unique name and pass it to the StatementID. This name +// must be a static string, since this object only deals with pointers and +// assumes the underlying string doesn't change or get deleted. +// +// This object is copyable and assignable using the compiler-generated +// operator= and copy constructor. +class StatementID { + public: + // Creates a uniquely named statement with the given file ane line number. + // Normally you will use SQL_FROM_HERE instead of calling yourself. + StatementID(const char* file, int line) + : number_(line), + str_(file) { + } + + // Creates a uniquely named statement with the given user-defined name. + explicit StatementID(const char* unique_name) + : number_(-1), + str_(unique_name) { + } + + // This constructor is unimplemented and will generate a linker error if + // called. It is intended to try to catch people dynamically generating + // a statement name that will be deallocated and will cause a crash later. + // All strings must be static and unchanging! + explicit StatementID(const std::string& dont_ever_do_this); + + // We need this to insert into our map. + bool operator<(const StatementID& other) const; + + private: + int number_; + const char* str_; +}; + +#define SQL_FROM_HERE sql::StatementID(__FILE__, __LINE__) + +class Connection { + private: + class StatementRef; // Forward declaration, see real one below. + + public: + // The database is opened by calling Init(). Any uncommitted transactions + // will be rolled back when this object is deleted. + Connection(); + ~Connection(); + + // Pre-init configuration ---------------------------------------------------- + + // Sets the page size that will be used when creating a new adtabase. This + // must be called before Init(), and will only have an effect on new + // databases. + // + // From sqlite.org: "The page size must be a power of two greater than or + // equal to 512 and less than or equal to SQLITE_MAX_PAGE_SIZE. The maximum + // value for SQLITE_MAX_PAGE_SIZE is 32768." + void set_page_size(int page_size) { page_size_ = page_size; } + + // Sets the number of pages that will be cached in memory by sqlite. The + // total cache size in bytes will be page_size * cache_size. This must be + // called before Init() to have an effect. + void set_cache_size(int cache_size) { cache_size_ = cache_size; } + + // Call to put the database in exclusive locking mode. There is no "back to + // normal" flag because of some additional requirements sqlite puts on this + // transaition (requires another access to the DB) and because we don't + // actually need it. + // + // Exclusive mode means that the database is not unlocked at the end of each + // transaction, which means there may be less time spent initializing the + // next transaction because it doesn't have to re-aquire locks. + // + // This must be called before Init() to have an effect. + void set_exclusive_locking() { exclusive_locking_ = true; } + + // Initialization ------------------------------------------------------------ + + // Initializes the SQL connection for the given file, returning true if the + // file could be opened. + bool Init(const FilePath& path); + + // Closes the database. This is automatically performed on destruction for + // you, but this allows you to close the database early. You must not call + // any other functions after closing it. It is permissable to call Close on + // an uninitialized or already-closed database. + void Close(); + + // Pre-loads the first <cache-size> pages into the cache from the file. + // If you expect to soon use a substantial portion of the database, this + // is much more efficient than allowing the pages to be populated organically + // since there is no per-page hard drive seeking. If the file is larger than + // the cache, the last part that doesn't fit in the cache will be brought in + // organically. + // + // This function assumes your class is using a meta table on the current + // database, as it openes a transaction on the meta table to force the + // database to be initialized. You should feel free to initialize the meta + // table after calling preload since the meta table will already be in the + // database if it exists, and if it doesn't exist, the database won't + // generally exist either. + void Preload(); + + // Transactions -------------------------------------------------------------- + + // Transaction management. We maintain a virtual transaction stack to emulate + // nested transactions since sqlite can't do nested transactions. The + // limitation is you can't roll back a sub transaction: if any transaction + // fails, all transactions open will also be rolled back. Any nested + // transactions after one has rolled back will return fail for Begin(). If + // Begin() fails, you must not call Commit or Rollback(). + // + // Normally you should use sql::Transaction to manage a transaction, which + // will scope it to a C++ context. + bool BeginTransaction(); + void RollbackTransaction(); + bool CommitTransaction(); + + // Returns the current transaction nesting, which will be 0 if there are + // no open transactions. + int transaction_nesting() const { return transaction_nesting_; } + + // Statements ---------------------------------------------------------------- + + // Executes the given SQL string, returning true on success. This is + // normally used for simple, 1-off statements that don't take any bound + // parameters and don't return any data (e.g. CREATE TABLE). + bool Execute(const char* sql); + + // Returns true if we have a statement with the given identifier already + // cached. This is normally not necessary to call, but can be useful if the + // caller has to dynamically build up SQL to avoid doing so if it's already + // cached. + bool HasCachedStatement(const StatementID& id) const; + + // Returns a statement for the given SQL using the statement cache. It can + // take a nontrivial amount of work to parse and compile a statement, so + // keeping commonly-used ones around for future use is important for + // performance. + // + // The SQL may have an error, so the caller must check validity of the + // statement before using it. + // + // The StatementID and the SQL must always correspond to one-another. The + // ID is the lookup into the cache, so crazy things will happen if you use + // different SQL with the same ID. + // + // You will normally use the SQL_FROM_HERE macro to generate a statement + // ID associated with the current line of code. This gives uniqueness without + // you having to manage unique names. See StatementID above for more. + // + // Example: + // sql::Statement stmt = connection_.GetCachedStatement( + // SQL_FROM_HERE, "SELECT * FROM foo"); + // if (!stmt) + // return false; // Error creating statement. + scoped_refptr<StatementRef> GetCachedStatement(const StatementID& id, + const char* sql); + + // Returns a non-cached statement for the given SQL. Use this for SQL that + // is only executed once or only rarely (there is overhead associated with + // keeping a statement cached). + // + // See GetCachedStatement above for examples and error information. + scoped_refptr<StatementRef> GetUniqueStatement(const char* sql); + + // Info querying ------------------------------------------------------------- + + // Returns true if the given table exists. + bool DoesTableExist( const char* table_name); + + // Returns true if a column with the given name exists in the given table. + bool DoesColumnExist(const char* table_name, const char* column_name); + + // Returns sqlite's internal ID for the last inserted row. Valid only + // immediately after an insert. + int64 GetLastInsertRowId() const; + + // Errors -------------------------------------------------------------------- + + // Returns the error code associated with the last sqlite operation. + int GetErrorCode() const; + + // Returns a pointer to a statically allocated string associated with the + // last sqlite operation. + const char* GetErrorMessage() const; + + private: + // Statement access StatementRef which we don't want to expose to erverybody + // (they should go through Statement). + friend class Statement; + + // A StatementRef is a refcounted wrapper around a sqlite statement pointer. + // Refcounting allows us to give these statements out to sql::Statement + // objects while also optionally maintaining a cache of compiled statements + // by just keeping a refptr to these objects. + // + // A statement ref can be valid, in which case it can be used, or invalid to + // indicate that the statement hasn't been created yet, has an error, or has + // been destroyed. + // + // The Connection may revoke a StatementRef in some error cases, so callers + // should always check validity before using. + class StatementRef : public base::RefCounted<StatementRef> { + public: + // Default constructor initializes to an invalid statement. + StatementRef(); + StatementRef(Connection* connection, sqlite3_stmt* stmt); + ~StatementRef(); + + // When true, the statement can be used. + bool is_valid() const { return !!stmt_; } + + // If we've not been linked to a connection, this will be NULL. Guaranteed + // non-NULL when is_valid(). + Connection* connection() const { return connection_; } + + // Returns the sqlite statement if any. If the statement is not active, + // this will return NULL. + sqlite3_stmt* stmt() const { return stmt_; } + + // Destroys the compiled statement and marks it NULL. The statement will + // no longer be active. + void Close(); + + private: + Connection* connection_; + sqlite3_stmt* stmt_; + + DISALLOW_COPY_AND_ASSIGN(StatementRef); + }; + friend class StatementRef; + + // Executes a rollback statement, ignoring all transaction state. Used + // internally in the transaction management code. + void DoRollback(); + + // Called by a StatementRef when it's being created or destroyed. See + // open_statements_ below. + void StatementRefCreated(StatementRef* ref); + void StatementRefDeleted(StatementRef* ref); + + // Frees all cached statements from statement_cache_. + void ClearCache(); + + // The actual sqlite database. Will be NULL before Init has been called or if + // Init resulted in an error. + sqlite3* db_; + + // Parameters we'll configure in sqlite before doing anything else. Zero means + // use the default value. + int page_size_; + int cache_size_; + bool exclusive_locking_; + + // All cached statements. Keeping a reference to these statements means that + // they'll remain active. + typedef std::map<StatementID, scoped_refptr<StatementRef> > + CachedStatementMap; + CachedStatementMap statement_cache_; + + // A list of all StatementRefs we've given out. Each ref must register with + // us when it's created or destroyed. This allows us to potentially close + // any open statements when we encounter an error. + typedef std::set<StatementRef*> StatementRefSet; + StatementRefSet open_statements_; + + // Number of currently-nested transactions. + int transaction_nesting_; + + // True if any of the currently nested transactions have been rolled back. + // When we get to the outermost transaction, this will determine if we do + // a rollback instead of a commit. + bool needs_rollback_; + + DISALLOW_COPY_AND_ASSIGN(Connection); +}; + +} // namespace sql + +#endif // APP_SQL_CONNECTION_H_ diff --git a/app/sql/connection_unittest.cc b/app/sql/connection_unittest.cc new file mode 100644 index 0000000..70c9ffc --- /dev/null +++ b/app/sql/connection_unittest.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "app/sql/connection.h" +#include "app/sql/statement.h" +#include "base/file_path.h" +#include "base/file_util.h" +#include "base/path_service.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "third_party/sqlite/preprocessed/sqlite3.h" + +class SQLConnectionTest : public testing::Test { + public: + SQLConnectionTest() {} + + void SetUp() { + ASSERT_TRUE(PathService::Get(base::DIR_TEMP, &path_)); + path_ = path_.AppendASCII("SQLConnectionTest.db"); + file_util::Delete(path_, false); + ASSERT_TRUE(db_.Init(path_)); + } + + void TearDown() { + db_.Close(); + + // If this fails something is going on with cleanup and later tests may + // fail, so we want to identify problems right away. + ASSERT_TRUE(file_util::Delete(path_, false)); + } + + sql::Connection& db() { return db_; } + + private: + FilePath path_; + sql::Connection db_; +}; + +TEST_F(SQLConnectionTest, Execute) { + // Valid statement should return true. + ASSERT_TRUE(db().Execute("CREATE TABLE foo (a, b)")); + EXPECT_EQ(SQLITE_OK, db().GetErrorCode()); + + // Invalid statement should fail. + ASSERT_FALSE(db().Execute("CREATE TAB foo (a, b")); + EXPECT_EQ(SQLITE_ERROR, db().GetErrorCode()); +} + +TEST_F(SQLConnectionTest, CachedStatement) { + sql::StatementID id1("foo", 12); + + ASSERT_TRUE(db().Execute("CREATE TABLE foo (a, b)")); + ASSERT_TRUE(db().Execute("INSERT INTO foo(a, b) VALUES (12, 13)")); + + // Create a new cached statement. + { + sql::Statement s(db().GetCachedStatement(id1, "SELECT a FROM foo")); + ASSERT_FALSE(!s); // Test ! operator for validity. + + ASSERT_TRUE(s.Step()); + EXPECT_EQ(12, s.ColumnInt(0)); + } + + // The statement should be cached still. + EXPECT_TRUE(db().HasCachedStatement(id1)); + + { + // Get the same statement using different SQL. This should ignore our + // SQL and use the cached one (so it will be valid). + sql::Statement s(db().GetCachedStatement(id1, "something invalid(")); + ASSERT_FALSE(!s); // Test ! operator for validity. + + ASSERT_TRUE(s.Step()); + EXPECT_EQ(12, s.ColumnInt(0)); + } + + // Make sure other statements aren't marked as cached. + EXPECT_FALSE(db().HasCachedStatement(SQL_FROM_HERE)); +} + +TEST_F(SQLConnectionTest, DoesStuffExist) { + // Test DoesTableExist. + EXPECT_FALSE(db().DoesTableExist("foo")); + ASSERT_TRUE(db().Execute("CREATE TABLE foo (a, b)")); + EXPECT_TRUE(db().DoesTableExist("foo")); + + // Should be case sensitive. + EXPECT_FALSE(db().DoesTableExist("FOO")); + + // Test DoesColumnExist. + EXPECT_FALSE(db().DoesColumnExist("foo", "bar")); + EXPECT_TRUE(db().DoesColumnExist("foo", "a")); + + // Testing for a column on a nonexistant table. + EXPECT_FALSE(db().DoesColumnExist("bar", "b")); +} + +TEST_F(SQLConnectionTest, GetLastInsertRowId) { + ASSERT_TRUE(db().Execute("CREATE TABLE foo (id INTEGER PRIMARY KEY, value)")); + + ASSERT_TRUE(db().Execute("INSERT INTO foo (value) VALUES (12)")); + + // Last insert row ID should be valid. + int64 row = db().GetLastInsertRowId(); + EXPECT_LT(0, row); + + // It should be the primary key of the row we just inserted. + sql::Statement s(db().GetUniqueStatement("SELECT value FROM foo WHERE id=?")); + s.BindInt64(0, row); + ASSERT_TRUE(s.Step()); + EXPECT_EQ(12, s.ColumnInt(0)); +} + diff --git a/app/sql/meta_table.cc b/app/sql/meta_table.cc new file mode 100644 index 0000000..03a1245 --- /dev/null +++ b/app/sql/meta_table.cc @@ -0,0 +1,143 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "app/sql/meta_table.h" + +#include "app/sql/connection.h" +#include "app/sql/statement.h" +#include "base/logging.h" +#include "base/string_util.h" + +namespace sql { + +// Key used in our meta table for version numbers. +static const char kVersionKey[] = "version"; +static const char kCompatibleVersionKey[] = "last_compatible_version"; + +MetaTable::MetaTable() : db_(NULL) { +} + +MetaTable::~MetaTable() { +} + +bool MetaTable::Init(Connection* db, int version, int compatible_version) { + DCHECK(!db_ && db); + db_ = db; + if (!db_->DoesTableExist("meta")) { + if (!db_->Execute("CREATE TABLE meta" + "(key LONGVARCHAR NOT NULL UNIQUE PRIMARY KEY," + "value LONGVARCHAR)")) + return false; + + // Note: there is no index over the meta table. We currently only have a + // couple of keys, so it doesn't matter. If we start storing more stuff in + // there, we should create an index. + SetVersionNumber(version); + SetCompatibleVersionNumber(compatible_version); + } + return true; +} + +bool MetaTable::SetValue(const char* key, const std::string& value) { + Statement s; + if (!PrepareSetStatement(&s, key)) + return false; + s.BindString(1, value); + return s.Run(); +} + +bool MetaTable::GetValue(const char* key, std::string* value) { + Statement s; + if (!PrepareGetStatement(&s, key)) + return false; + + *value = s.ColumnString(0); + return true; +} + +bool MetaTable::SetValue(const char* key, int value) { + Statement s; + if (!PrepareSetStatement(&s, key)) + return false; + + s.BindInt(1, value); + return s.Run(); +} + +bool MetaTable::GetValue(const char* key, int* value) { + Statement s; + if (!PrepareGetStatement(&s, key)) + return false; + + *value = s.ColumnInt(0); + return true; +} + +bool MetaTable::SetValue(const char* key, int64 value) { + Statement s; + if (!PrepareSetStatement(&s, key)) + return false; + s.BindInt64(1, value); + return s.Run(); +} + +bool MetaTable::GetValue(const char* key, int64* value) { + Statement s; + if (!PrepareGetStatement(&s, key)) + return false; + + *value = s.ColumnInt64(0); + return true; +} + +void MetaTable::SetVersionNumber(int version) { + SetValue(kVersionKey, version); +} + +int MetaTable::GetVersionNumber() { + int version = 0; + if (!GetValue(kVersionKey, &version)) + return 0; + return version; +} + +void MetaTable::SetCompatibleVersionNumber(int version) { + SetValue(kCompatibleVersionKey, version); +} + +int MetaTable::GetCompatibleVersionNumber() { + int version = 0; + if (!GetValue(kCompatibleVersionKey, &version)) + return 0; + return version; +} + +bool MetaTable::PrepareSetStatement(Statement* statement, const char* key) { + DCHECK(db_ && statement); + statement->Assign(db_->GetCachedStatement(SQL_FROM_HERE, + "INSERT OR REPLACE INTO meta (key,value) VALUES (?,?)")); + if (!*statement) { + NOTREACHED() << db_->GetErrorMessage(); + return false; + } + statement->BindCString(0, key); + return true; +} + +bool MetaTable::PrepareGetStatement(Statement* statement, const char* key) { + DCHECK(db_ && statement); + statement->Assign(db_->GetCachedStatement(SQL_FROM_HERE, + "SELECT value FROM meta WHERE key=?")); + if (!*statement) { + NOTREACHED() << db_->GetErrorMessage(); + return false; + } + statement->BindCString(0, key); + if (!statement->Step()) + return false; + return true; +} + +} // namespace sql + diff --git a/app/sql/meta_table.h b/app/sql/meta_table.h new file mode 100644 index 0000000..6ccee17 --- /dev/null +++ b/app/sql/meta_table.h @@ -0,0 +1,87 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef APP_SQL_META_TABLE_H_ +#define APP_SQL_META_TABLE_H_ + +#include <string> + +#include "base/basictypes.h" + +namespace sql { + +class Connection; +class Statement; + +class MetaTable { + public: + MetaTable(); + ~MetaTable(); + + // Initializes the MetaTableHelper, creating the meta table if necessary. For + // new tables, it will initialize the version number to |version| and the + // compatible version number to |compatible_version|. + // + // The name of the database in the sqlite connection (for tables named with + // the "db_name.table_name" scheme is given in |db_name|. If empty, it is + // assumed there is no database name. + bool Init(Connection* db, + int version, + int compatible_version); + + // The version number of the database. This should be the version number of + // the creator of the file. The version number will be 0 if there is no + // previously set version number. + // + // See also Get/SetCompatibleVersionNumber(). + void SetVersionNumber(int version); + int GetVersionNumber(); + + // The compatible version number is the lowest version of the code that this + // database can be read by. If there are minor changes or additions, old + // versions of the code can still work with the database without failing. + // + // For example, if an optional column is added to a table in version 3, the + // new code will set the version to 3, and the compatible version to 2, since + // the code expecting version 2 databases can still read and write the table. + // + // Rule of thumb: check the version number when you're upgrading, but check + // the compatible version number to see if you can read the file at all. If + // it's larger than you code is expecting, fail. + // + // The compatible version number will be 0 if there is no previously set + // compatible version number. + void SetCompatibleVersionNumber(int version); + int GetCompatibleVersionNumber(); + + // Set the given arbitrary key with the given data. Returns true on success. + bool SetValue(const char* key, const std::string& value); + bool SetValue(const char* key, int value); + bool SetValue(const char* key, int64 value); + + // Retrieves the value associated with the given key. This will use sqlite's + // type conversion rules. It will return true on success. + bool GetValue(const char* key, std::string* value); + bool GetValue(const char* key, int* value); + bool GetValue(const char* key, int64* value); + + private: + // Conveniences to prepare the two types of statements used by + // MetaTableHelper. + bool PrepareSetStatement(Statement* statement, const char* key); + bool PrepareGetStatement(Statement* statement, const char* key); + + Connection* db_; + + // Name of the database within the connection, if there is one. When empty, + // there is no special database name and the table name can be used + // unqualified. + std::string db_name_; + + DISALLOW_COPY_AND_ASSIGN(MetaTable); +}; + +} // namespace sql + +#endif // APP_SQL_META_TABLE_H_ diff --git a/app/sql/statement.cc b/app/sql/statement.cc new file mode 100644 index 0000000..0b419ba --- /dev/null +++ b/app/sql/statement.cc @@ -0,0 +1,212 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "app/sql/statement.h" + +#include "base/logging.h" +#include "third_party/sqlite/preprocessed/sqlite3.h" + +namespace sql { + +// This empty constructor initializes our reference with an empty one so that +// we don't have to NULL-check the ref_ to see if the statement is valid: we +// only have to check the ref's validity bit. +Statement::Statement() + : ref_(new Connection::StatementRef), + succeeded_(false) { +} + +Statement::Statement(scoped_refptr<Connection::StatementRef> ref) + : ref_(ref), + succeeded_(false) { +} + +Statement::~Statement() { + // Free the resources associated with this statement. We assume there's only + // one statement active for a given sqlite3_stmt at any time, so this won't + // mess with anything. + Reset(); +} + +void Statement::Assign(scoped_refptr<Connection::StatementRef> ref) { + ref_ = ref; +} + +bool Statement::Run() { + if (!is_valid()) + return false; + return CheckError(sqlite3_step(ref_->stmt())) == SQLITE_DONE; +} + +bool Statement::Step() { + if (!is_valid()) + return false; + return CheckError(sqlite3_step(ref_->stmt())) == SQLITE_ROW; +} + +void Statement::Reset() { + if (is_valid()) + CheckError(sqlite3_reset(ref_->stmt())); + succeeded_ = false; +} + +bool Statement::Succeeded() const { + if (!is_valid()) + return false; + return succeeded_; +} + +bool Statement::BindNull(int col) { + if (is_valid()) { + int err = CheckError(sqlite3_bind_null(ref_->stmt(), col + 1)); + DCHECK(err == SQLITE_OK) << ref_->connection()->GetErrorMessage(); + return err == SQLITE_OK; + } + return false; +} + +bool Statement::BindInt(int col, int val) { + if (is_valid()) { + int err = CheckError(sqlite3_bind_int(ref_->stmt(), col + 1, val)); + DCHECK(err == SQLITE_OK) << ref_->connection()->GetErrorMessage(); + return err == SQLITE_OK; + } + return false; +} + +bool Statement::BindInt64(int col, int64 val) { + if (is_valid()) { + int err = CheckError(sqlite3_bind_int64(ref_->stmt(), col + 1, val)); + DCHECK(err == SQLITE_OK) << ref_->connection()->GetErrorMessage(); + return err == SQLITE_OK; + } + return false; +} + +bool Statement::BindDouble(int col, double val) { + if (is_valid()) { + int err = CheckError(sqlite3_bind_double(ref_->stmt(), col + 1, val)); + DCHECK(err == SQLITE_OK) << ref_->connection()->GetErrorMessage(); + return err == SQLITE_OK; + } + return false; +} + +bool Statement::BindCString(int col, const char* val) { + if (is_valid()) { + int err = CheckError(sqlite3_bind_text(ref_->stmt(), col + 1, val, -1, + SQLITE_TRANSIENT)); + DCHECK(err == SQLITE_OK) << ref_->connection()->GetErrorMessage(); + return err == SQLITE_OK; + } + return false; +} + +bool Statement::BindString(int col, const std::string& val) { + if (is_valid()) { + int err = CheckError(sqlite3_bind_text(ref_->stmt(), col + 1, val.data(), + val.size(), SQLITE_TRANSIENT)); + DCHECK(err == SQLITE_OK) << ref_->connection()->GetErrorMessage(); + return err == SQLITE_OK; + } + return false; +} + +bool Statement::BindBlob(int col, const void* val, int val_len) { + if (is_valid()) { + int err = CheckError(sqlite3_bind_blob(ref_->stmt(), col + 1, + val, val_len, SQLITE_TRANSIENT)); + DCHECK(err == SQLITE_OK) << ref_->connection()->GetErrorMessage(); + return err == SQLITE_OK; + } + return false; +} + +int Statement::ColumnCount() const { + if (!is_valid()) { + NOTREACHED(); + return 0; + } + return sqlite3_column_count(ref_->stmt()); +} + +int Statement::ColumnInt(int col) const { + if (!is_valid()) { + NOTREACHED(); + return 0; + } + return sqlite3_column_int(ref_->stmt(), col); +} + +int64 Statement::ColumnInt64(int col) const { + if (!is_valid()) { + NOTREACHED(); + return 0; + } + return sqlite3_column_int64(ref_->stmt(), col); +} + +double Statement::ColumnDouble(int col) const { + if (!is_valid()) { + NOTREACHED(); + return 0; + } + return sqlite3_column_double(ref_->stmt(), col); +} + +std::string Statement::ColumnString(int col) const { + if (!is_valid()) { + NOTREACHED(); + return 0; + } + const char* str = reinterpret_cast<const char*>( + sqlite3_column_text(ref_->stmt(), col)); + int len = sqlite3_column_bytes(ref_->stmt(), col); + + std::string result; + if (str && len > 0) + result.assign(str, len); + return result; +} + +int Statement::ColumnByteLength(int col) { + if (!is_valid()) { + NOTREACHED(); + return 0; + } + return sqlite3_column_bytes(ref_->stmt(), col); +} + +const void* Statement::ColumnBlob(int col) { + if (!is_valid()) { + NOTREACHED(); + return NULL; + } + + return sqlite3_column_blob(ref_->stmt(), col); +} + +void Statement::ColumnBlobAsVector(int col, std::vector<char>* val) { + val->clear(); + if (!is_valid()) { + NOTREACHED(); + return; + } + + const void* data = sqlite3_column_blob(ref_->stmt(), col); + int len = sqlite3_column_bytes(ref_->stmt(), col); + if (data && len > 0) { + val->resize(len); + memcpy(&(*val)[0], data, len); + } +} + +int Statement::CheckError(int err) { + succeeded_ = (err == SQLITE_OK || err == SQLITE_ROW || err == SQLITE_DONE); + + // TODO(brettw) enhance this to process the error. + return err; +} + +} // namespace sql diff --git a/app/sql/statement.h b/app/sql/statement.h new file mode 100644 index 0000000..371bc4e --- /dev/null +++ b/app/sql/statement.h @@ -0,0 +1,132 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef APP_SQL_STATEMENT_H_ +#define APP_SQL_STATEMENT_H_ + +#include <string> +#include <vector> + +#include "app/sql/connection.h" +#include "base/basictypes.h" +#include "base/ref_counted.h" + +namespace sql { + +// Normal usage: +// sql::Statement s = connection_.GetUniqueStatement(...); +// if (!s) // You should check for errors before using the statement. +// return false; +// +// s.BindInt(0, a); +// if (s.Step()) +// return s.ColumnString(0); +class Statement { + public: + // Creates an uninitialized statement. The statement will be invalid until + // you initialize it via Assign. + Statement(); + + Statement(scoped_refptr<Connection::StatementRef> ref); + ~Statement(); + + // Initializes this object with the given statement, which may or may not + // be valid. Use is_valid() to check if it's OK. + void Assign(scoped_refptr<Connection::StatementRef> ref); + + // Returns true if the statement can be executed. All functions can still + // be used if the statement is invalid, but they will return failure or some + // default value. This is because the statement can become invalid in the + // middle of executing a command if there is a serioud error and the database + // has to be reset. + bool is_valid() const { return ref_->is_valid(); } + + // These operators allow conveniently checking if the statement is valid + // or not. See the pattern above for an example. + operator bool() const { return is_valid(); } + bool operator!() const { return !is_valid(); } + + // Running ------------------------------------------------------------------- + + // Executes the statement, returning true on success. This is like Step but + // for when there is no output, like an INSERT statement. + bool Run(); + + // Executes the statement, returning true if there is a row of data returned. + // You can keep calling Step() until it returns false to iterate through all + // the rows in your result set. + // + // When Step returns false, the result is either that there is no more data + // or there is an error. This makes it most convenient for loop usage. If you + // need to disambiguate these cases, use Succeeded(). + // + // Typical example: + // while (s.Step()) { + // ... + // } + // return s.Succeeded(); + bool Step(); + + // Resets the statement to its initial condition. This includes clearing all + // the bound variables and any current result row. + void Reset(); + + // Returns true if the last executed thing in this statement succeeded. If + // there was no last executed thing or the statement is invalid, this will + // return false. + bool Succeeded() const; + + // Binding ------------------------------------------------------------------- + + // These all take a 0-based argument index and return true on failure. You + // may not always care about the return value (they'll DCHECK if they fail). + // The main thing you may want to check is when binding large blobs or + // strings there may be out of memory. + bool BindNull(int col); + bool BindInt(int col, int val); + bool BindInt64(int col, int64 val); + bool BindDouble(int col, double val); + bool BindCString(int col, const char* val); + bool BindString(int col, const std::string& val); + bool BindBlob(int col, const void* value, int value_len); + + // Retrieving ---------------------------------------------------------------- + + // Returns the number of output columns in the result. + int ColumnCount() const; + + // These all take a 0-based argument index. + int ColumnInt(int col) const; + int64 ColumnInt64(int col) const; + double ColumnDouble(int col) const; + std::string ColumnString(int col) const; + + // When reading a blob, you can get a raw pointer to the underlying data, + // along with the length, or you can just ask us to copy the blob into a + // vector. Danger! ColumnBlob may return NULL if there is no data! + int ColumnByteLength(int col); + const void* ColumnBlob(int col); + void ColumnBlobAsVector(int col, std::vector<char>* val); + + private: + // This is intended to check for serious errors and report them to the + // connection object. It takes a sqlite error code, and returns the same + // code. Currently this function just updates the succeeded flag, but will be + // enhanced in the future to do the notification. + int CheckError(int err); + + // The actual sqlite statement. This may be unique to us, or it may be cached + // by the connection, which is why it's refcounted. This pointer is + // guaranteed non-NULL. + scoped_refptr<Connection::StatementRef> ref_; + + // See Succeeded() for what this holds. + bool succeeded_; + + DISALLOW_COPY_AND_ASSIGN(Statement); +}; + +} // namespace sql + +#endif // APP_SQL_STATEMENT_H_ diff --git a/app/sql/statement_unittest.cc b/app/sql/statement_unittest.cc new file mode 100644 index 0000000..30a869c --- /dev/null +++ b/app/sql/statement_unittest.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "app/sql/connection.h" +#include "app/sql/statement.h" +#include "base/file_path.h" +#include "base/file_util.h" +#include "base/path_service.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "third_party/sqlite/preprocessed/sqlite3.h" + +class SQLStatementTest : public testing::Test { + public: + SQLStatementTest() {} + + void SetUp() { + ASSERT_TRUE(PathService::Get(base::DIR_TEMP, &path_)); + path_ = path_.AppendASCII("SQLStatementTest.db"); + file_util::Delete(path_, false); + ASSERT_TRUE(db_.Init(path_)); + } + + void TearDown() { + db_.Close(); + + // If this fails something is going on with cleanup and later tests may + // fail, so we want to identify problems right away. + ASSERT_TRUE(file_util::Delete(path_, false)); + } + + sql::Connection& db() { return db_; } + + private: + FilePath path_; + sql::Connection db_; +}; + +TEST_F(SQLStatementTest, Assign) { + sql::Statement s; + EXPECT_FALSE(s); // bool conversion operator. + EXPECT_TRUE(!s); // ! operator. + EXPECT_FALSE(s.is_valid()); + + s.Assign(db().GetUniqueStatement("CREATE TABLE foo (a, b)")); + EXPECT_TRUE(s); + EXPECT_FALSE(!s); + EXPECT_TRUE(s.is_valid()); +} + +TEST_F(SQLStatementTest, Run) { + ASSERT_TRUE(db().Execute("CREATE TABLE foo (a, b)")); + ASSERT_TRUE(db().Execute("INSERT INTO foo (a, b) VALUES (3, 12)")); + + sql::Statement s(db().GetUniqueStatement("SELECT b FROM foo WHERE a=?")); + EXPECT_FALSE(s.Succeeded()); + + // Stepping it won't work since we haven't bound the value. + EXPECT_FALSE(s.Step()); + + // Run should fail since this produces output, and we should use Step(). This + // gets a bit wonky since sqlite says this is OK so succeeded is set. + s.Reset(); + s.BindInt(0, 3); + EXPECT_FALSE(s.Run()); + EXPECT_EQ(SQLITE_ROW, db().GetErrorCode()); + EXPECT_TRUE(s.Succeeded()); + + // Resetting it should put it back to the previous state (not runnable). + s.Reset(); + EXPECT_FALSE(s.Succeeded()); + + // Binding and stepping should produce one row. + s.BindInt(0, 3); + EXPECT_TRUE(s.Step()); + EXPECT_TRUE(s.Succeeded()); + EXPECT_EQ(12, s.ColumnInt(0)); + EXPECT_FALSE(s.Step()); + EXPECT_TRUE(s.Succeeded()); +} diff --git a/app/sql/transaction.cc b/app/sql/transaction.cc new file mode 100644 index 0000000..79a198b --- /dev/null +++ b/app/sql/transaction.cc @@ -0,0 +1,51 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "app/sql/transaction.h" + +#include "app/sql/connection.h" +#include "base/logging.h" + +namespace sql { + +Transaction::Transaction(Connection* connection) + : connection_(connection), + is_open_(false) { +} + +Transaction::~Transaction() { + if (is_open_) + connection_->RollbackTransaction(); +} + +bool Transaction::Begin() { + if (is_open_) { + NOTREACHED() << "Beginning a transaction twice!"; + return false; + } + is_open_ = connection_->BeginTransaction(); + return is_open_; +} + +void Transaction::Rollback() { + if (!is_open_) { + NOTREACHED() << "Attempting to roll back a nonexistant transaction. " + << "Did you remember to call Begin() and check its return?"; + return; + } + is_open_ = false; + connection_->RollbackTransaction(); +} + +bool Transaction::Commit() { + if (!is_open_) { + NOTREACHED() << "Attempting to commit a nonexistant transaction. " + << "Did you remember to call Begin() and check its return?"; + return false; + } + is_open_ = false; + return connection_->CommitTransaction(); +} + +} // namespace sql diff --git a/app/sql/transaction.h b/app/sql/transaction.h new file mode 100644 index 0000000..1e00c6f --- /dev/null +++ b/app/sql/transaction.h @@ -0,0 +1,56 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#ifndef APP_SQL_TRANSACTION_H_ +#define APP_SQL_TRANSACTION_H_ + +#include "base/basictypes.h" + +namespace sql { + +class Connection; + +class Transaction { + public: + // Creates the scoped transaction object. You MUST call Begin() to begin the + // transaction. If you have begun a transaction and not committed it, the + // constructor will roll back the transaction. If you want to commit, you + // need to manually call Commit before this goes out of scope. + Transaction(Connection* connection); + ~Transaction(); + + // Returns true when there is a transaction that has been successfully begun. + bool is_open() const { return is_open_; } + + // Begins the transaction. This uses the default sqlite "deferred" transaction + // type, which means that the DB lock is lazily acquired the next time the + // database is accessed, not in the begin transaction command. + // + // Returns false on failure. Note that if this fails, you shouldn't do + // anything you expect to be actually transactional, because it won't be! + bool Begin(); + + // Rolls back the transaction. This will happen automatically if you do + // nothing when the transaction goes out of scope. + void Rollback(); + + // Commits the transaction, returning true on success. This will return + // false if sqlite could not commit it, or if another transaction in the + // same outermost transaction has been rolled back (which necessitates a + // rollback of all transactions in that outermost one). + bool Commit(); + + private: + Connection* connection_; + + // True when the transaction is open, false when it's already been committed + // or rolled back. + bool is_open_; + + DISALLOW_COPY_AND_ASSIGN(Transaction); +}; + +} // namespace sql + +#endif // APP_SQL_TRANSACTION_H_ diff --git a/app/sql/transaction_unittest.cc b/app/sql/transaction_unittest.cc new file mode 100644 index 0000000..0da79e3 --- /dev/null +++ b/app/sql/transaction_unittest.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2009 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "app/sql/connection.h" +#include "app/sql/statement.h" +#include "app/sql/transaction.h" +#include "base/file_path.h" +#include "base/file_util.h" +#include "base/path_service.h" +#include "testing/gtest/include/gtest/gtest.h" +#include "third_party/sqlite/preprocessed/sqlite3.h" + +class SQLTransactionTest : public testing::Test { + public: + SQLTransactionTest() {} + + void SetUp() { + ASSERT_TRUE(PathService::Get(base::DIR_TEMP, &path_)); + path_ = path_.AppendASCII("SQLStatementTest.db"); + file_util::Delete(path_, false); + ASSERT_TRUE(db_.Init(path_)); + + ASSERT_TRUE(db().Execute("CREATE TABLE foo (a, b)")); + } + + void TearDown() { + db_.Close(); + + // If this fails something is going on with cleanup and later tests may + // fail, so we want to identify problems right away. + ASSERT_TRUE(file_util::Delete(path_, false)); + } + + sql::Connection& db() { return db_; } + + // Returns the number of rows in table "foo". + int CountFoo() { + sql::Statement count(db().GetUniqueStatement("SELECT count(*) FROM foo")); + count.Step(); + return count.ColumnInt(0); + } + + private: + FilePath path_; + sql::Connection db_; +}; + +TEST_F(SQLTransactionTest, Commit) { + { + sql::Transaction t(&db()); + EXPECT_FALSE(t.is_open()); + EXPECT_TRUE(t.Begin()); + EXPECT_TRUE(t.is_open()); + + EXPECT_TRUE(db().Execute("INSERT INTO foo (a, b) VALUES (1, 2)")); + + t.Commit(); + EXPECT_FALSE(t.is_open()); + } + + EXPECT_EQ(1, CountFoo()); +} + +TEST_F(SQLTransactionTest, Rollback) { + // Test some basic initialization, and that rollback runs when you exit the + // scope. + { + sql::Transaction t(&db()); + EXPECT_FALSE(t.is_open()); + EXPECT_TRUE(t.Begin()); + EXPECT_TRUE(t.is_open()); + + EXPECT_TRUE(db().Execute("INSERT INTO foo (a, b) VALUES (1, 2)")); + } + + // Nothing should have been committed since it was implicitly rolled back. + EXPECT_EQ(0, CountFoo()); + + // Test explicit rollback. + sql::Transaction t2(&db()); + EXPECT_FALSE(t2.is_open()); + EXPECT_TRUE(t2.Begin()); + + EXPECT_TRUE(db().Execute("INSERT INTO foo (a, b) VALUES (1, 2)")); + t2.Rollback(); + EXPECT_FALSE(t2.is_open()); + + // Nothing should have been committed since it was explicitly rolled back. + EXPECT_EQ(0, CountFoo()); +} + +// Rolling back any part of a transaction should roll back all of them. +TEST_F(SQLTransactionTest, NestedRollback) { + EXPECT_EQ(0, db().transaction_nesting()); + + // Outermost transaction. + { + sql::Transaction outer(&db()); + EXPECT_TRUE(outer.Begin()); + EXPECT_EQ(1, db().transaction_nesting()); + + // The first inner one gets committed. + { + sql::Transaction inner1(&db()); + EXPECT_TRUE(inner1.Begin()); + EXPECT_TRUE(db().Execute("INSERT INTO foo (a, b) VALUES (1, 2)")); + EXPECT_EQ(2, db().transaction_nesting()); + + inner1.Commit(); + EXPECT_EQ(1, db().transaction_nesting()); + } + + // One row should have gotten inserted. + EXPECT_EQ(1, CountFoo()); + + // The second inner one gets rolled back. + { + sql::Transaction inner2(&db()); + EXPECT_TRUE(inner2.Begin()); + EXPECT_TRUE(db().Execute("INSERT INTO foo (a, b) VALUES (1, 2)")); + EXPECT_EQ(2, db().transaction_nesting()); + + inner2.Rollback(); + EXPECT_EQ(1, db().transaction_nesting()); + } + + // A third inner one will fail in Begin since one has already been rolled + // back. + EXPECT_EQ(1, db().transaction_nesting()); + { + sql::Transaction inner3(&db()); + EXPECT_FALSE(inner3.Begin()); + EXPECT_EQ(1, db().transaction_nesting()); + } + } + EXPECT_EQ(0, db().transaction_nesting()); + EXPECT_EQ(0, CountFoo()); +} |