From 8102e8997cd0f19fa3269f4ea71138fb128a666c Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Thu, 1 Jan 2026 09:03:59 +0900 Subject: [PATCH 1/3] sqlite autocommit --- Lib/dbm/gnu.py | 3 + Lib/dbm/ndbm.py | 3 + Lib/dbm/sqlite3.py | 144 ++++++++++++++ Lib/test/test_dbm_gnu.py | 222 +++++++++++++++++++++ Lib/test/test_dbm_ndbm.py | 165 ++++++++++++++++ Lib/test/test_dbm_sqlite3.py | 362 +++++++++++++++++++++++++++++++++++ crates/stdlib/src/sqlite.rs | 67 +++++++ 7 files changed, 966 insertions(+) create mode 100644 Lib/dbm/gnu.py create mode 100644 Lib/dbm/ndbm.py create mode 100644 Lib/dbm/sqlite3.py create mode 100644 Lib/test/test_dbm_gnu.py create mode 100644 Lib/test/test_dbm_ndbm.py create mode 100644 Lib/test/test_dbm_sqlite3.py diff --git a/Lib/dbm/gnu.py b/Lib/dbm/gnu.py new file mode 100644 index 00000000000..b07a1defffd --- /dev/null +++ b/Lib/dbm/gnu.py @@ -0,0 +1,3 @@ +"""Provide the _gdbm module as a dbm submodule.""" + +from _gdbm import * diff --git a/Lib/dbm/ndbm.py b/Lib/dbm/ndbm.py new file mode 100644 index 00000000000..23056a29ef2 --- /dev/null +++ b/Lib/dbm/ndbm.py @@ -0,0 +1,3 @@ +"""Provide the _dbm module as a dbm submodule.""" + +from _dbm import * diff --git a/Lib/dbm/sqlite3.py b/Lib/dbm/sqlite3.py new file mode 100644 index 00000000000..d0eed54e0f8 --- /dev/null +++ b/Lib/dbm/sqlite3.py @@ -0,0 +1,144 @@ +import os +import sqlite3 +from pathlib import Path +from contextlib import suppress, closing +from collections.abc import MutableMapping + +BUILD_TABLE = """ + CREATE TABLE IF NOT EXISTS Dict ( + key BLOB UNIQUE NOT NULL, + value BLOB NOT NULL + ) +""" +GET_SIZE = "SELECT COUNT (key) FROM Dict" +LOOKUP_KEY = "SELECT value FROM Dict WHERE key = CAST(? AS BLOB)" +STORE_KV = "REPLACE INTO Dict (key, value) VALUES (CAST(? AS BLOB), CAST(? AS BLOB))" +DELETE_KEY = "DELETE FROM Dict WHERE key = CAST(? AS BLOB)" +ITER_KEYS = "SELECT key FROM Dict" + + +class error(OSError): + pass + + +_ERR_CLOSED = "DBM object has already been closed" +_ERR_REINIT = "DBM object does not support reinitialization" + + +def _normalize_uri(path): + path = Path(path) + uri = path.absolute().as_uri() + while "//" in uri: + uri = uri.replace("//", "/") + return uri + + +class _Database(MutableMapping): + + def __init__(self, path, /, *, flag, mode): + if hasattr(self, "_cx"): + raise error(_ERR_REINIT) + + path = os.fsdecode(path) + match flag: + case "r": + flag = "ro" + case "w": + flag = "rw" + case "c": + flag = "rwc" + Path(path).touch(mode=mode, exist_ok=True) + case "n": + flag = "rwc" + Path(path).unlink(missing_ok=True) + Path(path).touch(mode=mode) + case _: + raise ValueError("Flag must be one of 'r', 'w', 'c', or 'n', " + f"not {flag!r}") + + # We use the URI format when opening the database. + uri = _normalize_uri(path) + uri = f"{uri}?mode={flag}" + if flag == "ro": + # Add immutable=1 to allow read-only SQLite access even if wal/shm missing + uri += "&immutable=1" + + try: + self._cx = sqlite3.connect(uri, autocommit=True, uri=True) + except sqlite3.Error as exc: + raise error(str(exc)) + + if flag != "ro": + # This is an optimization only; it's ok if it fails. + with suppress(sqlite3.OperationalError): + self._cx.execute("PRAGMA journal_mode = wal") + + if flag == "rwc": + self._execute(BUILD_TABLE) + + def _execute(self, *args, **kwargs): + if not self._cx: + raise error(_ERR_CLOSED) + try: + return closing(self._cx.execute(*args, **kwargs)) + except sqlite3.Error as exc: + raise error(str(exc)) + + def __len__(self): + with self._execute(GET_SIZE) as cu: + row = cu.fetchone() + return row[0] + + def __getitem__(self, key): + with self._execute(LOOKUP_KEY, (key,)) as cu: + row = cu.fetchone() + if not row: + raise KeyError(key) + return row[0] + + def __setitem__(self, key, value): + self._execute(STORE_KV, (key, value)) + + def __delitem__(self, key): + with self._execute(DELETE_KEY, (key,)) as cu: + if not cu.rowcount: + raise KeyError(key) + + def __iter__(self): + try: + with self._execute(ITER_KEYS) as cu: + for row in cu: + yield row[0] + except sqlite3.Error as exc: + raise error(str(exc)) + + def close(self): + if self._cx: + self._cx.close() + self._cx = None + + def keys(self): + return list(super().keys()) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +def open(filename, /, flag="r", mode=0o666): + """Open a dbm.sqlite3 database and return the dbm object. + + The 'filename' parameter is the name of the database file. + + The optional 'flag' parameter can be one of ...: + 'r' (default): open an existing database for read only access + 'w': open an existing database for read/write access + 'c': create a database if it does not exist; open for read/write access + 'n': always create a new, empty database; open for read/write access + + The optional 'mode' parameter is the Unix file access mode of the database; + only used when creating a new database. Default: 0o666. + """ + return _Database(filename, flag=flag, mode=mode) diff --git a/Lib/test/test_dbm_gnu.py b/Lib/test/test_dbm_gnu.py new file mode 100644 index 00000000000..66268c42a30 --- /dev/null +++ b/Lib/test/test_dbm_gnu.py @@ -0,0 +1,222 @@ +import os +import unittest +from test import support +from test.support import cpython_only, import_helper +from test.support.os_helper import (TESTFN, TESTFN_NONASCII, FakePath, + create_empty_file, temp_dir, unlink) + +gdbm = import_helper.import_module("dbm.gnu") # skip if not supported + +filename = TESTFN + +class TestGdbm(unittest.TestCase): + @staticmethod + def setUpClass(): + if support.verbose: + try: + from _gdbm import _GDBM_VERSION as version + except ImportError: + pass + else: + print(f"gdbm version: {version}") + + def setUp(self): + self.g = None + + def tearDown(self): + if self.g is not None: + self.g.close() + unlink(filename) + + @cpython_only + def test_disallow_instantiation(self): + # Ensure that the type disallows instantiation (bpo-43916) + self.g = gdbm.open(filename, 'c') + support.check_disallow_instantiation(self, type(self.g)) + + def test_key_methods(self): + self.g = gdbm.open(filename, 'c') + self.assertEqual(self.g.keys(), []) + self.g['a'] = 'b' + self.g['12345678910'] = '019237410982340912840198242' + self.g[b'bytes'] = b'data' + key_set = set(self.g.keys()) + self.assertEqual(key_set, set([b'a', b'bytes', b'12345678910'])) + self.assertIn('a', self.g) + self.assertIn(b'a', self.g) + self.assertEqual(self.g[b'bytes'], b'data') + key = self.g.firstkey() + while key: + self.assertIn(key, key_set) + key_set.remove(key) + key = self.g.nextkey(key) + # get() and setdefault() work as in the dict interface + self.assertEqual(self.g.get(b'a'), b'b') + self.assertIsNone(self.g.get(b'xxx')) + self.assertEqual(self.g.get(b'xxx', b'foo'), b'foo') + with self.assertRaises(KeyError): + self.g['xxx'] + self.assertEqual(self.g.setdefault(b'xxx', b'foo'), b'foo') + self.assertEqual(self.g[b'xxx'], b'foo') + + def test_error_conditions(self): + # Try to open a non-existent database. + unlink(filename) + self.assertRaises(gdbm.error, gdbm.open, filename, 'r') + # Try to access a closed database. + self.g = gdbm.open(filename, 'c') + self.g.close() + self.assertRaises(gdbm.error, lambda: self.g['a']) + # try pass an invalid open flag + self.assertRaises(gdbm.error, lambda: gdbm.open(filename, 'rx').close()) + + def test_flags(self): + # Test the flag parameter open() by trying all supported flag modes. + all = set(gdbm.open_flags) + # Test standard flags (presumably "crwn"). + modes = all - set('fsu') + for mode in sorted(modes): # put "c" mode first + self.g = gdbm.open(filename, mode) + self.g.close() + + # Test additional flags (presumably "fsu"). + flags = all - set('crwn') + for mode in modes: + for flag in flags: + self.g = gdbm.open(filename, mode + flag) + self.g.close() + + def test_reorganize(self): + self.g = gdbm.open(filename, 'c') + size0 = os.path.getsize(filename) + + # bpo-33901: on macOS with gdbm 1.15, an empty database uses 16 MiB + # and adding an entry of 10,000 B has no effect on the file size. + # Add size0 bytes to make sure that the file size changes. + value_size = max(size0, 10000) + self.g['x'] = 'x' * value_size + size1 = os.path.getsize(filename) + self.assertGreater(size1, size0) + + del self.g['x'] + # 'size' is supposed to be the same even after deleting an entry. + self.assertEqual(os.path.getsize(filename), size1) + + self.g.reorganize() + size2 = os.path.getsize(filename) + self.assertLess(size2, size1) + self.assertGreaterEqual(size2, size0) + + def test_context_manager(self): + with gdbm.open(filename, 'c') as db: + db["gdbm context manager"] = "context manager" + + with gdbm.open(filename, 'r') as db: + self.assertEqual(list(db.keys()), [b"gdbm context manager"]) + + with self.assertRaises(gdbm.error) as cm: + db.keys() + self.assertEqual(str(cm.exception), + "GDBM object has already been closed") + + def test_bool_empty(self): + with gdbm.open(filename, 'c') as db: + self.assertFalse(bool(db)) + + def test_bool_not_empty(self): + with gdbm.open(filename, 'c') as db: + db['a'] = 'b' + self.assertTrue(bool(db)) + + def test_bool_on_closed_db_raises(self): + with gdbm.open(filename, 'c') as db: + db['a'] = 'b' + self.assertRaises(gdbm.error, bool, db) + + def test_bytes(self): + with gdbm.open(filename, 'c') as db: + db[b'bytes key \xbd'] = b'bytes value \xbd' + with gdbm.open(filename, 'r') as db: + self.assertEqual(list(db.keys()), [b'bytes key \xbd']) + self.assertTrue(b'bytes key \xbd' in db) + self.assertEqual(db[b'bytes key \xbd'], b'bytes value \xbd') + + def test_unicode(self): + with gdbm.open(filename, 'c') as db: + db['Unicode key \U0001f40d'] = 'Unicode value \U0001f40d' + with gdbm.open(filename, 'r') as db: + self.assertEqual(list(db.keys()), ['Unicode key \U0001f40d'.encode()]) + self.assertTrue('Unicode key \U0001f40d'.encode() in db) + self.assertTrue('Unicode key \U0001f40d' in db) + self.assertEqual(db['Unicode key \U0001f40d'.encode()], + 'Unicode value \U0001f40d'.encode()) + self.assertEqual(db['Unicode key \U0001f40d'], + 'Unicode value \U0001f40d'.encode()) + + def test_write_readonly_file(self): + with gdbm.open(filename, 'c') as db: + db[b'bytes key'] = b'bytes value' + with gdbm.open(filename, 'r') as db: + with self.assertRaises(gdbm.error): + del db[b'not exist key'] + with self.assertRaises(gdbm.error): + del db[b'bytes key'] + with self.assertRaises(gdbm.error): + db[b'not exist key'] = b'not exist value' + + @unittest.skipUnless(TESTFN_NONASCII, + 'requires OS support of non-ASCII encodings') + def test_nonascii_filename(self): + filename = TESTFN_NONASCII + self.addCleanup(unlink, filename) + with gdbm.open(filename, 'c') as db: + db[b'key'] = b'value' + self.assertTrue(os.path.exists(filename)) + with gdbm.open(filename, 'r') as db: + self.assertEqual(list(db.keys()), [b'key']) + self.assertTrue(b'key' in db) + self.assertEqual(db[b'key'], b'value') + + def test_nonexisting_file(self): + nonexisting_file = 'nonexisting-file' + with self.assertRaises(gdbm.error) as cm: + gdbm.open(nonexisting_file) + self.assertIn(nonexisting_file, str(cm.exception)) + self.assertEqual(cm.exception.filename, nonexisting_file) + + def test_open_with_pathlib_path(self): + gdbm.open(FakePath(filename), "c").close() + + def test_open_with_bytes_path(self): + gdbm.open(os.fsencode(filename), "c").close() + + def test_open_with_pathlib_bytes_path(self): + gdbm.open(FakePath(os.fsencode(filename)), "c").close() + + def test_clear(self): + kvs = [('foo', 'bar'), ('1234', '5678')] + with gdbm.open(filename, 'c') as db: + for k, v in kvs: + db[k] = v + self.assertIn(k, db) + self.assertEqual(len(db), len(kvs)) + + db.clear() + for k, v in kvs: + self.assertNotIn(k, db) + self.assertEqual(len(db), 0) + + @support.run_with_locale( + 'LC_ALL', + 'fr_FR.iso88591', 'ja_JP.sjis', 'zh_CN.gbk', + 'fr_FR.utf8', 'en_US.utf8', + '', + ) + def test_localized_error(self): + with temp_dir() as d: + create_empty_file(os.path.join(d, 'test')) + self.assertRaises(gdbm.error, gdbm.open, filename, 'r') + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_dbm_ndbm.py b/Lib/test/test_dbm_ndbm.py new file mode 100644 index 00000000000..e0f31c9a9a3 --- /dev/null +++ b/Lib/test/test_dbm_ndbm.py @@ -0,0 +1,165 @@ +from test.support import import_helper +from test.support import os_helper +import_helper.import_module("dbm.ndbm") #skip if not supported +import os +import unittest +import dbm.ndbm +from dbm.ndbm import error + +class DbmTestCase(unittest.TestCase): + + def setUp(self): + self.filename = os_helper.TESTFN + self.d = dbm.ndbm.open(self.filename, 'c') + self.d.close() + + def tearDown(self): + for suffix in ['', '.pag', '.dir', '.db']: + os_helper.unlink(self.filename + suffix) + + def test_keys(self): + self.d = dbm.ndbm.open(self.filename, 'c') + self.assertEqual(self.d.keys(), []) + self.d['a'] = 'b' + self.d[b'bytes'] = b'data' + self.d['12345678910'] = '019237410982340912840198242' + self.d.keys() + self.assertIn('a', self.d) + self.assertIn(b'a', self.d) + self.assertEqual(self.d[b'bytes'], b'data') + # get() and setdefault() work as in the dict interface + self.assertEqual(self.d.get(b'a'), b'b') + self.assertIsNone(self.d.get(b'xxx')) + self.assertEqual(self.d.get(b'xxx', b'foo'), b'foo') + with self.assertRaises(KeyError): + self.d['xxx'] + self.assertEqual(self.d.setdefault(b'xxx', b'foo'), b'foo') + self.assertEqual(self.d[b'xxx'], b'foo') + self.d.close() + + def test_empty_value(self): + if dbm.ndbm.library == 'Berkeley DB': + self.skipTest("Berkeley DB doesn't distinguish the empty value " + "from the absent one") + self.d = dbm.ndbm.open(self.filename, 'c') + self.assertEqual(self.d.keys(), []) + self.d['empty'] = '' + self.assertEqual(self.d.keys(), [b'empty']) + self.assertIn(b'empty', self.d) + self.assertEqual(self.d[b'empty'], b'') + self.assertEqual(self.d.get(b'empty'), b'') + self.assertEqual(self.d.setdefault(b'empty'), b'') + self.d.close() + + def test_modes(self): + for mode in ['r', 'rw', 'w', 'n']: + try: + self.d = dbm.ndbm.open(self.filename, mode) + self.d.close() + except error: + self.fail() + + def test_context_manager(self): + with dbm.ndbm.open(self.filename, 'c') as db: + db["ndbm context manager"] = "context manager" + + with dbm.ndbm.open(self.filename, 'r') as db: + self.assertEqual(list(db.keys()), [b"ndbm context manager"]) + + with self.assertRaises(dbm.ndbm.error) as cm: + db.keys() + self.assertEqual(str(cm.exception), + "DBM object has already been closed") + + def test_bytes(self): + with dbm.ndbm.open(self.filename, 'c') as db: + db[b'bytes key \xbd'] = b'bytes value \xbd' + with dbm.ndbm.open(self.filename, 'r') as db: + self.assertEqual(list(db.keys()), [b'bytes key \xbd']) + self.assertTrue(b'bytes key \xbd' in db) + self.assertEqual(db[b'bytes key \xbd'], b'bytes value \xbd') + + def test_unicode(self): + with dbm.ndbm.open(self.filename, 'c') as db: + db['Unicode key \U0001f40d'] = 'Unicode value \U0001f40d' + with dbm.ndbm.open(self.filename, 'r') as db: + self.assertEqual(list(db.keys()), ['Unicode key \U0001f40d'.encode()]) + self.assertTrue('Unicode key \U0001f40d'.encode() in db) + self.assertTrue('Unicode key \U0001f40d' in db) + self.assertEqual(db['Unicode key \U0001f40d'.encode()], + 'Unicode value \U0001f40d'.encode()) + self.assertEqual(db['Unicode key \U0001f40d'], + 'Unicode value \U0001f40d'.encode()) + + def test_write_readonly_file(self): + with dbm.ndbm.open(self.filename, 'c') as db: + db[b'bytes key'] = b'bytes value' + with dbm.ndbm.open(self.filename, 'r') as db: + with self.assertRaises(error): + del db[b'not exist key'] + with self.assertRaises(error): + del db[b'bytes key'] + with self.assertRaises(error): + db[b'not exist key'] = b'not exist value' + + @unittest.skipUnless(os_helper.TESTFN_NONASCII, + 'requires OS support of non-ASCII encodings') + def test_nonascii_filename(self): + filename = os_helper.TESTFN_NONASCII + for suffix in ['', '.pag', '.dir', '.db']: + self.addCleanup(os_helper.unlink, filename + suffix) + with dbm.ndbm.open(filename, 'c') as db: + db[b'key'] = b'value' + self.assertTrue(any(os.path.exists(filename + suffix) + for suffix in ['', '.pag', '.dir', '.db'])) + with dbm.ndbm.open(filename, 'r') as db: + self.assertEqual(list(db.keys()), [b'key']) + self.assertTrue(b'key' in db) + self.assertEqual(db[b'key'], b'value') + + def test_nonexisting_file(self): + nonexisting_file = 'nonexisting-file' + with self.assertRaises(dbm.ndbm.error) as cm: + dbm.ndbm.open(nonexisting_file) + self.assertIn(nonexisting_file, str(cm.exception)) + self.assertEqual(cm.exception.filename, nonexisting_file) + + def test_open_with_pathlib_path(self): + dbm.ndbm.open(os_helper.FakePath(self.filename), "c").close() + + def test_open_with_bytes_path(self): + dbm.ndbm.open(os.fsencode(self.filename), "c").close() + + def test_open_with_pathlib_bytes_path(self): + dbm.ndbm.open(os_helper.FakePath(os.fsencode(self.filename)), "c").close() + + def test_bool_empty(self): + with dbm.ndbm.open(self.filename, 'c') as db: + self.assertFalse(bool(db)) + + def test_bool_not_empty(self): + with dbm.ndbm.open(self.filename, 'c') as db: + db['a'] = 'b' + self.assertTrue(bool(db)) + + def test_bool_on_closed_db_raises(self): + with dbm.ndbm.open(self.filename, 'c') as db: + db['a'] = 'b' + self.assertRaises(dbm.ndbm.error, bool, db) + + def test_clear(self): + kvs = [('foo', 'bar'), ('1234', '5678')] + with dbm.ndbm.open(self.filename, 'c') as db: + for k, v in kvs: + db[k] = v + self.assertIn(k, db) + self.assertEqual(len(db), len(kvs)) + + db.clear() + for k, v in kvs: + self.assertNotIn(k, db) + self.assertEqual(len(db), 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_dbm_sqlite3.py b/Lib/test/test_dbm_sqlite3.py new file mode 100644 index 00000000000..39eac7a35ec --- /dev/null +++ b/Lib/test/test_dbm_sqlite3.py @@ -0,0 +1,362 @@ +import os +import stat +import sys +import test.support +import unittest +from contextlib import closing +from functools import partial +from pathlib import Path +from test.support import cpython_only, import_helper, os_helper + +dbm_sqlite3 = import_helper.import_module("dbm.sqlite3") +# N.B. The test will fail on some platforms without sqlite3 +# if the sqlite3 import is above the import of dbm.sqlite3. +# This is deliberate: if the import helper managed to import dbm.sqlite3, +# we must inevitably be able to import sqlite3. Else, we have a problem. +import sqlite3 +from dbm.sqlite3 import _normalize_uri + + +root_in_posix = False +if hasattr(os, 'geteuid'): + root_in_posix = (os.geteuid() == 0) + + +class _SQLiteDbmTests(unittest.TestCase): + + def setUp(self): + self.filename = os_helper.TESTFN + db = dbm_sqlite3.open(self.filename, "c") + db.close() + + def tearDown(self): + for suffix in "", "-wal", "-shm": + os_helper.unlink(self.filename + suffix) + + +class URI(unittest.TestCase): + + def test_uri_substitutions(self): + dataset = ( + ("/absolute/////b/c", "/absolute/b/c"), + ("PRE#MID##END", "PRE%23MID%23%23END"), + ("%#?%%#", "%25%23%3F%25%25%23"), + ) + for path, normalized in dataset: + with self.subTest(path=path, normalized=normalized): + self.assertTrue(_normalize_uri(path).endswith(normalized)) + + @unittest.skipUnless(sys.platform == "win32", "requires Windows") + def test_uri_windows(self): + dataset = ( + # Relative subdir. + (r"2018\January.xlsx", + "2018/January.xlsx"), + # Absolute with drive letter. + (r"C:\Projects\apilibrary\apilibrary.sln", + "/C:/Projects/apilibrary/apilibrary.sln"), + # Relative with drive letter. + (r"C:Projects\apilibrary\apilibrary.sln", + "/C:Projects/apilibrary/apilibrary.sln"), + ) + for path, normalized in dataset: + with self.subTest(path=path, normalized=normalized): + if not Path(path).is_absolute(): + self.skipTest(f"skipping relative path: {path!r}") + self.assertTrue(_normalize_uri(path).endswith(normalized)) + + +class ReadOnly(_SQLiteDbmTests): + + def setUp(self): + super().setUp() + with dbm_sqlite3.open(self.filename, "w") as db: + db[b"key1"] = "value1" + db[b"key2"] = "value2" + self.db = dbm_sqlite3.open(self.filename, "r") + + def tearDown(self): + self.db.close() + super().tearDown() + + def test_readonly_read(self): + self.assertEqual(self.db[b"key1"], b"value1") + self.assertEqual(self.db[b"key2"], b"value2") + + def test_readonly_write(self): + with self.assertRaises(dbm_sqlite3.error): + self.db[b"new"] = "value" + + def test_readonly_delete(self): + with self.assertRaises(dbm_sqlite3.error): + del self.db[b"key1"] + + def test_readonly_keys(self): + self.assertEqual(self.db.keys(), [b"key1", b"key2"]) + + def test_readonly_iter(self): + self.assertEqual([k for k in self.db], [b"key1", b"key2"]) + + +@unittest.skipIf(root_in_posix, "test is meanless with root privilege") +class ReadOnlyFilesystem(unittest.TestCase): + + def setUp(self): + self.test_dir = os_helper.TESTFN + self.addCleanup(os_helper.rmtree, self.test_dir) + os.mkdir(self.test_dir) + self.db_path = os.path.join(self.test_dir, "test.db") + + db = dbm_sqlite3.open(self.db_path, "c") + db[b"key"] = b"value" + db.close() + + def test_readonly_file_read(self): + os.chmod(self.db_path, stat.S_IREAD) + with dbm_sqlite3.open(self.db_path, "r") as db: + self.assertEqual(db[b"key"], b"value") + + def test_readonly_file_write(self): + os.chmod(self.db_path, stat.S_IREAD) + with dbm_sqlite3.open(self.db_path, "w") as db: + with self.assertRaises(dbm_sqlite3.error): + db[b"newkey"] = b"newvalue" + + def test_readonly_dir_read(self): + os.chmod(self.test_dir, stat.S_IREAD | stat.S_IEXEC) + with dbm_sqlite3.open(self.db_path, "r") as db: + self.assertEqual(db[b"key"], b"value") + + def test_readonly_dir_write(self): + os.chmod(self.test_dir, stat.S_IREAD | stat.S_IEXEC) + with dbm_sqlite3.open(self.db_path, "w") as db: + try: + db[b"newkey"] = b"newvalue" + modified = True # on Windows and macOS + except dbm_sqlite3.error: + modified = False + with dbm_sqlite3.open(self.db_path, "r") as db: + if modified: + self.assertEqual(db[b"newkey"], b"newvalue") + else: + self.assertNotIn(b"newkey", db) + + +class ReadWrite(_SQLiteDbmTests): + + def setUp(self): + super().setUp() + self.db = dbm_sqlite3.open(self.filename, "w") + + def tearDown(self): + self.db.close() + super().tearDown() + + def db_content(self): + with closing(sqlite3.connect(self.filename)) as cx: + keys = [r[0] for r in cx.execute("SELECT key FROM Dict")] + vals = [r[0] for r in cx.execute("SELECT value FROM Dict")] + return keys, vals + + def test_readwrite_unique_key(self): + self.db["key"] = "value" + self.db["key"] = "other" + keys, vals = self.db_content() + self.assertEqual(keys, [b"key"]) + self.assertEqual(vals, [b"other"]) + + def test_readwrite_delete(self): + self.db["key"] = "value" + self.db["new"] = "other" + + del self.db[b"new"] + keys, vals = self.db_content() + self.assertEqual(keys, [b"key"]) + self.assertEqual(vals, [b"value"]) + + del self.db[b"key"] + keys, vals = self.db_content() + self.assertEqual(keys, []) + self.assertEqual(vals, []) + + def test_readwrite_null_key(self): + with self.assertRaises(dbm_sqlite3.error): + self.db[None] = "value" + + def test_readwrite_null_value(self): + with self.assertRaises(dbm_sqlite3.error): + self.db[b"key"] = None + + +class Misuse(_SQLiteDbmTests): + + def setUp(self): + super().setUp() + self.db = dbm_sqlite3.open(self.filename, "w") + + def tearDown(self): + self.db.close() + super().tearDown() + + def test_misuse_double_create(self): + self.db["key"] = "value" + with dbm_sqlite3.open(self.filename, "c") as db: + self.assertEqual(db[b"key"], b"value") + + def test_misuse_double_close(self): + self.db.close() + + def test_misuse_invalid_flag(self): + regex = "must be.*'r'.*'w'.*'c'.*'n', not 'invalid'" + with self.assertRaisesRegex(ValueError, regex): + dbm_sqlite3.open(self.filename, flag="invalid") + + def test_misuse_double_delete(self): + self.db["key"] = "value" + del self.db[b"key"] + with self.assertRaises(KeyError): + del self.db[b"key"] + + def test_misuse_invalid_key(self): + with self.assertRaises(KeyError): + self.db[b"key"] + + def test_misuse_iter_close1(self): + self.db["1"] = 1 + it = iter(self.db) + self.db.close() + with self.assertRaises(dbm_sqlite3.error): + next(it) + + def test_misuse_iter_close2(self): + self.db["1"] = 1 + self.db["2"] = 2 + it = iter(self.db) + next(it) + self.db.close() + with self.assertRaises(dbm_sqlite3.error): + next(it) + + def test_misuse_use_after_close(self): + self.db.close() + with self.assertRaises(dbm_sqlite3.error): + self.db[b"read"] + with self.assertRaises(dbm_sqlite3.error): + self.db[b"write"] = "value" + with self.assertRaises(dbm_sqlite3.error): + del self.db[b"del"] + with self.assertRaises(dbm_sqlite3.error): + len(self.db) + with self.assertRaises(dbm_sqlite3.error): + self.db.keys() + + def test_misuse_reinit(self): + with self.assertRaises(dbm_sqlite3.error): + self.db.__init__("new.db", flag="n", mode=0o666) + + def test_misuse_empty_filename(self): + for flag in "r", "w", "c", "n": + with self.assertRaises(dbm_sqlite3.error): + db = dbm_sqlite3.open("", flag="c") + + +class DataTypes(_SQLiteDbmTests): + + dataset = ( + # (raw, coerced) + (42, b"42"), + (3.14, b"3.14"), + ("string", b"string"), + (b"bytes", b"bytes"), + ) + + def setUp(self): + super().setUp() + self.db = dbm_sqlite3.open(self.filename, "w") + + def tearDown(self): + self.db.close() + super().tearDown() + + def test_datatypes_values(self): + for raw, coerced in self.dataset: + with self.subTest(raw=raw, coerced=coerced): + self.db["key"] = raw + self.assertEqual(self.db[b"key"], coerced) + + def test_datatypes_keys(self): + for raw, coerced in self.dataset: + with self.subTest(raw=raw, coerced=coerced): + self.db[raw] = "value" + self.assertEqual(self.db[coerced], b"value") + # Raw keys are silently coerced to bytes. + self.assertEqual(self.db[raw], b"value") + del self.db[raw] + + def test_datatypes_replace_coerced(self): + self.db["10"] = "value" + self.db[b"10"] = "value" + self.db[10] = "value" + self.assertEqual(self.db.keys(), [b"10"]) + + +class CorruptDatabase(_SQLiteDbmTests): + """Verify that database exceptions are raised as dbm.sqlite3.error.""" + + def setUp(self): + super().setUp() + with closing(sqlite3.connect(self.filename)) as cx: + with cx: + cx.execute("DROP TABLE IF EXISTS Dict") + cx.execute("CREATE TABLE Dict (invalid_schema)") + + def check(self, flag, fn, should_succeed=False): + with closing(dbm_sqlite3.open(self.filename, flag)) as db: + with self.assertRaises(dbm_sqlite3.error): + fn(db) + + @staticmethod + def read(db): + return db["key"] + + @staticmethod + def write(db): + db["key"] = "value" + + @staticmethod + def iter(db): + next(iter(db)) + + @staticmethod + def keys(db): + db.keys() + + @staticmethod + def del_(db): + del db["key"] + + @staticmethod + def len_(db): + len(db) + + def test_corrupt_readwrite(self): + for flag in "r", "w", "c": + with self.subTest(flag=flag): + check = partial(self.check, flag=flag) + check(fn=self.read) + check(fn=self.write) + check(fn=self.iter) + check(fn=self.keys) + check(fn=self.del_) + check(fn=self.len_) + + def test_corrupt_force_new(self): + with closing(dbm_sqlite3.open(self.filename, "n")) as db: + db["foo"] = "write" + _ = db[b"foo"] + next(iter(db)) + del db[b"foo"] + + +if __name__ == "__main__": + unittest.main() diff --git a/crates/stdlib/src/sqlite.rs b/crates/stdlib/src/sqlite.rs index 54c889ecb6b..f9ef3f7a8d2 100644 --- a/crates/stdlib/src/sqlite.rs +++ b/crates/stdlib/src/sqlite.rs @@ -160,6 +160,8 @@ mod _sqlite { const PARSE_DECLTYPES: c_int = 1; #[pyattr] const PARSE_COLNAMES: c_int = 2; + #[pyattr] + const LEGACY_TRANSACTION_CONTROL: c_int = 1; #[pyattr] use libsqlite3_sys::{ @@ -300,6 +302,46 @@ mod _sqlite { SQLITE_IOERR_CORRUPTFS ); + /// Autocommit mode setting for sqlite3 connections. + /// - Legacy (default): use isolation_level to control transactions + /// - Enabled: autocommit mode (no automatic transactions) + /// - Disabled: manual commit mode + #[derive(Clone, Copy, Debug, PartialEq, Eq)] + enum AutocommitMode { + Legacy, + Enabled, + Disabled, + } + + impl Default for AutocommitMode { + fn default() -> Self { + Self::Legacy + } + } + + impl TryFromBorrowedObject<'_> for AutocommitMode { + fn try_from_borrowed_object(vm: &VirtualMachine, obj: &PyObject) -> PyResult { + if obj.is(&vm.ctx.true_value) { + Ok(Self::Enabled) + } else if obj.is(&vm.ctx.false_value) { + Ok(Self::Disabled) + } else if let Ok(val) = obj.try_to_value::(vm) { + if val == LEGACY_TRANSACTION_CONTROL { + Ok(Self::Legacy) + } else { + Err(vm.new_value_error(format!( + "autocommit must be True, False, or sqlite3.LEGACY_TRANSACTION_CONTROL, not {val}" + ))) + } + } else { + Err(vm.new_type_error(format!( + "autocommit must be True, False, or sqlite3.LEGACY_TRANSACTION_CONTROL, not {}", + obj.class().name() + ))) + } + } + } + #[derive(FromArgs)] struct ConnectArgs { #[pyarg(any)] @@ -320,6 +362,8 @@ mod _sqlite { cached_statements: c_int, #[pyarg(any, default = false)] uri: bool, + #[pyarg(any, default)] + autocommit: AutocommitMode, } unsafe impl Traverse for ConnectArgs { @@ -841,6 +885,7 @@ mod _sqlite { thread_ident: PyMutex, // TODO: Use atomic row_factory: PyAtomicRef>, text_factory: PyAtomicRef, + autocommit: PyMutex, } impl Debug for Connection { @@ -878,6 +923,7 @@ mod _sqlite { thread_ident: PyMutex::new(std::thread::current().id()), row_factory: PyAtomicRef::from(None), text_factory: PyAtomicRef::from(text_factory), + autocommit: PyMutex::new(args.autocommit), }) } } @@ -919,12 +965,14 @@ mod _sqlite { detect_types, isolation_level, check_same_thread, + autocommit, .. } = args; zelf.detect_types.store(detect_types, Ordering::Relaxed); zelf.check_same_thread .store(check_same_thread, Ordering::Relaxed); + *zelf.autocommit.lock() = autocommit; *zelf.thread_ident.lock() = std::thread::current().id(); let _ = unsafe { zelf.isolation_level.swap(isolation_level) }; @@ -1465,6 +1513,21 @@ mod _sqlite { } } + #[pygetset] + fn autocommit(&self, vm: &VirtualMachine) -> PyObjectRef { + match *self.autocommit.lock() { + AutocommitMode::Enabled => vm.ctx.true_value.clone().into(), + AutocommitMode::Disabled => vm.ctx.false_value.clone().into(), + AutocommitMode::Legacy => vm.ctx.new_int(LEGACY_TRANSACTION_CONTROL).into(), + } + } + #[pygetset(setter)] + fn set_autocommit(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { + let mode = AutocommitMode::try_from_borrowed_object(vm, &val)?; + *self.autocommit.lock() = mode; + Ok(()) + } + #[pygetset] fn text_factory(&self) -> PyObjectRef { self.text_factory.to_owned() @@ -1622,9 +1685,11 @@ mod _sqlite { let db = zelf.connection.db_lock(vm)?; + // Start implicit transaction for DML statements unless in autocommit mode if stmt.is_dml && db.is_autocommit() && zelf.connection.isolation_level.deref().is_some() + && *zelf.connection.autocommit.lock() != AutocommitMode::Enabled { db.begin_transaction( zelf.connection @@ -1715,9 +1780,11 @@ mod _sqlite { let db = zelf.connection.db_lock(vm)?; + // Start implicit transaction for DML statements unless in autocommit mode if stmt.is_dml && db.is_autocommit() && zelf.connection.isolation_level.deref().is_some() + && *zelf.connection.autocommit.lock() != AutocommitMode::Enabled { db.begin_transaction( zelf.connection From 16666bd03f0c279828de4bc70cd9fbb214348404 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Thu, 1 Jan 2026 08:40:47 +0900 Subject: [PATCH 2/3] Update dbm --- Lib/dbm/__init__.py | 21 ++-- Lib/dbm/dumb.py | 11 +- Lib/test/test_dbm.py | 180 ++++++++++++++++++----------- Lib/test/test_dbm_dumb.py | 86 +++++++++++++- Lib/test/test_dbm_gnu.py | 222 ------------------------------------ Lib/test/test_dbm_ndbm.py | 165 --------------------------- crates/stdlib/src/sqlite.rs | 9 +- 7 files changed, 221 insertions(+), 473 deletions(-) delete mode 100644 Lib/test/test_dbm_gnu.py delete mode 100644 Lib/test/test_dbm_ndbm.py diff --git a/Lib/dbm/__init__.py b/Lib/dbm/__init__.py index f65da521af4..4fdbc54e74c 100644 --- a/Lib/dbm/__init__.py +++ b/Lib/dbm/__init__.py @@ -5,7 +5,7 @@ import dbm d = dbm.open(file, 'w', 0o666) -The returned object is a dbm.gnu, dbm.ndbm or dbm.dumb object, dependent on the +The returned object is a dbm.sqlite3, dbm.gnu, dbm.ndbm or dbm.dumb database object, dependent on the type of database being opened (determined by the whichdb function) in the case of an existing dbm. If the dbm does not exist and the create or new flag ('c' or 'n') was specified, the dbm type will be determined by the availability of @@ -38,7 +38,7 @@ class error(Exception): pass -_names = ['dbm.gnu', 'dbm.ndbm', 'dbm.dumb'] +_names = ['dbm.sqlite3', 'dbm.gnu', 'dbm.ndbm', 'dbm.dumb'] _defaultmod = None _modules = {} @@ -109,17 +109,18 @@ def whichdb(filename): """ # Check for ndbm first -- this has a .pag and a .dir file + filename = os.fsencode(filename) try: - f = io.open(filename + ".pag", "rb") + f = io.open(filename + b".pag", "rb") f.close() - f = io.open(filename + ".dir", "rb") + f = io.open(filename + b".dir", "rb") f.close() return "dbm.ndbm" except OSError: # some dbm emulations based on Berkeley DB generate a .db file # some do not, but they should be caught by the bsd checks try: - f = io.open(filename + ".db", "rb") + f = io.open(filename + b".db", "rb") f.close() # guarantee we can actually open the file using dbm # kind of overkill, but since we are dealing with emulations @@ -134,12 +135,12 @@ def whichdb(filename): # Check for dumbdbm next -- this has a .dir and a .dat file try: # First check for presence of files - os.stat(filename + ".dat") - size = os.stat(filename + ".dir").st_size + os.stat(filename + b".dat") + size = os.stat(filename + b".dir").st_size # dumbdbm files with no keys are empty if size == 0: return "dbm.dumb" - f = io.open(filename + ".dir", "rb") + f = io.open(filename + b".dir", "rb") try: if f.read(1) in (b"'", b'"'): return "dbm.dumb" @@ -163,6 +164,10 @@ def whichdb(filename): if len(s) != 4: return "" + # Check for SQLite3 header string. + if s16 == b"SQLite format 3\0": + return "dbm.sqlite3" + # Convert to 4-byte int in native byte order -- return "" if impossible try: (magic,) = struct.unpack("=l", s) diff --git a/Lib/dbm/dumb.py b/Lib/dbm/dumb.py index 864ad371ec9..def120ffc37 100644 --- a/Lib/dbm/dumb.py +++ b/Lib/dbm/dumb.py @@ -46,6 +46,7 @@ class _Database(collections.abc.MutableMapping): _io = _io # for _commit() def __init__(self, filebasename, mode, flag='c'): + filebasename = self._os.fsencode(filebasename) self._mode = mode self._readonly = (flag == 'r') @@ -54,14 +55,14 @@ def __init__(self, filebasename, mode, flag='c'): # where key is the string key, pos is the offset into the dat # file of the associated value's first byte, and siz is the number # of bytes in the associated value. - self._dirfile = filebasename + '.dir' + self._dirfile = filebasename + b'.dir' # The data file is a binary file pointed into by the directory # file, and holds the values associated with keys. Each value # begins at a _BLOCKSIZE-aligned byte offset, and is a raw # binary 8-bit string value. - self._datfile = filebasename + '.dat' - self._bakfile = filebasename + '.bak' + self._datfile = filebasename + b'.dat' + self._bakfile = filebasename + b'.bak' # The index is an in-memory dict, mirroring the directory file. self._index = None # maps keys to (pos, siz) pairs @@ -97,7 +98,8 @@ def _update(self, flag): except OSError: if flag not in ('c', 'n'): raise - self._modified = True + with self._io.open(self._dirfile, 'w', encoding="Latin-1") as f: + self._chmod(self._dirfile) else: with f: for line in f: @@ -133,6 +135,7 @@ def _commit(self): # position; UTF-8, though, does care sometimes. entry = "%r, %r\n" % (key.decode('Latin-1'), pos_and_siz_pair) f.write(entry) + self._modified = False sync = _commit diff --git a/Lib/test/test_dbm.py b/Lib/test/test_dbm.py index e615d284cd3..6785aa273ac 100644 --- a/Lib/test/test_dbm.py +++ b/Lib/test/test_dbm.py @@ -1,23 +1,28 @@ """Test script for the dbm.open function based on testdumbdbm.py""" import unittest -import glob -import test.support -from test.support import os_helper, import_helper +import dbm +import os +from test.support import import_helper +from test.support import os_helper + + +try: + from dbm import sqlite3 as dbm_sqlite3 +except ImportError: + dbm_sqlite3 = None -# Skip tests if dbm module doesn't exist. -dbm = import_helper.import_module('dbm') try: from dbm import ndbm except ImportError: ndbm = None -_fname = os_helper.TESTFN +dirname = os_helper.TESTFN +_fname = os.path.join(dirname, os_helper.TESTFN) # -# Iterates over every database module supported by dbm currently available, -# setting dbm to use each in turn, and yielding that module +# Iterates over every database module supported by dbm currently available. # def dbm_iterator(): for name in dbm._names: @@ -31,11 +36,12 @@ def dbm_iterator(): # # Clean up all scratch databases we might have created during testing # -def delete_files(): - # we don't know the precise name the underlying database uses - # so we use glob to locate all names - for f in glob.glob(glob.escape(_fname) + "*"): - os_helper.unlink(f) +def cleaunup_test_dir(): + os_helper.rmtree(dirname) + +def setup_test_dir(): + cleaunup_test_dir() + os.mkdir(dirname) class AnyDBMTestCase: @@ -129,85 +135,127 @@ def test_anydbm_access(self): assert(f[key] == b"Python:") f.close() + def test_open_with_bytes(self): + dbm.open(os.fsencode(_fname), "c").close() + + def test_open_with_pathlib_path(self): + dbm.open(os_helper.FakePath(_fname), "c").close() + + def test_open_with_pathlib_path_bytes(self): + dbm.open(os_helper.FakePath(os.fsencode(_fname)), "c").close() + def read_helper(self, f): keys = self.keys_helper(f) for key in self._dict: self.assertEqual(self._dict[key], f[key.encode("ascii")]) - def tearDown(self): - delete_files() + def test_keys(self): + with dbm.open(_fname, 'c') as d: + self.assertEqual(d.keys(), []) + a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')] + for k, v in a: + d[k] = v + self.assertEqual(sorted(d.keys()), sorted(k for (k, v) in a)) + for k, v in a: + self.assertIn(k, d) + self.assertEqual(d[k], v) + self.assertNotIn(b'xxx', d) + self.assertRaises(KeyError, lambda: d[b'xxx']) + + def test_clear(self): + with dbm.open(_fname, 'c') as d: + self.assertEqual(d.keys(), []) + a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')] + for k, v in a: + d[k] = v + for k, _ in a: + self.assertIn(k, d) + self.assertEqual(len(d), len(a)) + + d.clear() + self.assertEqual(len(d), 0) + for k, _ in a: + self.assertNotIn(k, d) def setUp(self): + self.addCleanup(setattr, dbm, '_defaultmod', dbm._defaultmod) dbm._defaultmod = self.module - delete_files() + self.addCleanup(cleaunup_test_dir) + setup_test_dir() class WhichDBTestCase(unittest.TestCase): def test_whichdb(self): + self.addCleanup(setattr, dbm, '_defaultmod', dbm._defaultmod) + _bytes_fname = os.fsencode(_fname) + fnames = [_fname, os_helper.FakePath(_fname), + _bytes_fname, os_helper.FakePath(_bytes_fname)] for module in dbm_iterator(): # Check whether whichdb correctly guesses module name # for databases opened with "module" module. - # Try with empty files first name = module.__name__ - if name == 'dbm.dumb': - continue # whichdb can't support dbm.dumb - delete_files() - f = module.open(_fname, 'c') - f.close() - self.assertEqual(name, self.dbm.whichdb(_fname)) + setup_test_dir() + dbm._defaultmod = module + # Try with empty files first + with module.open(_fname, 'c'): pass + for path in fnames: + self.assertEqual(name, self.dbm.whichdb(path)) # Now add a key - f = module.open(_fname, 'w') - f[b"1"] = b"1" - # and test that we can find it - self.assertIn(b"1", f) - # and read it - self.assertEqual(f[b"1"], b"1") - f.close() - self.assertEqual(name, self.dbm.whichdb(_fname)) + with module.open(_fname, 'w') as f: + f[b"1"] = b"1" + # and test that we can find it + self.assertIn(b"1", f) + # and read it + self.assertEqual(f[b"1"], b"1") + for path in fnames: + self.assertEqual(name, self.dbm.whichdb(path)) @unittest.skipUnless(ndbm, reason='Test requires ndbm') def test_whichdb_ndbm(self): # Issue 17198: check that ndbm which is referenced in whichdb is defined - db_file = '{}_ndbm.db'.format(_fname) - with open(db_file, 'w'): - self.addCleanup(os_helper.unlink, db_file) - self.assertIsNone(self.dbm.whichdb(db_file[:-3])) + with open(_fname + '.db', 'wb') as f: + f.write(b'spam') + _bytes_fname = os.fsencode(_fname) + fnames = [_fname, os_helper.FakePath(_fname), + _bytes_fname, os_helper.FakePath(_bytes_fname)] + for path in fnames: + self.assertIsNone(self.dbm.whichdb(path)) + + @unittest.skipUnless(dbm_sqlite3, reason='Test requires dbm.sqlite3') + def test_whichdb_sqlite3(self): + # Databases created by dbm.sqlite3 are detected correctly. + with dbm_sqlite3.open(_fname, "c") as db: + db["key"] = "value" + self.assertEqual(self.dbm.whichdb(_fname), "dbm.sqlite3") + + @unittest.skipUnless(dbm_sqlite3, reason='Test requires dbm.sqlite3') + def test_whichdb_sqlite3_existing_db(self): + # Existing sqlite3 databases are detected correctly. + sqlite3 = import_helper.import_module("sqlite3") + try: + # Create an empty database. + with sqlite3.connect(_fname) as cx: + cx.execute("CREATE TABLE dummy(database)") + cx.commit() + finally: + cx.close() + self.assertEqual(self.dbm.whichdb(_fname), "dbm.sqlite3") - def tearDown(self): - delete_files() def setUp(self): - delete_files() - self.filename = os_helper.TESTFN - self.d = dbm.open(self.filename, 'c') - self.d.close() + self.addCleanup(cleaunup_test_dir) + setup_test_dir() self.dbm = import_helper.import_fresh_module('dbm') - def test_keys(self): - self.d = dbm.open(self.filename, 'c') - self.assertEqual(self.d.keys(), []) - a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')] - for k, v in a: - self.d[k] = v - self.assertEqual(sorted(self.d.keys()), sorted(k for (k, v) in a)) - for k, v in a: - self.assertIn(k, self.d) - self.assertEqual(self.d[k], v) - self.assertNotIn(b'xxx', self.d) - self.assertRaises(KeyError, lambda: self.d[b'xxx']) - self.d.close() - - -def load_tests(loader, tests, pattern): - classes = [] - for mod in dbm_iterator(): - classes.append(type("TestCase-" + mod.__name__, - (AnyDBMTestCase, unittest.TestCase), - {'module': mod})) - suites = [unittest.makeSuite(c) for c in classes] - - tests.addTests(suites) - return tests + +for mod in dbm_iterator(): + assert mod.__name__.startswith('dbm.') + suffix = mod.__name__[4:] + testname = f'TestCase_{suffix}' + globals()[testname] = type(testname, + (AnyDBMTestCase, unittest.TestCase), + {'module': mod}) + if __name__ == "__main__": unittest.main() diff --git a/Lib/test/test_dbm_dumb.py b/Lib/test/test_dbm_dumb.py index 0dc489362b2..672f9092207 100644 --- a/Lib/test/test_dbm_dumb.py +++ b/Lib/test/test_dbm_dumb.py @@ -15,6 +15,7 @@ _fname = os_helper.TESTFN + def _delete_files(): for ext in [".dir", ".dat", ".bak"]: try: @@ -41,6 +42,7 @@ def test_dumbdbm_creation(self): self.read_helper(f) @unittest.skipUnless(hasattr(os, 'umask'), 'test needs os.umask()') + @os_helper.skip_unless_working_chmod def test_dumbdbm_creation_mode(self): try: old_umask = os.umask(0o002) @@ -231,7 +233,7 @@ def test_create_new(self): self.assertEqual(f.keys(), []) def test_eval(self): - with open(_fname + '.dir', 'w') as stream: + with open(_fname + '.dir', 'w', encoding="utf-8") as stream: stream.write("str(print('Hacked!')), 0\n") with support.captured_stdout() as stdout: with self.assertRaises(ValueError): @@ -244,9 +246,27 @@ def test_missing_data(self): _delete_files() with self.assertRaises(FileNotFoundError): dumbdbm.open(_fname, value) + self.assertFalse(os.path.exists(_fname + '.dat')) self.assertFalse(os.path.exists(_fname + '.dir')) self.assertFalse(os.path.exists(_fname + '.bak')) + for value in ('c', 'n'): + _delete_files() + with dumbdbm.open(_fname, value) as f: + self.assertTrue(os.path.exists(_fname + '.dat')) + self.assertTrue(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + self.assertFalse(os.path.exists(_fname + '.bak')) + + for value in ('c', 'n'): + _delete_files() + with dumbdbm.open(_fname, value) as f: + f['key'] = 'value' + self.assertTrue(os.path.exists(_fname + '.dat')) + self.assertTrue(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + self.assertTrue(os.path.exists(_fname + '.bak')) + def test_missing_index(self): with dumbdbm.open(_fname, 'n') as f: pass @@ -257,6 +277,60 @@ def test_missing_index(self): self.assertFalse(os.path.exists(_fname + '.dir')) self.assertFalse(os.path.exists(_fname + '.bak')) + for value in ('c', 'n'): + with dumbdbm.open(_fname, value) as f: + self.assertTrue(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + self.assertFalse(os.path.exists(_fname + '.bak')) + os.unlink(_fname + '.dir') + + for value in ('c', 'n'): + with dumbdbm.open(_fname, value) as f: + f['key'] = 'value' + self.assertTrue(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + self.assertTrue(os.path.exists(_fname + '.bak')) + os.unlink(_fname + '.dir') + os.unlink(_fname + '.bak') + + def test_sync_empty_unmodified(self): + with dumbdbm.open(_fname, 'n') as f: + pass + os.unlink(_fname + '.dir') + for value in ('c', 'n'): + with dumbdbm.open(_fname, value) as f: + self.assertTrue(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + f.sync() + self.assertTrue(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + os.unlink(_fname + '.dir') + f.sync() + self.assertFalse(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + self.assertFalse(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + + def test_sync_nonempty_unmodified(self): + with dumbdbm.open(_fname, 'n') as f: + pass + os.unlink(_fname + '.dir') + for value in ('c', 'n'): + with dumbdbm.open(_fname, value) as f: + f['key'] = 'value' + self.assertTrue(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + f.sync() + self.assertTrue(os.path.exists(_fname + '.dir')) + self.assertTrue(os.path.exists(_fname + '.bak')) + os.unlink(_fname + '.dir') + os.unlink(_fname + '.bak') + f.sync() + self.assertFalse(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + self.assertFalse(os.path.exists(_fname + '.dir')) + self.assertFalse(os.path.exists(_fname + '.bak')) + def test_invalid_flag(self): for flag in ('x', 'rf', None): with self.assertRaisesRegex(ValueError, @@ -264,6 +338,7 @@ def test_invalid_flag(self): "'r', 'w', 'c', or 'n'"): dumbdbm.open(_fname, flag) + @os_helper.skip_unless_working_chmod def test_readonly_files(self): with os_helper.temp_dir() as dir: fname = os.path.join(dir, 'db') @@ -293,6 +368,15 @@ def test_nonascii_filename(self): self.assertTrue(b'key' in db) self.assertEqual(db[b'key'], b'value') + def test_open_with_pathlib_path(self): + dumbdbm.open(os_helper.FakePath(_fname), "c").close() + + def test_open_with_bytes_path(self): + dumbdbm.open(os.fsencode(_fname), "c").close() + + def test_open_with_pathlib_bytes_path(self): + dumbdbm.open(os_helper.FakePath(os.fsencode(_fname)), "c").close() + def tearDown(self): _delete_files() diff --git a/Lib/test/test_dbm_gnu.py b/Lib/test/test_dbm_gnu.py deleted file mode 100644 index 66268c42a30..00000000000 --- a/Lib/test/test_dbm_gnu.py +++ /dev/null @@ -1,222 +0,0 @@ -import os -import unittest -from test import support -from test.support import cpython_only, import_helper -from test.support.os_helper import (TESTFN, TESTFN_NONASCII, FakePath, - create_empty_file, temp_dir, unlink) - -gdbm = import_helper.import_module("dbm.gnu") # skip if not supported - -filename = TESTFN - -class TestGdbm(unittest.TestCase): - @staticmethod - def setUpClass(): - if support.verbose: - try: - from _gdbm import _GDBM_VERSION as version - except ImportError: - pass - else: - print(f"gdbm version: {version}") - - def setUp(self): - self.g = None - - def tearDown(self): - if self.g is not None: - self.g.close() - unlink(filename) - - @cpython_only - def test_disallow_instantiation(self): - # Ensure that the type disallows instantiation (bpo-43916) - self.g = gdbm.open(filename, 'c') - support.check_disallow_instantiation(self, type(self.g)) - - def test_key_methods(self): - self.g = gdbm.open(filename, 'c') - self.assertEqual(self.g.keys(), []) - self.g['a'] = 'b' - self.g['12345678910'] = '019237410982340912840198242' - self.g[b'bytes'] = b'data' - key_set = set(self.g.keys()) - self.assertEqual(key_set, set([b'a', b'bytes', b'12345678910'])) - self.assertIn('a', self.g) - self.assertIn(b'a', self.g) - self.assertEqual(self.g[b'bytes'], b'data') - key = self.g.firstkey() - while key: - self.assertIn(key, key_set) - key_set.remove(key) - key = self.g.nextkey(key) - # get() and setdefault() work as in the dict interface - self.assertEqual(self.g.get(b'a'), b'b') - self.assertIsNone(self.g.get(b'xxx')) - self.assertEqual(self.g.get(b'xxx', b'foo'), b'foo') - with self.assertRaises(KeyError): - self.g['xxx'] - self.assertEqual(self.g.setdefault(b'xxx', b'foo'), b'foo') - self.assertEqual(self.g[b'xxx'], b'foo') - - def test_error_conditions(self): - # Try to open a non-existent database. - unlink(filename) - self.assertRaises(gdbm.error, gdbm.open, filename, 'r') - # Try to access a closed database. - self.g = gdbm.open(filename, 'c') - self.g.close() - self.assertRaises(gdbm.error, lambda: self.g['a']) - # try pass an invalid open flag - self.assertRaises(gdbm.error, lambda: gdbm.open(filename, 'rx').close()) - - def test_flags(self): - # Test the flag parameter open() by trying all supported flag modes. - all = set(gdbm.open_flags) - # Test standard flags (presumably "crwn"). - modes = all - set('fsu') - for mode in sorted(modes): # put "c" mode first - self.g = gdbm.open(filename, mode) - self.g.close() - - # Test additional flags (presumably "fsu"). - flags = all - set('crwn') - for mode in modes: - for flag in flags: - self.g = gdbm.open(filename, mode + flag) - self.g.close() - - def test_reorganize(self): - self.g = gdbm.open(filename, 'c') - size0 = os.path.getsize(filename) - - # bpo-33901: on macOS with gdbm 1.15, an empty database uses 16 MiB - # and adding an entry of 10,000 B has no effect on the file size. - # Add size0 bytes to make sure that the file size changes. - value_size = max(size0, 10000) - self.g['x'] = 'x' * value_size - size1 = os.path.getsize(filename) - self.assertGreater(size1, size0) - - del self.g['x'] - # 'size' is supposed to be the same even after deleting an entry. - self.assertEqual(os.path.getsize(filename), size1) - - self.g.reorganize() - size2 = os.path.getsize(filename) - self.assertLess(size2, size1) - self.assertGreaterEqual(size2, size0) - - def test_context_manager(self): - with gdbm.open(filename, 'c') as db: - db["gdbm context manager"] = "context manager" - - with gdbm.open(filename, 'r') as db: - self.assertEqual(list(db.keys()), [b"gdbm context manager"]) - - with self.assertRaises(gdbm.error) as cm: - db.keys() - self.assertEqual(str(cm.exception), - "GDBM object has already been closed") - - def test_bool_empty(self): - with gdbm.open(filename, 'c') as db: - self.assertFalse(bool(db)) - - def test_bool_not_empty(self): - with gdbm.open(filename, 'c') as db: - db['a'] = 'b' - self.assertTrue(bool(db)) - - def test_bool_on_closed_db_raises(self): - with gdbm.open(filename, 'c') as db: - db['a'] = 'b' - self.assertRaises(gdbm.error, bool, db) - - def test_bytes(self): - with gdbm.open(filename, 'c') as db: - db[b'bytes key \xbd'] = b'bytes value \xbd' - with gdbm.open(filename, 'r') as db: - self.assertEqual(list(db.keys()), [b'bytes key \xbd']) - self.assertTrue(b'bytes key \xbd' in db) - self.assertEqual(db[b'bytes key \xbd'], b'bytes value \xbd') - - def test_unicode(self): - with gdbm.open(filename, 'c') as db: - db['Unicode key \U0001f40d'] = 'Unicode value \U0001f40d' - with gdbm.open(filename, 'r') as db: - self.assertEqual(list(db.keys()), ['Unicode key \U0001f40d'.encode()]) - self.assertTrue('Unicode key \U0001f40d'.encode() in db) - self.assertTrue('Unicode key \U0001f40d' in db) - self.assertEqual(db['Unicode key \U0001f40d'.encode()], - 'Unicode value \U0001f40d'.encode()) - self.assertEqual(db['Unicode key \U0001f40d'], - 'Unicode value \U0001f40d'.encode()) - - def test_write_readonly_file(self): - with gdbm.open(filename, 'c') as db: - db[b'bytes key'] = b'bytes value' - with gdbm.open(filename, 'r') as db: - with self.assertRaises(gdbm.error): - del db[b'not exist key'] - with self.assertRaises(gdbm.error): - del db[b'bytes key'] - with self.assertRaises(gdbm.error): - db[b'not exist key'] = b'not exist value' - - @unittest.skipUnless(TESTFN_NONASCII, - 'requires OS support of non-ASCII encodings') - def test_nonascii_filename(self): - filename = TESTFN_NONASCII - self.addCleanup(unlink, filename) - with gdbm.open(filename, 'c') as db: - db[b'key'] = b'value' - self.assertTrue(os.path.exists(filename)) - with gdbm.open(filename, 'r') as db: - self.assertEqual(list(db.keys()), [b'key']) - self.assertTrue(b'key' in db) - self.assertEqual(db[b'key'], b'value') - - def test_nonexisting_file(self): - nonexisting_file = 'nonexisting-file' - with self.assertRaises(gdbm.error) as cm: - gdbm.open(nonexisting_file) - self.assertIn(nonexisting_file, str(cm.exception)) - self.assertEqual(cm.exception.filename, nonexisting_file) - - def test_open_with_pathlib_path(self): - gdbm.open(FakePath(filename), "c").close() - - def test_open_with_bytes_path(self): - gdbm.open(os.fsencode(filename), "c").close() - - def test_open_with_pathlib_bytes_path(self): - gdbm.open(FakePath(os.fsencode(filename)), "c").close() - - def test_clear(self): - kvs = [('foo', 'bar'), ('1234', '5678')] - with gdbm.open(filename, 'c') as db: - for k, v in kvs: - db[k] = v - self.assertIn(k, db) - self.assertEqual(len(db), len(kvs)) - - db.clear() - for k, v in kvs: - self.assertNotIn(k, db) - self.assertEqual(len(db), 0) - - @support.run_with_locale( - 'LC_ALL', - 'fr_FR.iso88591', 'ja_JP.sjis', 'zh_CN.gbk', - 'fr_FR.utf8', 'en_US.utf8', - '', - ) - def test_localized_error(self): - with temp_dir() as d: - create_empty_file(os.path.join(d, 'test')) - self.assertRaises(gdbm.error, gdbm.open, filename, 'r') - - -if __name__ == '__main__': - unittest.main() diff --git a/Lib/test/test_dbm_ndbm.py b/Lib/test/test_dbm_ndbm.py deleted file mode 100644 index e0f31c9a9a3..00000000000 --- a/Lib/test/test_dbm_ndbm.py +++ /dev/null @@ -1,165 +0,0 @@ -from test.support import import_helper -from test.support import os_helper -import_helper.import_module("dbm.ndbm") #skip if not supported -import os -import unittest -import dbm.ndbm -from dbm.ndbm import error - -class DbmTestCase(unittest.TestCase): - - def setUp(self): - self.filename = os_helper.TESTFN - self.d = dbm.ndbm.open(self.filename, 'c') - self.d.close() - - def tearDown(self): - for suffix in ['', '.pag', '.dir', '.db']: - os_helper.unlink(self.filename + suffix) - - def test_keys(self): - self.d = dbm.ndbm.open(self.filename, 'c') - self.assertEqual(self.d.keys(), []) - self.d['a'] = 'b' - self.d[b'bytes'] = b'data' - self.d['12345678910'] = '019237410982340912840198242' - self.d.keys() - self.assertIn('a', self.d) - self.assertIn(b'a', self.d) - self.assertEqual(self.d[b'bytes'], b'data') - # get() and setdefault() work as in the dict interface - self.assertEqual(self.d.get(b'a'), b'b') - self.assertIsNone(self.d.get(b'xxx')) - self.assertEqual(self.d.get(b'xxx', b'foo'), b'foo') - with self.assertRaises(KeyError): - self.d['xxx'] - self.assertEqual(self.d.setdefault(b'xxx', b'foo'), b'foo') - self.assertEqual(self.d[b'xxx'], b'foo') - self.d.close() - - def test_empty_value(self): - if dbm.ndbm.library == 'Berkeley DB': - self.skipTest("Berkeley DB doesn't distinguish the empty value " - "from the absent one") - self.d = dbm.ndbm.open(self.filename, 'c') - self.assertEqual(self.d.keys(), []) - self.d['empty'] = '' - self.assertEqual(self.d.keys(), [b'empty']) - self.assertIn(b'empty', self.d) - self.assertEqual(self.d[b'empty'], b'') - self.assertEqual(self.d.get(b'empty'), b'') - self.assertEqual(self.d.setdefault(b'empty'), b'') - self.d.close() - - def test_modes(self): - for mode in ['r', 'rw', 'w', 'n']: - try: - self.d = dbm.ndbm.open(self.filename, mode) - self.d.close() - except error: - self.fail() - - def test_context_manager(self): - with dbm.ndbm.open(self.filename, 'c') as db: - db["ndbm context manager"] = "context manager" - - with dbm.ndbm.open(self.filename, 'r') as db: - self.assertEqual(list(db.keys()), [b"ndbm context manager"]) - - with self.assertRaises(dbm.ndbm.error) as cm: - db.keys() - self.assertEqual(str(cm.exception), - "DBM object has already been closed") - - def test_bytes(self): - with dbm.ndbm.open(self.filename, 'c') as db: - db[b'bytes key \xbd'] = b'bytes value \xbd' - with dbm.ndbm.open(self.filename, 'r') as db: - self.assertEqual(list(db.keys()), [b'bytes key \xbd']) - self.assertTrue(b'bytes key \xbd' in db) - self.assertEqual(db[b'bytes key \xbd'], b'bytes value \xbd') - - def test_unicode(self): - with dbm.ndbm.open(self.filename, 'c') as db: - db['Unicode key \U0001f40d'] = 'Unicode value \U0001f40d' - with dbm.ndbm.open(self.filename, 'r') as db: - self.assertEqual(list(db.keys()), ['Unicode key \U0001f40d'.encode()]) - self.assertTrue('Unicode key \U0001f40d'.encode() in db) - self.assertTrue('Unicode key \U0001f40d' in db) - self.assertEqual(db['Unicode key \U0001f40d'.encode()], - 'Unicode value \U0001f40d'.encode()) - self.assertEqual(db['Unicode key \U0001f40d'], - 'Unicode value \U0001f40d'.encode()) - - def test_write_readonly_file(self): - with dbm.ndbm.open(self.filename, 'c') as db: - db[b'bytes key'] = b'bytes value' - with dbm.ndbm.open(self.filename, 'r') as db: - with self.assertRaises(error): - del db[b'not exist key'] - with self.assertRaises(error): - del db[b'bytes key'] - with self.assertRaises(error): - db[b'not exist key'] = b'not exist value' - - @unittest.skipUnless(os_helper.TESTFN_NONASCII, - 'requires OS support of non-ASCII encodings') - def test_nonascii_filename(self): - filename = os_helper.TESTFN_NONASCII - for suffix in ['', '.pag', '.dir', '.db']: - self.addCleanup(os_helper.unlink, filename + suffix) - with dbm.ndbm.open(filename, 'c') as db: - db[b'key'] = b'value' - self.assertTrue(any(os.path.exists(filename + suffix) - for suffix in ['', '.pag', '.dir', '.db'])) - with dbm.ndbm.open(filename, 'r') as db: - self.assertEqual(list(db.keys()), [b'key']) - self.assertTrue(b'key' in db) - self.assertEqual(db[b'key'], b'value') - - def test_nonexisting_file(self): - nonexisting_file = 'nonexisting-file' - with self.assertRaises(dbm.ndbm.error) as cm: - dbm.ndbm.open(nonexisting_file) - self.assertIn(nonexisting_file, str(cm.exception)) - self.assertEqual(cm.exception.filename, nonexisting_file) - - def test_open_with_pathlib_path(self): - dbm.ndbm.open(os_helper.FakePath(self.filename), "c").close() - - def test_open_with_bytes_path(self): - dbm.ndbm.open(os.fsencode(self.filename), "c").close() - - def test_open_with_pathlib_bytes_path(self): - dbm.ndbm.open(os_helper.FakePath(os.fsencode(self.filename)), "c").close() - - def test_bool_empty(self): - with dbm.ndbm.open(self.filename, 'c') as db: - self.assertFalse(bool(db)) - - def test_bool_not_empty(self): - with dbm.ndbm.open(self.filename, 'c') as db: - db['a'] = 'b' - self.assertTrue(bool(db)) - - def test_bool_on_closed_db_raises(self): - with dbm.ndbm.open(self.filename, 'c') as db: - db['a'] = 'b' - self.assertRaises(dbm.ndbm.error, bool, db) - - def test_clear(self): - kvs = [('foo', 'bar'), ('1234', '5678')] - with dbm.ndbm.open(self.filename, 'c') as db: - for k, v in kvs: - db[k] = v - self.assertIn(k, db) - self.assertEqual(len(db), len(kvs)) - - db.clear() - for k, v in kvs: - self.assertNotIn(k, db) - self.assertEqual(len(db), 0) - - -if __name__ == '__main__': - unittest.main() diff --git a/crates/stdlib/src/sqlite.rs b/crates/stdlib/src/sqlite.rs index f9ef3f7a8d2..a07d88dcd05 100644 --- a/crates/stdlib/src/sqlite.rs +++ b/crates/stdlib/src/sqlite.rs @@ -306,19 +306,14 @@ mod _sqlite { /// - Legacy (default): use isolation_level to control transactions /// - Enabled: autocommit mode (no automatic transactions) /// - Disabled: manual commit mode - #[derive(Clone, Copy, Debug, PartialEq, Eq)] + #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] enum AutocommitMode { + #[default] Legacy, Enabled, Disabled, } - impl Default for AutocommitMode { - fn default() -> Self { - Self::Legacy - } - } - impl TryFromBorrowedObject<'_> for AutocommitMode { fn try_from_borrowed_object(vm: &VirtualMachine, obj: &PyObject) -> PyResult { if obj.is(&vm.ctx.true_value) { From 8d70096d692ff00f56b91776d4fbb628a84770b0 Mon Sep 17 00:00:00 2001 From: "Jeong, YunWon" Date: Thu, 1 Jan 2026 10:05:25 +0900 Subject: [PATCH 3/3] fix --- crates/stdlib/src/sqlite.rs | 76 ++++++++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 17 deletions(-) diff --git a/crates/stdlib/src/sqlite.rs b/crates/stdlib/src/sqlite.rs index a07d88dcd05..7e0392b1f30 100644 --- a/crates/stdlib/src/sqlite.rs +++ b/crates/stdlib/src/sqlite.rs @@ -29,21 +29,21 @@ mod _sqlite { sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name, sqlite3_bind_text, sqlite3_blob, sqlite3_blob_bytes, sqlite3_blob_close, sqlite3_blob_open, sqlite3_blob_read, sqlite3_blob_write, sqlite3_busy_timeout, sqlite3_changes, - sqlite3_close, sqlite3_column_blob, sqlite3_column_bytes, sqlite3_column_count, - sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int64, sqlite3_column_name, - sqlite3_column_text, sqlite3_column_type, sqlite3_complete, sqlite3_context, - sqlite3_context_db_handle, sqlite3_create_collation_v2, sqlite3_create_function_v2, - sqlite3_create_window_function, sqlite3_data_count, sqlite3_db_handle, sqlite3_errcode, - sqlite3_errmsg, sqlite3_exec, sqlite3_expanded_sql, sqlite3_extended_errcode, - sqlite3_finalize, sqlite3_get_autocommit, sqlite3_interrupt, sqlite3_last_insert_rowid, - sqlite3_libversion, sqlite3_limit, sqlite3_open_v2, sqlite3_prepare_v2, - sqlite3_progress_handler, sqlite3_reset, sqlite3_result_blob, sqlite3_result_double, - sqlite3_result_error, sqlite3_result_error_nomem, sqlite3_result_error_toobig, - sqlite3_result_int64, sqlite3_result_null, sqlite3_result_text, sqlite3_set_authorizer, - sqlite3_sleep, sqlite3_step, sqlite3_stmt, sqlite3_stmt_busy, sqlite3_stmt_readonly, - sqlite3_threadsafe, sqlite3_total_changes, sqlite3_trace_v2, sqlite3_user_data, - sqlite3_value, sqlite3_value_blob, sqlite3_value_bytes, sqlite3_value_double, - sqlite3_value_int64, sqlite3_value_text, sqlite3_value_type, + sqlite3_column_blob, sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_decltype, + sqlite3_column_double, sqlite3_column_int64, sqlite3_column_name, sqlite3_column_text, + sqlite3_column_type, sqlite3_complete, sqlite3_context, sqlite3_context_db_handle, + sqlite3_create_collation_v2, sqlite3_create_function_v2, sqlite3_create_window_function, + sqlite3_data_count, sqlite3_db_handle, sqlite3_errcode, sqlite3_errmsg, sqlite3_exec, + sqlite3_expanded_sql, sqlite3_extended_errcode, sqlite3_finalize, sqlite3_get_autocommit, + sqlite3_interrupt, sqlite3_last_insert_rowid, sqlite3_libversion, sqlite3_limit, + sqlite3_open_v2, sqlite3_prepare_v2, sqlite3_progress_handler, sqlite3_reset, + sqlite3_result_blob, sqlite3_result_double, sqlite3_result_error, + sqlite3_result_error_nomem, sqlite3_result_error_toobig, sqlite3_result_int64, + sqlite3_result_null, sqlite3_result_text, sqlite3_set_authorizer, sqlite3_sleep, + sqlite3_step, sqlite3_stmt, sqlite3_stmt_busy, sqlite3_stmt_readonly, sqlite3_threadsafe, + sqlite3_total_changes, sqlite3_trace_v2, sqlite3_user_data, sqlite3_value, + sqlite3_value_blob, sqlite3_value_bytes, sqlite3_value_double, sqlite3_value_int64, + sqlite3_value_text, sqlite3_value_type, }; use malachite_bigint::Sign; use rustpython_common::{ @@ -161,7 +161,7 @@ mod _sqlite { #[pyattr] const PARSE_COLNAMES: c_int = 2; #[pyattr] - const LEGACY_TRANSACTION_CONTROL: c_int = 1; + const LEGACY_TRANSACTION_CONTROL: c_int = -1; #[pyattr] use libsqlite3_sys::{ @@ -1519,6 +1519,28 @@ mod _sqlite { #[pygetset(setter)] fn set_autocommit(&self, val: PyObjectRef, vm: &VirtualMachine) -> PyResult<()> { let mode = AutocommitMode::try_from_borrowed_object(vm, &val)?; + let db = self.db_lock(vm)?; + + // Handle transaction state based on mode change + match mode { + AutocommitMode::Enabled => { + // If there's a pending transaction, commit it + if !db.is_autocommit() { + db._exec(b"COMMIT", vm)?; + } + } + AutocommitMode::Disabled => { + // If not in a transaction, begin one + if db.is_autocommit() { + db._exec(b"BEGIN", vm)?; + } + } + AutocommitMode::Legacy => { + // Legacy mode doesn't change transaction state + } + } + + drop(db); *self.autocommit.lock() = mode; Ok(()) } @@ -2016,6 +2038,18 @@ mod _sqlite { impl SelfIter for Cursor {} impl IterNext for Cursor { fn next(zelf: &Py, vm: &VirtualMachine) -> PyResult { + // Check if connection is closed first, and if so, clear statement to release file lock + if zelf.connection.is_closed() { + let mut guard = zelf.inner.lock(); + if let Some(stmt) = guard.as_mut().and_then(|inner| inner.statement.take()) { + stmt.lock().reset(); + } + return Err(new_programming_error( + vm, + "Cannot operate on a closed database.".to_owned(), + )); + } + let mut inner = zelf.inner(vm)?; let Some(stmt) = &inner.statement else { return Ok(PyIterReturn::StopIteration(None)); @@ -2720,9 +2754,17 @@ mod _sqlite { } } + // sqlite3_close_v2 is not exported by libsqlite3-sys, so we declare it manually. + // It handles "zombie close" - if there are still unfinalized statements, + // the database will be closed when the last statement is finalized. + unsafe extern "C" { + fn sqlite3_close_v2(db: *mut sqlite3) -> c_int; + } + impl Drop for Sqlite { fn drop(&mut self) { - unsafe { sqlite3_close(self.raw.db) }; + // Use sqlite3_close_v2 for safe closing even with active statements + unsafe { sqlite3_close_v2(self.raw.db) }; } }