第十六章 操作数据库 本章核心:通过实现一个分层架构的 SQLite 操作框架 ,深入学习 Python 数据库交互的最佳实践。我们将涵盖从数据库连接管理、模型定义、数据访问到业务逻辑的完整流程。
16.1 SQLite 数据库:构建一个分层操作框架 核心特性: 内置、无服务器、单文件数据库。遵循 DB-API 2.0 规范。
常用 API 参考表:
API 描述 sqlite3.connect(database, ...)
连接数据库文件,返回 Connection
对象。 connection.cursor()
创建 Cursor
对象。 cursor.execute(sql, parameters)
执行单条 SQL (用 ?
占位符)。 cursor.executemany(sql, seq_of_params)
批量执行 SQL。 connection.commit()
提交事务。 connection.rollback()
回滚事务。 cursor.fetchone()
获取下一行结果 (元组或 Row 对象),无结果时 None
。 cursor.fetchall()
获取所有剩余行结果 (元组或 Row 对象的列表)。 cursor.lastrowid
(属性) 最后 INSERT 的行的 ROWID。 cursor.rowcount
(属性) 最后 DML 操作影响的行数 (-1 表示不确定或不适用)。 connection.close()
关闭连接。 connection.row_factory = sqlite3.Row
(设置) 让查询结果可以通过列名访问(类似字典)。 sqlite3.Binary(bytes)
用于封装要存入 BLOB 字段的二进制数据。 with sqlite3.connect(...) as conn:
(推荐) 使用上下文管理器,自动处理连接关闭和基本事务。
框架实现步骤
16.1.1 项目结构设置 我们创建以下目录结构,并在每个目录中放入 __init__.py
文件使其成为 Python 包:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 sqlite_practice/ ├── core/ # 核心功能模块 │ ├── __init__.py │ ├── db_manager.py # 数据库连接管理 │ └── table_manager.py # 表结构管理 ├── models/ # 数据模型 │ ├── __init__.py │ └── task.py # 任务数据模型 ├── repositories/ # 数据访问层 │ ├── __init__.py │ └── task_repository.py # 任务数据访问 ├── services/ # 业务逻辑层 │ ├── __init__.py │ └── task_service.py # 任务业务逻辑 ├── utils/ # 工具类 │ ├── __init__.py │ └── date_utils.py # 日期处理工具(被删减了) ├── examples/ # 示例代码 │ ├── __init__.py │ ├── basic_operations.py # 基本操作示例 │ ├── advanced_queries_oop.py # 高级查询示例 │ └── README.md # 项目说明文档
要点:
16.1.2 核心层实现 (core/
) 数据库管理器 (core/db_manager.py
) 此类负责管理数据库连接,提供连接、关闭及上下文管理 (with
语句) 支持。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 """ 任务数据模型模块,定义Task相关的数据结构。 # sqlite_practice/models/task.py """ from dataclasses import dataclass, field, asdictfrom typing import Optional from datetime import datetime@dataclass class Task : """任务数据模型类,使用dataclass简化代码""" title: str priority: int = 3 is_completed: bool = False task_id: Optional [int ] = None due_date: Optional [str ] = None attachment: Optional [str ] = None created_at: Optional [str ] = None description: Optional [str ] = None last_updated: Optional [str ] = None def __post_init__ (self ): """dataclass提供的初始化后自动执行的函数,用于设置默认值""" if self .created_at is None : self .created_at = datetime.now().strftime('%Y-%m-%d %H:%M:%S' ) def to_dict (self ) -> dict : """将对象转换为字典,用于数据库操作 Returns: dict: 包含任务数据的字典 """ task_dict = asdict(self ) task_dict['is_completed' ] = 1 if self .is_completed else 0 if self .task_id is None : task_dict.pop('task_id' ) return task_dict @classmethod def from_row (cls, row ) -> 'Task' : """从数据库行创建Task对象 这个方法的作用是将数据库查询结果(SQLite行数据)转换为Task对象。 看起来复杂是因为它需要处理多种可能的输入格式: 1. sqlite3.Row对象(有keys方法的字典类对象) 2. 元组或列表形式的结果 虽然dataclass简化了类定义,但不能自动处理从外部数据源(如数据库) 创建对象的过程,尤其是当数据需要类型转换时(如整数到布尔值)。 这种复杂性是为了提高代码的健壮性,确保从不同来源的数据都能正确转换为Task对象。 如果确定数据库始终返回同一格式的结果,可以简化此方法。 Args: row: sqlite3.Row对象或类似字典/序列的对象 Returns: Task: 创建的Task对象 """ task_data = {} if hasattr (row, 'keys' ): for key in row.keys(): task_data[key] = row[key] else : try : task_data = { 'task_id' : row[0 ] if len (row) > 0 else None , 'title' : row[1 ] if len (row) > 1 else None , 'description' : row[2 ] if len (row) > 2 else None , 'priority' : row[3 ] if len (row) > 3 else 3 , 'due_date' : row[4 ] if len (row) > 4 else None , 'is_completed' : bool (row[5 ]) if len (row) > 5 else False , 'attachment' : row[6 ] if len (row) > 6 else None , 'created_at' : row[7 ] if len (row) > 7 else None , 'last_updated' : row[8 ] if len (row) > 8 else None } except (IndexError, TypeError): return cls(title="Unknown" ) if 'is_completed' in task_data: task_data['is_completed' ] = bool (task_data['is_completed' ]) return cls(**task_data)
表管理器 (core/table_manager.py
) `
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 """ 表管理模块,负责SQLite数据库表的创建和基础操作。 提供表结构定义、创建、修改等功能。 """ import sqlite3from typing import Dict , Any , List , Optional class TableManager : """表管理类,负责SQLite数据库表的创建和操作""" def __init__ (self, conn: sqlite3.Connection, cursor: sqlite3.Cursor ): """初始化表管理器 Args: conn: 数据库连接对象 cursor: 数据库游标对象 """ self .conn = conn self .cursor = cursor def create_table (self, table_name: str , columns: Dict [str , str ] ) -> bool : """创建数据库表 Args: table_name: 表名 columns: 表列定义字典,格式为 {列名: 列类型} Returns: bool: 表创建是否成功 """ try : if not self .conn or not self .cursor: print (f"[Error] 数据库连接或游标无效" ) return False columns_def = [] for col_name, col_type in columns.items(): columns_def.append(f"{col_name} {col_type} " ) create_table_sql = f'''CREATE TABLE IF NOT EXISTS {table_name} ( {', ' .join(columns_def)} )''' self .cursor.execute(create_table_sql) self .conn.commit() print (f"[Setup] 表 '{table_name} ' 已检查/创建。" ) return True except sqlite3.Error as e: print (f"[Error] 创建表 '{table_name} ' 时出错: {e} " ) return False def add_column (self, table_name: str , column_name: str , column_type: str ) -> bool : """向已有表添加新列 Args: table_name: 表名 column_name: 列名 column_type: 列类型和约束 Returns: bool: 添加列是否成功 """ try : alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type} " self .cursor.execute(alter_sql) self .conn.commit() print (f"[Setup] 已向表 '{table_name} ' 添加列 '{column_name} {column_type} '" ) return True except sqlite3.Error as e: print (f"[Error] 向表 '{table_name} ' 添加列时出错: {e} " ) return False def table_exists (self, table_name: str ) -> bool : """检查表是否存在 Args: table_name: 表名 Returns: bool: 表是否存在 """ try : self .cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name=?" , (table_name,)) return self .cursor.fetchone() is not None except sqlite3.Error as e: print (f"[Error] 检查表 '{table_name} ' 是否存在时出错: {e} " ) return False def get_table_info (self, table_name: str ) -> List [Dict [str , Any ]]: """获取表结构信息 Args: table_name: 表名 Returns: list: 包含列信息的字典列表 """ try : self .cursor.execute(f"PRAGMA table_info({table_name} )" ) columns = [] for row in self .cursor.fetchall(): column_info = { 'cid' : row['cid' ], 'name' : row['name' ], 'type' : row['type' ], 'notnull' : row['notnull' ], 'default_value' : row['dflt_value' ], 'pk' : row['pk' ] } columns.append(column_info) return columns except sqlite3.Error as e: print (f"[Error] 获取表 '{table_name} ' 信息时出错: {e} " ) return []
16.1.3 数据模型层实现 (models/
) 前置知识:使用 @dataclass
简化 Python 类定义 在定义 Task
模型之前,我们先了解一下 Python 3.7+ 引入的一个非常有用的工具:@dataclass
装饰器。它能极大地简化用于存储数据的类的编写。
dataclass
是从 Python 3.7 版本开始,作为标准库 dataclasses
中的模块被引入的。 随着 Python 版本的不断更新,dataclass
也逐步发展和完善,为 Python 开发者提供了更加便捷的数据类创建和管理方式。
dataclass
的主要功能在于帮助我们简化数据类的定义过程 。 本文总结了几个我们在此框架中会用到的 dataclass
技巧。
1. 传统的类定义方式
回顾一下,如果不用 dataclass
,定义一个简单的 CoinTrans
类(包含交易 ID, 交易对, 价格, 是否成功, 地址列表)大致如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 class CoinTransTraditional : def __init__ ( self, id : str , symbol: str , price: float , is_success: bool , addrs: list , ) -> None : self .id = id self .symbol = symbol self .price = price self .is_success = is_success self .addrs = addrs def __repr__ (self ) -> str : return (f"CoinTransTraditional(id='{self.id } ', symbol='{self.symbol} ', " f"price={self.price} , is_success={self.is_success} , addrs={self.addrs} )" ) if __name__ == "__main__" : coin_trans_trad = CoinTransTraditional("id01" , "BTC/USDT" , 71000.0 , True , ["0x1111" , "0x2222" ]) print (coin_trans_trad)
如你所见,我们需要编写 __init__
方法来初始化所有属性,还需要编写 __repr__
(或 __str__
) 方法才能方便地打印对象内容。
2. 使用 @dataclass
装饰器定义类
现在看看用 @dataclass
有多简单:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 from dataclasses import dataclass, field from typing import List @dataclass class CoinTransDataclass : id : str symbol: str price: float is_success: bool addrs: List [str ] if __name__ == "__main__" : coin_trans_dc = CoinTransDataclass("id01" , "BTC/USDT" , 71000.0 , True , ["0x1111" , "0x2222" ]) print (coin_trans_dc)
运行结果会直接打印出易于阅读的对象表示:
1 CoinTransDataclass(id='id01', symbol='BTC/USDT', price=71000.0, is_success=True, addrs=['0x1111', '0x2222'])
关键优势:
自动生成方法: @dataclass
会自动为你生成 __init__
、__repr__
、__eq__
(等值比较) 等常用方法,大大减少样板代码。类型提示: 它强制(或强烈推荐)使用类型提示,有助于代码清晰度和静态分析。2.1 默认值与 default_factory
设置默认值很简单,直接在属性后面赋值即可。但对于可变类型 (如 list
, dict
)作为默认值,直接赋值会引发 ValueError
,因为所有实例会共享同一个可变对象,这通常不是我们想要的。需要使用 field
函数和 default_factory
来指定一个工厂函数 (一个无参可调用对象,返回默认值)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 from dataclasses import dataclass, fieldfrom typing import List def default_addr_list () -> List [str ]: print (" (调用了 default_addr_list 工厂函数)" ) return ["0xdefault1" , "0xdefault2" ] @dataclass class CoinTransWithDefaults : id : str = "default_id" symbol: str = "BTC/USDT" price: float = 0.0 is_success: bool = False addrs: List [str ] = field(default_factory=default_addr_list) if __name__ == "__main__" : print ("\n--- Dataclass 默认值示例 ---" ) default_trans = CoinTransWithDefaults() print ("创建第一个实例:" ) print (default_trans) print ("创建第二个实例:" ) another_default = CoinTransWithDefaults() print (another_default) default_trans.addrs.append("0xadded" ) print ("修改第一个实例的 addrs 后:" ) print (f" 实例1: {default_trans} " ) print (f" 实例2: {another_default} " )
2.2 隐藏敏感信息 (repr=False
)
如果你不希望某些字段出现在 print()
(即 __repr__
) 的输出中(例如密码、密钥等),可以在 field
中设置 repr=False
。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 @dataclass class SensitiveData : username: str session_id: str token: str = field(repr =False ) if __name__ == "__main__" : print ("\n--- Dataclass 隐藏字段示例 ---" ) data = SensitiveData("user1" , "sess123" , token="secret_token_value" ) print (data) print (f" 访问隐藏的 token: {data.token} " )
2.3 只读对象 (frozen=True
)
如果你希望创建的对象是不可变 的(创建后其属性值不能被修改),可以在 @dataclass
装饰器中设置 frozen=True
。这对于表示常量数据或确保数据不被意外篡改很有用。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 @dataclass(frozen=True ) class ImmutablePoint : x: float y: float if __name__ == "__main__" : print ("\n--- Dataclass 只读对象示例 ---" ) p = ImmutablePoint(1.0 , 2.0 ) print (p) try : p.x = 5.0 except Exception as e: print (f" 尝试修改只读对象失败: {type (e).__name__} : {e} " )
2.4 转换为元组和字典 (astuple
, asdict
)
dataclasses
模块提供了 astuple
和 asdict
函数,可以方便地将 dataclass 实例转换为元组或字典,这在与其他库或模块交互时非常方便。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 from dataclasses import astuple, asdict, dataclass@dataclass class SimpleConfig : host: str = "localhost" port: int = 8080 debug_mode: bool = False if __name__ == "__main__" : print ("\n--- Dataclass 转换示例 ---" ) config = SimpleConfig(port=9000 ) config_tuple = astuple(config) print (f" 转换为元组: {config_tuple} " ) config_dict = asdict(config) print (f" 转换为字典: {config_dict} " )
总结 (@dataclass
)
在 Python 中,数据类主要用于存储数据。 定义数据类时,通常需要编写一些重复性的代码,如构造函数 (__init__
)、字符串表示 (__repr__
, __str__
) 等。 @dataclass
装饰器的出现,使得这些通用方法的生成变得自动化,从而极大地简化了数据类的定义过程,提高了开发效率,是我们在数据建模时非常有用的工具。
现在,我们将使用 @dataclass
来定义我们的 Task
模型。
任务模型 (models/task.py
- 使用 @dataclass) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 """ 任务数据模型模块,定义Task相关的数据结构。 # sqlite_practice/models/task.py """ from dataclasses import dataclass, fieldfrom typing import Optional from datetime import datetime@dataclass class Task : """任务数据模型类,使用dataclass简化代码""" title: str description: Optional [str ] = None priority: int = 3 due_date: Optional [str ] = None is_completed: bool = False attachment: Optional [bytes ] = None created_at: Optional [str ] = None last_updated: Optional [str ] = None task_id: Optional [int ] = None def __post_init__ (self ): """对象初始化后运行的方法,用于设置默认值""" if self .created_at is None : self .created_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S" ) def to_dict (self ) -> dict : """将对象转换为字典,用于数据库操作 Returns: dict: 包含任务数据的字典 """ task_dict = { 'title' : self .title, 'description' : self .description, 'priority' : self .priority, 'due_date' : self .due_date, 'is_completed' : 1 if self .is_completed else 0 , 'attachment' : self .attachment, 'created_at' : self .created_at, 'last_updated' : self .last_updated } if self .task_id is not None : task_dict['task_id' ] = self .task_id return task_dict @classmethod def from_row (cls, row ) -> 'Task' : """从数据库行创建Task对象 Args: row: sqlite3.Row对象或类似字典的对象 Returns: Task: 创建的Task对象 """ task_data = {} if hasattr (row, 'keys' ): for key in row.keys(): task_data[key] = row[key] else : try : task_data = { 'task_id' : row[0 ] if len (row) > 0 else None , 'title' : row[1 ] if len (row) > 1 else None , 'description' : row[2 ] if len (row) > 2 else None , 'priority' : row[3 ] if len (row) > 3 else 3 , 'due_date' : row[4 ] if len (row) > 4 else None , 'is_completed' : bool (row[5 ]) if len (row) > 5 else False , 'attachment' : row[6 ] if len (row) > 6 else None , 'created_at' : row[7 ] if len (row) > 7 else None , 'last_updated' : row[8 ] if len (row) > 8 else None } except (IndexError, TypeError): return cls(title="Unknown" ) if 'is_completed' in task_data: task_data['is_completed' ] = bool (task_data['is_completed' ]) return cls(**task_data)
知识点解释 (@dataclass
版 Task
):
@dataclass
装饰器: 应用在类定义之前,自动添加 __init__
, __repr__
, __eq__
等方法。属性定义与类型提示: 直接在类级别声明属性及其类型 (title: str
, priority: int = 3
)。Optional[T]
表示该属性可以是 T
类型或 None
。默认值: 直接在属性声明后赋值即可为非可变类型设置默认值 (priority: int = 3
)。field()
函数 (可选): 用于更精细地控制字段行为,如设置 default_factory
(用于可变默认值,本例未使用但已在知识点中介绍), repr=False
(不在打印输出中显示), init=False
(不由 __init__
初始化), compare=False
(不在等值比较中考虑) 等。to_dict()
方法: 方便将对象状态转换为字典,特别用于准备插入或更新到数据库的数据。注意这里手动处理了 is_completed
从 bool
到 int
的转换,以匹配 SQLite 的存储方式。@classmethod from_row(cls, row)
: 这是一个类方法 (第一个参数是类本身 cls
,而不是实例 self
),专门用于从数据库返回的行数据(这里是 sqlite3.Row
对象)创建 Task
类的实例。这是将数据库数据映射回对象的常用模式。它负责处理列名到属性的映射和必要的数据类型转换(如 bool(row['is_completed'])
)。增加了错误处理,确保输入有效。16.1.4 数据访问层实现 (repositories/
) 任务仓库 (repositories/task_repository.py
) 此类封装所有与 tasks
表相关的 SQL 操作,并使用 Task
dataclass 进行数据交互。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 """ 任务数据访问层模块,实现对Task数据的CRUD操作。 """ import sqlite3from typing import List , Optional , Dict , Any , Tuple from models.task import Taskfrom core.table_manager import TableManagerclass TaskRepository : """任务数据访问类,实现对Task表的增删改查操作""" def __init__ (self, conn: sqlite3.Connection, cursor: sqlite3.Cursor ): self .conn = conn self .cursor = cursor self .table_name = 'tasks' self .table_manager = TableManager(self .conn, self .cursor) def create_table (self ) -> bool : """创建任务表""" columns = { "task_id" : "INTEGER PRIMARY KEY AUTOINCREMENT" , "title" : "TEXT NOT NULL" , "description" : "TEXT" , "priority" : "INTEGER DEFAULT 3" , "due_date" : "DATE" , "is_completed" : "BOOLEAN DEFAULT 0" , "attachment" : "BLOB" , "created_at" : "TIMESTAMP DEFAULT CURRENT_TIMESTAMP" , "last_updated" : "TIMESTAMP" } created_flag = self .table_manager.create_table(self .table_name, columns) return created_flag def add_column (self, column_name: str , column_type: str ) -> bool : """向任务表添加新列 Args: column_name: 新列名 column_type: 新列类型 """ added_flag = self .table_manager.add_column(self .table_name, column_name, column_type) return added_flag def get_table_info (self ) -> List [Dict [str , Any ]]: """获取任务表结构信息 Returns: list: 包含列信息的字典列表 """ table_info = self .table_manager.get_table_info(self .table_name) return table_info def insert_task (self, task: Task ) -> Optional [int ]: """插入新任务 Args: task: Task对象 Returns: int: 新任务的ID,如果插入失败返回None """ try : task_dict = task.to_dict() if 'task_id' in task_dict: del task_dict['task_id' ] columns = ', ' .join(task_dict.keys()) """ placeholders的执行逻辑:类似于下方的代码 task_dict = {'task1': 'value1', 'task2': 'value2', 'task3': 'value3'} question_marks = ', '.join(['?' for _ in task_dict]) print(question_marks) # 输出: ?, ?, ? '""" placeholders = placeholders = ', ' .join(['?' for _ in task_dict]) values = tuple (task_dict.values()) insert_sql = f"INSERT INTO {self.table_name} ({columns} ) VALUES ({placeholders} )" self .cursor.execute(insert_sql, values) self .conn.commit() inserted_id = self .cursor.lastrowid print (f"[Info] 成功插入任务: '{task.title} ', ID: {inserted_id} " ) return inserted_id except sqlite3.Error as e: print (f"[Error] 任务插入失败: {e} " ) self .conn.rollback() return None def insert_many (self, tasks: List [Task] ) -> Optional [int ]: """批量插入任务记录 Args: tasks: 要插入的Task对象列表 Returns: int: 成功插入的记录数,如果插入失败则返回None """ try : if not tasks: print ("[Warning] 没有任务要插入" ) return 0 task_dicts = [task.to_dict() for task in tasks] for task_dict in task_dicts: if 'task_id' in task_dict: del task_dict['task_id' ] columns = list (task_dicts[0 ].keys()) placeholders = ', ' .join(['?' for _ in columns]) insert_sql = f"INSERT INTO {self.table_name} ({', ' .join(columns)} ) VALUES ({placeholders} )" values = [tuple (task_dict[col] for col in columns) for task_dict in task_dicts] print (f"[Info] 准备批量插入{len (values)} 条任务" ) self .cursor.executemany(insert_sql, values) self .conn.commit() inserted_count = self .cursor.rowcount print (f"[Info] 成功插入{inserted_count} 条任务" ) return inserted_count except sqlite3.Error as e: print (f"[Error] 任务插入失败: {e} " ) self .conn.rollback() return None def update_task (self, task: Task ) -> bool : """更新任务记录 Args: task: 要更新的Task对象,必须包含task_id Returns: bool: 是否更新成功 """ try : if task.task_id is None : print ("[Error] 无法更新任务:缺少task_id" ) return False task_dict = task.to_dict() task_id = task_dict.pop('task_id' ) """ 这将生成类似于下面的SQL语句: UPDATE tasks SET title = 'task1', description = 'task1 description', priority = 1, due_date = '2022-01-01', is_completed = 1, attachment = b'123456', last_updated = '2022-01-01 00:00:00' WHERE task_id = 1 """ set_clause = ', ' .join([f"{col} = ?" for col in task_dict.keys()]) if 'last_updated' not in task_dict: set_clause += ", last_updated = CURRENT_TIMESTAMP" update_sql = f"UPDATE {self.table_name} SET {set_clause} WHERE task_id = ?" values = list (task_dict.values()) values.append(task_id) self .cursor.execute(update_sql, values) self .conn.commit() updated_count = self .cursor.rowcount if updated_count > 0 : print (f"[Info] 成功更新任务 ID={task_id} " ) return True else : print (f"[Warning] 未找到要更新的任务 ID={task_id} " ) return False except sqlite3.Error as e: print (f"[Error] 任务更新失败: {e} " ) self .conn.rollback() return False def delete_task (self, task_id: int ) -> bool : """删除任务记录 Args: task_id: 要删除的任务ID Returns: bool: 是否删除成功 """ try : delete_sql = f"DELETE FROM {self.table_name} WHERE task_id = ?" self .cursor.execute(delete_sql, (task_id,)) deleted_count = self .cursor.rowcount if deleted_count > 0 : print (f"[Info] 成功删除任务 ID={task_id} " ) return True else : print (f"[Warning] 未找到要删除的任务 ID={task_id} " ) return False except sqlite3.Error as e: print (f"[Error] 任务删除失败: {e} " ) self .conn.rollback() return False def find_task_by_id (self, task_id: int ) -> Optional [Task]: """根据任务ID查找任务记录""" try : select_sql = f"SELECT * FROM {self.table_name} WHERE task_id = ?" self .cursor.execute(select_sql, (task_id,)) row = self .cursor.fetchone() if row is None : return None return Task.from_row(row) except sqlite3.Error as e: print (f"[Error] 任务查找失败: {e} " ) return None def find_all_tasks (self ) -> List [Task]: """查找所有任务记录""" try : select_sql = f"SELECT * FROM {self.table_name} " self .cursor.execute(select_sql) rows = self .cursor.fetchall() return [Task.from_row(row) for row in rows] except sqlite3.Error as e: print (f"[Error] 任务查找失败: {e} " ) return [] def find_by_criteria (self, criteria: Dict [str , Any ] ) -> List [Task]: """根据条件查找任务 Args: criteria: 查询条件字典,格式为 {列名: 值} Returns: list: 符合条件的Task对象列表 """ try : if not criteria: return self .find_all_tasks() values = [] where_clauses = [] for col, val in criteria.items(): where_clauses.append(f"{col} = ?" ) values.append(val) where_clause = ' AND ' .join(where_clauses) select_sql = f"SELECT * FROM {self.table_name} WHERE {where_clause} " self .cursor.execute(select_sql, values) rows = self .cursor.fetchall() return [Task.from_row(row) for row in rows] except sqlite3.Error as e: print (f"[Error] 任务查找失败: {e} " ) return [] def find_by_title_contains (self, title_part: str ) -> List [Task]: """根据标题部分查找任务""" try : select_sql = f"SELECT * FROM {self.table_name} WHERE title LIKE ?" self .cursor.execute(select_sql, (f'%{title_part} %' ,)) rows = self .cursor.fetchall() return [Task.from_row(row) for row in rows] except sqlite3.Error as e: print (f"[Error] 任务查找失败: {e} " ) return []
知识点解释 (Repository - 适配 Dataclass):
与 Dataclass 协作: insert
和 update
方法现在接收 Task
dataclass 实例。它们使用 task.to_dict()
来获取需要持久化的数据,而不是手动从对象属性中提取。from_row
的使用: find_all
, find_by_id
, find_by_criteria
等查询方法现在依赖 Task.from_row(row)
这个类方法来将数据库行数据(sqlite3.Row
对象)转换回 Task
dataclass 实例。这使得模型转换逻辑集中在模型类自身。布尔值处理: 在 insert
/update
时将 Python bool
转换为数据库的 INTEGER
(0/1),在 from_row
时将数据库的 INTEGER
转换回 bool
。动态 SQL 构建 (需谨慎): insert
和 update
方法中动态构建了 SQL 语句的列名和占位符部分。这样做可以使代码适应模型的变化,但必须极其小心 ,确保列名不是 来自用户输入,以防范 SQL 注入。这里列名来自 task.to_dict().keys()
,是安全的。错误处理: 继续保持了对 sqlite3.Error
和 sqlite3.IntegrityError
的捕获以及事务回滚。查询方法在出错时返回空列表 []
或 None
。create_table
职责: 将创建表的逻辑放在 Repository 初始化时执行,确保操作 Repository 前表结构一定存在,这是一种常见的实践模式。16.1.5 业务逻辑层实现 (services/
) 任务服务 (services/task_service.py
) 服务层的代码通常与底层数据模型是 dataclass 还是普通类关系不大,因为它主要通过 Repository 提供的接口(接收和返回 Task
对象)来工作。因此,之前的 TaskService
代码基本可以直接使用,或者做少量调整以利用 dataclass 的特性(如果需要的话)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 """ 任务服务模块,提供业务逻辑层功能。 负责处理任务相关的业务规则和操作。 # sqlite_practice/services/task_service.py """ from typing import List , Optional , Dict , Any from datetime import datetime, timedeltafrom models.task import Taskfrom repositories.task_repository import TaskRepositoryclass TaskService : """任务服务类,处理任务相关的业务逻辑""" def __init__ (self, task_repository: TaskRepository ): """初始化任务服务 Args: task_repository: 任务数据访问对象 """ self .task_repository = task_repository def create_task (self, title: str , description: str = None , priority: int = 3 , due_date: str = None , is_completed: bool = False , attachment: bytes = None ) -> Optional [int ]: """创建新任务 Args: title: 任务标题 description: 任务描述 priority: 优先级 (1-5),默认为3 due_date: 截止日期,格式为'YYYY-MM-DD' is_completed: 是否已完成 attachment: 附件数据 Returns: int: 新创建任务的ID,失败则返回None """ if priority < 1 or priority > 5 : print (f"[Warning] 优先级 {priority} 超出范围 (1-5),将使用默认值 3" ) priority = 3 if due_date: try : datetime.strptime(due_date, '%Y-%m-%d' ) except ValueError: print (f"[Warning] 日期格式 '{due_date} ' 无效,应为 'YYYY-MM-DD',将设为空" ) due_date = None if not title: print ("[Error] 任务标题不能为空" ) return None task = Task( title=title, description=description, priority=priority, due_date=due_date, is_completed=is_completed, attachment=attachment ) return self .task_repository.insert_task(task) def create_tasks_batch (self, task_data_list: List [Dict [str , Any ]] ) -> Optional [int ]: """批量创建任务 Args: task_data_list: 包含任务数据的字典列表 Returns: int: 成功创建的任务数量,失败则返回None """ if not task_data_list: print ("[Warning] 没有任务数据需要创建" ) return 0 tasks = [] for task_data in task_data_list: if 'title' not in task_data or not task_data['title' ]: print ("[Warning] 跳过缺少标题的任务" ) continue task = Task( title=task_data['title' ], description=task_data.get('description' ), priority=task_data.get('priority' , 3 ), due_date=task_data.get('due_date' ), is_completed=task_data.get('is_completed' , False ), attachment=task_data.get('attachment' ) ) tasks.append(task) if not tasks: print ("[Warning] 没有有效的任务需要创建" ) return 0 return self .task_repository.insert_many(tasks) def update_task (self, task_id: int , **kwargs ) -> bool : """更新任务 Args: task_id: 要更新的任务ID **kwargs: 要更新的字段和值 Returns: bool: 更新是否成功 """ task = self .task_repository.find_task_by_id(task_id) if not task: print (f"[Error] 未找到要更新的任务 ID={task_id} " ) return False for key, value in kwargs.items(): if hasattr (task, key): setattr (task, key, value) else : print (f"[Warning] 忽略未知字段 '{key} '" ) return self .task_repository.update_task(task) def complete_task (self, task_id: int ) -> bool : """将任务标记为已完成 Args: task_id: 要标记的任务ID Returns: bool: 操作是否成功 """ return self .update_task(task_id, is_completed=True ) def delete_task (self, task_id: int ) -> bool : """删除任务 Args: task_id: 要删除的任务ID Returns: bool: 删除是否成功 """ return self .task_repository.delete_task(task_id) def get_task (self, task_id: int ) -> Optional [Task]: """获取指定ID的任务 Args: task_id: 任务ID Returns: Task: 任务对象,未找到则返回None """ return self .task_repository.find_task_by_id(task_id) def get_all_tasks (self ) -> List [Task]: """获取所有任务 Returns: list: 任务对象列表 """ return self .task_repository.find_all_tasks() def get_tasks_by_priority (self, priority: int ) -> List [Task]: """获取指定优先级的任务 Args: priority: 优先级 Returns: list: 任务对象列表 """ return self .task_repository.find_by_criteria({'priority' : priority}) def get_incomplete_tasks (self ) -> List [Task]: """获取未完成的任务 Returns: list: 任务对象列表 """ return self .task_repository.find_by_criteria({'is_completed' : 0 }) def get_overdue_tasks (self ) -> List [Task]: """获取已逾期的任务 Returns: list: 任务对象列表 """ today = datetime.now().strftime('%Y-%m-%d' ) all_tasks = self .task_repository.find_all_tasks() overdue_tasks = [] for task in all_tasks: if (task.due_date and task.due_date < today and not task.is_completed): overdue_tasks.append(task) return overdue_tasks def search_tasks_by_title (self, title_part: str ) -> List [Task]: """搜索标题包含指定字符串的任务 Args: title_part: 标题中要搜索的文本 Returns: list: 符合条件的任务对象列表 """ return self .task_repository.find_by_title_contains(title_part) def get_tasks_due_within_days (self, days: int ) -> List [Task]: """获取指定天数内到期的任务 Args: days: 天数 Returns: list: 任务对象列表 """ today = datetime.now() end_date = (today + timedelta(days=days)).strftime('%Y-%m-%d' ) today_str = today.strftime('%Y-%m-%d' ) all_tasks = self .task_repository.find_all_tasks() due_tasks = [] for task in all_tasks: if (task.due_date and today_str <= task.due_date <= end_date): due_tasks.append(task) return due_tasks
知识点解释 (Service Layer):
适配 Dataclass: 由于 Repository 返回的是 Task
dataclass 实例,Service 层的方法(如 complete_task
, update_task_details
)可以直接操作这些对象的属性,代码更简洁。业务规则体现: create_task
中增加了对 priority
和 due_date
格式的校验。complete_task
增加了对任务是否已完成的检查,避免重复操作。update_task_details
增加了对不可修改字段 (task_id
, created_at
) 的保护,并对传入的字段和值进行基本检查。get_tasks_due_within_days
中日期的解析和比较逻辑清晰地放在服务层。调用 Repository: 所有与数据库的交互都委托给 self.task_repository
。16.1.7 使用示例 (examples/
) 现在我们已经构建了框架的各个组件(DatabaseManager
, Task
模型, TaskRepository
, TaskService
),接下来将通过几个具体的示例脚本来演示如何使用这个框架完成不同的数据库操作任务。
注意: 运行这些示例脚本时,请确保:
Python 环境中已经安装了必要的库(虽然 sqlite3
是内置的,但未来章节可能需要 pymysql
, sqlalchemy
等)。 脚本能够正确导入我们之前创建的模块(core
, models
, repositories
, services
, utils
)。这通常意味着你需要从 sqlite_practice
这个项目的根目录 来运行这些示例脚本,或者确保 sqlite_practice
目录位于 Python 的模块搜索路径 (sys.path
) 中。示例代码中会包含尝试处理路径的代码。 每个示例脚本可能会创建自己的 SQLite 数据库文件(如 example_sqlite.db
, advanced_queries.db
, transaction_example.db
),这些文件会出现在脚本运行的目录下或指定的 data
子目录中。 16.1.7.1 基本操作示例 (examples/basic_operations.py
) 这个脚本演示了最核心的 CRUD (创建、读取、更新、删除) 操作,展示了如何通过 TaskRepository
来与数据库交互。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 """ 基本数据库操作示例模块,展示SQLite数据库的基本CRUD操作。 包括连接、创建表、增删改查等基本功能。 """ import osimport sqlite3from typing import List , Optional from core.db_manager import DatabaseManagerfrom core.table_manager import TableManagerfrom models.task import Taskfrom repositories.task_repository import TaskRepositoryclass SQLiteExample : """SQLite基本操作示例类,实现单例模式""" _instance = None def __new__ (cls, *args, **kwargs ): """确保类只有一个实例""" if cls._instance is None : cls._instance = super (SQLiteExample, cls).__new__(cls) cls._instance._initialized = False return cls._instance def __init__ (self, db_file: str = "sqlite_practice.db" ): """初始化SQLite示例类""" if self ._initialized: return self .db_file = db_file self .db_manager = DatabaseManager(self .db_file) self .conn, self .cursor = self .db_manager.connect() self .task_repository = TaskRepository(self .conn, self .cursor) self .init_db() self ._initialized = True def init_db (self ) -> bool : """初始化数据库和表结构 Returns: bool: 初始化是否成功 """ try : if os.path.exists(self .db_file): self .clean_db() return self .task_repository.create_table() except Exception as e: print (f"[Error] 初始化数据库失败: {e} " ) return False def clean_db (self ) -> bool : """清理数据库文件 Returns: bool: 清理是否成功 """ return self .db_manager.clean_database() def insert_sample_tasks (self ) -> int : """插入示例任务数据 Returns: int: 成功插入的任务数量 """ print ("\n--- 插入示例任务数据 ---" ) tasks = [ Task(title="任务1" , description="这是任务1的描述" , priority=1 ), Task(title="任务2" , description="这是任务2的描述" , priority=2 ), Task(title="任务3" , description="这是任务3的描述" , priority=3 ), Task(title="任务4" , description="这是任务4的描述" , priority=4 ), Task(title="任务5" , description="这是任务5的描述" , priority=5 ), ] return self .task_repository.insert_many(tasks) def insert_single_task (self, task: Task ) -> Optional [int ]: """插入单个任务 Args: task: 要插入的任务对象 Returns: Optional[int]: 插入的任务ID,失败则返回None """ return self .task_repository.insert_task(task) def find_all_tasks (self ) -> List [Task]: """查询所有任务 Returns: List[Task]: 任务列表 """ return self .task_repository.find_all_tasks() def find_task_by_id (self, task_id: int ) -> Optional [Task]: """根据ID查询任务 Args: task_id: 任务ID Returns: Optional[Task]: 任务对象,未找到则返回None """ return self .task_repository.find_task_by_id(task_id) def find_tasks_by_criteria (self, criteria: dict ) -> List [Task]: """根据条件查询任务 Args: criteria: 查询条件字典 Returns: List[Task]: 符合条件的任务列表 """ return self .task_repository.find_by_criteria(criteria) def find_tasks_by_title (self, title_part: str ) -> List [Task]: """根据标题模糊查询任务 Args: title_part: 标题包含的文本 Returns: List[Task]: 符合条件的任务列表 """ return self .task_repository.find_by_title_contains(title_part) def update_task (self, task: Task ) -> bool : """更新任务 Args: task: 要更新的任务对象 Returns: bool: 更新是否成功 """ return self .task_repository.update_task(task) def delete_task (self, task_id: int ) -> bool : """删除任务 Args: task_id: 要删除的任务ID Returns: bool: 删除是否成功 """ return self .task_repository.delete_task(task_id) def close (self ) -> None : """关闭数据库连接""" self .db_manager.close() if __name__ == '__main__' : print ('\033[1;36m===== 初始化 SQLite 示例 =====\033[0m' ) sqlite_example = SQLiteExample() print ('\033[1;32m===== 插入示例任务数据 =====\033[0m' ) sqlite_example.insert_sample_tasks() print ('\033[1;34m===== 查询所有任务 =====\033[0m' ) tasks = sqlite_example.find_all_tasks() for task in tasks: print ('\033[0;37m' + str (task.to_dict()) + '\033[0m' ) print ('\033[1;34m===== 查询单个任务 (ID=1) =====\033[0m' ) task = sqlite_example.find_task_by_id(1 ) print ('\033[0;33m' + str (task.to_dict()) + '\033[0m' ) print ('\033[1;34m===== 查询高优先级任务 (priority=3) =====\033[0m' ) tasks = sqlite_example.find_tasks_by_criteria({'priority' : 3 }) for task in tasks: print ('\033[0;35m' + str (task.to_dict()) + '\033[0m' ) print ('\033[1;34m===== 查询标题包含"任务"的任务 =====\033[0m' ) tasks = sqlite_example.find_tasks_by_title('任务' ) for task in tasks: print ('\033[0;36m' + str (task.to_dict()) + '\033[0m' ) print ('\033[1;33m===== 更新任务 (ID=1) =====\033[0m' ) task = Task(task_id=1 , title="任务1-更新" , description="这是任务1的更新描述" , priority=1 , is_completed=True ) sqlite_example.update_task(task) print ('\033[1;34m===== 查询更新后的任务 (ID=1) =====\033[0m' ) task = sqlite_example.find_task_by_id(1 ) print ('\033[0;32m' + str (task.to_dict()) + '\033[0m' ) print ('\033[1;31m===== 删除任务 (ID=1) =====\033[0m' ) sqlite_example.delete_task(1 ) print ('\033[1;34m===== 查询删除后的任务 (ID=1) =====\033[0m' ) task = sqlite_example.find_task_by_id(1 ) print ('\033[0;31m' + str (task) + '\033[0m' ) print ('\033[1;36m===== 关闭数据库连接 =====\033[0m' ) sqlite_example.close()
代码解释 (basic_operations.py):
目的: 专注于演示框架的基本 CRUD 功能。流程: 按照“连接 -> 初始化 Repo -> 创建表 -> 插入 -> 查询 -> 更新 -> 查询 -> 删除 -> 查询 -> 关闭”的逻辑顺序进行。交互: 主要通过 TaskRepository
实例 (task_repo
) 的方法 (insert
, insert_many
, find_all
, find_by_id
, update
, find_by_criteria
, find_by_title_contains
, delete
) 来完成数据库操作。模型使用: 创建 Task
dataclass 对象用于插入和更新,接收 Task
对象列表作为查询结果。连接管理: 示例中使用了手动的 db_manager.connect()
和 db_manager.close()
来展示 Repository 如何接收 conn
和 cursor
。在实际应用或更复杂的示例(如下面的事务示例)中,使用 with
语句通常是更好的选择。错误处理: 在主流程外层添加了 try...except...finally
来捕获可能的数据库错误或其他异常,并确保连接最终被关闭。16.1.7.2 高级查询示例 (examples/advanced_queries.py
) 此脚本展示了更复杂的查询场景,包括直接执行 SQL 和通过框架的 Service/Repository 层进行查询。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 """ 高级查询示例模块 (面向对象重构版)。 使用类封装演示流程,遵循面向对象原则,使用彩色输出。 展示SQLite数据库的高级查询功能,通过调用分层框架实现。 """ import sysimport osimport sqlite3from datetime import datetime, timedelta from typing import Optional , Any , List from core.db_manager import DatabaseManagerfrom models.task import Task from repositories.task_repository import TaskRepositoryfrom services.task_service import TaskServiceclass Colors : """定义 ANSI 转义码常量用于彩色输出""" HEADER = '\033[95m' BLUE = '\033[94m' CYAN = '\033[96m' GREEN = '\033[92m' WARNING = '\033[93m' FAIL = '\033[91m' BOLD = '\033[1m' UNDERLINE = '\033[4m' END = '\033[0m' def print_header (text: str ): """打印带样式的标题""" print (f"\n{Colors.HEADER} {Colors.BOLD} --- {text} ---{Colors.END} " ) def print_subheader (text: str ): """打印带样式的子标题""" print (f"\n{Colors.CYAN} {Colors.UNDERLINE} {text} {Colors.END} " ) def print_info (text: str ): """打印普通信息""" print (f" {text} " ) def print_success (text: str ): """打印成功信息""" print (f"{Colors.GREEN} ✔ {text} {Colors.END} " ) def print_warning (text: str ): """打印警告信息""" print (f"{Colors.WARNING} ⚠️ [Warning] {text} {Colors.END} " ) def print_error (text: str ): """打印错误信息""" print (f"{Colors.FAIL} ❌ [Error] {text} {Colors.END} " ) def print_sql (sql: str ): """打印格式化的 SQL 语句""" print (f"{Colors.BLUE} SQL: {sql.strip()} {Colors.END} " ) def print_result_item (item: Any , indent: int = 4 ): """打印格式化的查询结果项""" prefix = " " * indent if isinstance (item, Task): print (f"{prefix} {item} " ) elif hasattr (item, 'keys' ): details = ", " .join([f"{Colors.BOLD} {key} {Colors.END} : {item[key]} " for key in item.keys()]) print (f"{prefix} Row({details} )" ) else : print (f"{prefix} {item} " ) class AdvancedQueryDemoOO : """ 面向对象的 SQLite 高级查询演示类。 封装了演示的设置、执行和清理逻辑。 """ DB_FILE = "advanced_queries_oop.db" def __init__ (self ): """初始化演示类实例""" print_header("初始化高级查询演示 (OOP)" ) self .db_manager = DatabaseManager(self .DB_FILE) self .conn: Optional [sqlite3.Connection] = None self .cursor: Optional [sqlite3.Cursor] = None self .task_repo: Optional [TaskRepository] = None self .task_service: Optional [TaskService] = None self ._cleanup_db() def _cleanup_db (self ): """(私有方法) 清理旧的数据库文件""" if os.path.exists(self .DB_FILE): try : os.remove(self .DB_FILE) print_info(f"旧数据库文件 '{self.DB_FILE} ' 已清理。" ) except OSError as e: print_error(f"清理数据库文件 '{self.DB_FILE} ' 失败: {e} " ) def _setup_database_and_framework (self ): """(私有方法) 建立数据库连接,初始化框架组件,并填充测试数据""" print_subheader("1. 设置数据库和框架组件" ) self .conn, self .cursor = self .db_manager.connect() if not self .conn or not self .cursor: raise ConnectionError("未能获取数据库连接或游标" ) self .task_repo = TaskRepository(self .conn, self .cursor) self .task_service = TaskService(self .task_repo) print_success("Repository 和 Service 初始化成功。" ) self .task_repo.create_table() if not self ._populate_test_data(): raise RuntimeError("填充测试数据失败。" ) def get_future_date (self, days: int , format_str: str = '%Y-%m-%d' ) -> str : """获取未来日期字符串 Args: days: 向后推的天数 format_str: 日期格式,默认为 'YYYY-MM-DD' Returns: str: 未来日期字符串 """ future_date = datetime.now() + timedelta(days=days) return future_date.strftime(format_str) def _populate_test_data (self ) -> bool : """(私有方法) 向数据库填充测试数据""" if not self .task_repo: return False print_subheader("2. 填充测试数据" ) today = datetime.now().strftime('%Y-%m-%d' ) tomorrow = self .get_future_date(1 ) next_week = self .get_future_date(7 ) next_month = self .get_future_date(30 ) yesterday = self .get_future_date(-1 ) test_tasks = [ Task(title="完成项目设计文档" , description="项目需求和架构设计" , priority=1 , due_date=tomorrow), Task(title="修复关键Bug" , description="修复用户报告的登录问题" , priority=1 , due_date=today), Task(title="制定项目计划" , description="项目时间线和里程碑规划" , priority=1 , due_date=next_week), Task(title="紧急会议准备" , priority=1 , due_date=today, is_completed=False ), Task(title="代码重构" , description="重构认证模块的代码" , priority=2 , due_date=next_week), Task(title="单元测试编写" , description="为核心功能编写单元测试" , priority=2 , due_date=next_week), Task(title="API文档更新" , description="更新REST API文档" , priority=2 , due_date=next_month), Task(title="中期报告撰写" , priority=2 , due_date=self .get_future_date(10 )), Task(title="性能优化" , description="优化数据库查询性能" , priority=3 , due_date=next_month), Task(title="学习新技术" , description="学习 GraphQL 技术" , priority=3 , due_date=None ), Task(title="代码审查" , description="审查团队成员的代码" , priority=3 , due_date=next_week), Task(title="环境搭建" , description="搭建开发环境" , priority=1 , is_completed=True , due_date=yesterday), Task(title="需求分析" , description="分析用户需求" , priority=2 , is_completed=True , due_date=today), Task(title="旧项目归档" , priority=3 , is_completed=True , due_date=None ) ] inserted_count = self .task_repo.insert_many(test_tasks) if inserted_count is not None : print_success( f"已尝试插入 {len (test_tasks)} 条测试数据,成功 {inserted_count if inserted_count != -1 else len (test_tasks)} 条。" ) return True else : print_error("填充测试数据过程中发生错误。" ) return False def _demonstrate_raw_sql (self ): """(私有方法) 演示直接执行原生 SQL 查询""" print_header("演示: 直接执行原生 SQL 查询" ) if not self .cursor: print_error("数据库游标无效。" ) return try : print_subheader("1. 按优先级统计任务数量 (GROUP BY, COUNT)" ) self .cursor.execute( f"SELECT priority, COUNT(*) as task_count FROM {self.task_repo.table_name} GROUP BY priority ORDER BY priority" ) for row in self .cursor.fetchall(): print_result_item(row) print_subheader("2. 连接查询示例 (LEFT JOIN 语法演示)" ) join_sql = f"SELECT t.task_id, t.title, u.username FROM {self.task_repo.table_name} t LEFT JOIN users u ON t.user_id = u.user_id WHERE t.priority = 1 LIMIT 5" print_sql(join_sql + " (仅演示语法, users 表不存在)" ) print_subheader("3. 使用 CASE 表达式统计任务状态" ) self .cursor.execute( f"SELECT SUM(CASE WHEN is_completed = 1 THEN 1 ELSE 0 END) as completed_count, SUM(CASE WHEN is_completed = 0 THEN 1 ELSE 0 END) as pending_count, COUNT(*) as total_count FROM {self.task_repo.table_name} " ) stats = self .cursor.fetchone() if stats: print_result_item(stats) print_subheader("4. 复杂条件查询: 一周内到期的高优先级(<=2)未完成任务" ) self .cursor.execute( f"SELECT task_id, title, priority, due_date FROM {self.task_repo.table_name} WHERE is_completed = 0 AND priority <= ? AND (due_date IS NOT NULL AND due_date <= date('now', '+7 days')) ORDER BY priority ASC, due_date ASC" , (2 ,)) urgent_tasks = self .cursor.fetchall() print_info(f"找到 {len (urgent_tasks)} 条:" ) for task in urgent_tasks: print_result_item(task) print_subheader("5. 分页查询: 获取第 2 页数据 (每页 3 条)" ) page_size, page_number = 3 , 2 offset = (page_number - 1 ) * page_size self .cursor.execute( f"SELECT task_id, title FROM {self.task_repo.table_name} ORDER BY task_id LIMIT ? OFFSET ?" , (page_size, offset)) page_tasks = self .cursor.fetchall() print_info(f"第 {page_number} 页 (每页 {page_size} 条):" ) for task in page_tasks: print_result_item(task) print_subheader("6. 子查询: 查找所有今天到期的任务 (使用 EXISTS)" ) self .cursor.execute( f"SELECT task_id, title, due_date FROM {self.task_repo.table_name} t1 WHERE EXISTS (SELECT 1 FROM {self.task_repo.table_name} t2 WHERE t2.task_id = t1.task_id AND t2.due_date = date('now')) ORDER BY priority" ) today_tasks = self .cursor.fetchall() print_info(f"今天 ({datetime.now().strftime('%Y-%m-%d' )} ) 到期的任务 ({len (today_tasks)} 条):" ) for task in today_tasks: print_result_item(task) except sqlite3.Error as e: print_error(f"执行原生 SQL 查询时出错: {e} " ) def _demonstrate_repository_queries (self ): """(私有方法) 演示使用 TaskRepository 方法进行查询""" print_header("演示: 使用 Repository 方法进行查询" ) if not self .task_repo: print_error("TaskRepository 未初始化。" ) return try : print_subheader("1. 查找优先级为 1 的任务" ) priority_1_tasks = self .task_repo.find_by_criteria({"priority" : 1 }) print_info(f"找到 {len (priority_1_tasks)} 条:" ) for task in priority_1_tasks: print_result_item(task) print_subheader("2. 查找已完成的任务" ) completed_tasks = self .task_repo.find_by_criteria({"is_completed" : True }) print_info(f"找到 {len (completed_tasks)} 条:" ) for task in completed_tasks: print_result_item(task) print_subheader("3. 搜索标题中包含 '代码' 的任务" ) code_tasks = self .task_repo.find_by_title_contains("代码" ) print_info(f"找到 {len (code_tasks)} 条:" ) for task in code_tasks: print_result_item(task) print_subheader("4. 组合条件查询 (优先级=2 且 未完成)" ) prio2_incomplete = self .task_repo.find_by_criteria({"priority" : 2 , "is_completed" : False }) print_info(f"找到 {len (prio2_incomplete)} 条:" ) for task in prio2_incomplete: print_result_item(task) except Exception as e: print_error(f"使用 Repository 查询时出错: {e} " ) def _demonstrate_service_queries (self ): """(私有方法) 演示使用 TaskService 方法进行业务查询""" print_header("演示: 使用 Service 层进行业务查询" ) if not self .task_service: print_error("TaskService 未初始化。" ) return try : print_subheader("1. 获取未完成的任务" ) incomplete_tasks = self .task_service.get_incomplete_tasks() print_info(f"找到 {len (incomplete_tasks)} 条:" ) for task in incomplete_tasks[:3 ]: print_result_item(task) if len (incomplete_tasks) > 3 : print_info(f" ... 及其他 {len (incomplete_tasks) - 3 } 条" ) print_subheader("2. 获取未来 7 天内到期的任务" ) due_soon_tasks = self .task_service.get_tasks_due_within_days(7 ) print_info(f"找到 {len (due_soon_tasks)} 条:" ) for task in due_soon_tasks: print_result_item(task) print_subheader("3. 获取已逾期的未完成任务" ) overdue_tasks = self .task_service.get_overdue_tasks() print_info(f"找到 {len (overdue_tasks)} 条:" ) for task in overdue_tasks: print_result_item(task) except Exception as e: print_error(f"使用 Service 查询时出错: {e} " ) def run (self ): """ 执行完整的演示流程: 设置 -> 填充数据 -> 各种查询演示 -> 清理。 使用 try...finally 确保数据库连接总是被关闭。 """ try : self ._setup_database_and_framework() if self .cursor: self ._demonstrate_raw_sql() if self .task_repo: self ._demonstrate_repository_queries() if self .task_service: self ._demonstrate_service_queries() print_success("\n所有查询演示执行完毕。" ) except ConnectionError as ce: print_error(f"数据库连接错误: {ce} " ) except RuntimeError as rte: print_error(f"运行时错误 (可能在设置或填充数据时): {rte} " ) except sqlite3.Error as db_err: print_error(f"数据库操作错误: {db_err} " ) import traceback traceback.print_exc() except Exception as e: print_error(f"发生意外错误: {e} " ) import traceback traceback.print_exc() finally : if self .db_manager: self .db_manager.close() print_header("高级查询演示流程结束" ) if __name__ == "__main__" : demo = AdvancedQueryDemoOO() demo.run()
代码解释 (advanced_queries.py):
目的: 展示更复杂的 SQL 查询以及如何在不同层次(原生 SQL, Repository, Service)执行它们。populate_test_data
: 创建了多样化的任务数据,包括不同的优先级、截止日期和完成状态,为后续复杂查询提供基础。run_raw_sql_queries
: 直接使用 cursor
执行 SQL。这展示了:聚合函数 (COUNT
, SUM
) 和 GROUP BY
: 用于统计分析。CASE
表达式: 在 SQL 中实现条件逻辑。复杂 WHERE
子句: 结合 AND
, OR
, 比较运算符, IS NOT NULL
和 SQLite 内建日期函数 (date('now', '+7 days')
)。ORDER BY
多列排序。LIMIT
和 OFFSET
: 实现分页。EXISTS
子查询: 进行基于关联的条件判断。优点: 可以利用数据库特定的高级功能,或编写 ORM 难以表达的优化查询。缺点: 绕过了框架的封装,代码与数据库耦合,容易出错,且有 SQL 注入风险(如果不是硬编码 SQL 的话)。run_repository_queries
: 使用 TaskRepository
提供的 find_by_criteria
和 find_by_title_contains
方法。展示了通过抽象接口进行条件查询。run_service_queries
: 使用 TaskService
提供的面向业务的方法,如 get_incomplete_tasks
, get_tasks_due_within_days
, get_overdue_tasks
。这体现了 Service 层封装业务逻辑查询的优势。run_advanced_queries_main
: 主函数负责初始化框架组件(使用 with DatabaseManager
),调用数据填充函数,然后依次调用三种查询方式的演示函数。16.2 MySQL 数据库 (使用 PyMySQL
驱动) 16.2.1 简介
MySQL: 是一个非常流行的开源关系型数据库管理系统 (RDBMS) ,广泛用于各种规模的应用程序,尤其是 Web 应用。它采用客户端-服务器架构,需要独立的服务器进程。PyMySQL: 是一个让 Python 程序能够连接并操作 MySQL 数据库的纯 Python 驱动库 。它实现了 Python 的 DB-API 2.0 (PEP 249) 规范,提供了一套标准的数据库操作接口。前提: 在使用 PyMySQL
前,你需要确保:本地或远程有一个正在运行的 MySQL 服务器 (版本 5.5 或更高)。 你拥有连接该服务器所需的主机地址、端口号、用户名、密码 。 服务器上已经创建了你要操作的数据库 。 安装: 与 sqlite3
主要不同: 需要网络连接到服务器。 连接时参数更多。 SQL 参数占位符是 %s
。 需要特别注意字符集 (utf8mb4
) 和游标类型 (DictCursor
) 的设置。 连接管理(关闭/连接池)更为重要。 16.2.2 建立连接
连接是与 MySQL 服务器交互的第一步。
我们把上一章的打印工具给拿过来,放置在根目录print_utils
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 """ 打印工具模块,提供彩色和结构化的打印函数。 """ class Colors : HEADER = '\033[95m' BLUE = '\033[94m' CYAN = '\033[96m' GREEN = '\033[92m' WARNING = '\033[93m' FAIL = '\033[91m' BOLD = '\033[1m' UNDERLINE = '\033[4m' END = '\033[0m' def print_header (text: str ): print (f"\n{Colors.HEADER} {Colors.BOLD} --- {text} ---{Colors.END} " ) def print_subheader (text: str ): print (f"\n{Colors.CYAN} {Colors.UNDERLINE} {text} {Colors.END} " ) def print_info (text: str ): print (f" {text} " ) def print_success (text: str ): print (f"{Colors.GREEN} ✔ {text} {Colors.END} " ) def print_warning (text: str ): print (f"{Colors.WARNING} ⚠️ [Warning] {text} {Colors.END} " ) def print_error (text: str ): print (f"{Colors.FAIL} ❌ [Error] {text} {Colors.END} " ) def print_sql (sql: str ): print (f"{Colors.BLUE} SQL: {sql.strip()} {Colors.END} " ) def print_result_item (item, indent: int = 4 ): prefix = " " * indent if isinstance (item, dict ): details = ", " .join([ f"{Colors.BOLD} {key} {Colors.END} : {repr (value)} " for key, value in item.items() ]) print (f"{prefix} Row({details} )" ) else : print (f"{prefix} {repr (item)} " )
当我们需要使用时,直接from导入
并像调用函数一样调用即可
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 import pymysqlfrom pymysql.cursors import DictCursor from typing import Optional , Tuple , Dict , List , Any from sqlite_practice.utils.print_utils import ( Colors, print_header, print_subheader, print_info, print_success, print_warning, print_error, print_sql, print_result_item ) DB_CONFIG: Dict [str , Any ] = { 'host' : 'localhost' , 'port' : 3306 , 'user' : 'root' , 'password' : 'root' , 'database' : 'pymysql_demo_db' , 'charset' : 'utf8mb4' , 'cursorclass' : DictCursor, 'autocommit' : False , 'connect_timeout' : 10 } def get_mysql_connection () -> Optional [pymysql.connections.Connection]: """尝试建立到 MySQL 数据库的连接。""" print_info(f"尝试连接到 MySQL: {DB_CONFIG['host' ]} :{DB_CONFIG['port' ]} , Database: {DB_CONFIG['database' ]} " ) try : connection = pymysql.connect(**DB_CONFIG) print_success(f"成功连接到数据库 '{DB_CONFIG['database' ]} '" ) return connection except pymysql.Error as e: print_error(f"数据库连接失败: {e} " ) print_warning(" 请检查 DB_CONFIG 配置、MySQL 服务器状态及网络连接。" ) return None
连接参数关键点:
charset='utf8mb4'
: 确保可以正确处理各种字符,包括表情符号。数据库、表、列也应使用此字符集。cursorclass=DictCursor
: 使 cursor.fetchone()
返回字典,cursor.fetchall()
返回字典列表。可通过 row['column_name']
访问数据,比元组索引更易读。autocommit=False
: 禁用自动提交。所有 INSERT
, UPDATE
, DELETE
操作都需要显式调用 conn.commit()
才生效,便于事务管理。安全: 绝不在代码中硬编码生产环境的密码。使用环境变量、配置文件、Secrets Manager 等安全机制。16.2.3 核心 API 概览
PyMySQL
遵循 DB-API 2.0 规范,核心 API 与 sqlite3
类似,但注意占位符和游标类型。
API 描述 注意事项 pymysql.connect(**config)
连接 MySQL,返回 Connection
对象。 参数多,需网络 connection.cursor([cursorclass])
创建 Cursor
对象 (推荐 DictCursor
)。 DictCursor
返回字典cursor.execute(sql, [args])
执行单条 SQL (用 %s
占位符)。返回受影响行数。 %s
占位符cursor.executemany(sql, seq_of_args)
批量执行 SQL (用 %s
)。返回受影响总行数 (可能不精确)。 %s
占位符connection.commit()
提交当前事务。 autocommit=False
时必需connection.rollback()
回滚当前事务。 connection.begin()
显式开始事务 (若 autocommit=False
)。 可选,但更清晰 cursor.fetchone()
获取下一行结果 (字典或元组),无结果时 None
。 依赖 cursorclass
cursor.fetchall()
获取所有剩余行结果 (字典或元组的列表)。 依赖 cursorclass
cursor.lastrowid
(属性) 最后 INSERT 操作的自增 ID (AUTO_INCREMENT
列)。 cursor.rowcount
(属性) 最后操作影响的行数 (对 SELECT 可能不同数据库行为有差异)。 connection.close()
必须调用 关闭网络连接。极其重要 with connection.cursor() as cursor:
(推荐) 游标上下文管理器,自动关闭游标。 connection.ping(reconnect=True)
检查连接活性,可选重连。 用于长连接场景
16.2.4 准备工作:创建表与填充数据
为了演示查询,我们创建 categories
和 products
表。
categories 表
列名 数据类型 是否允许 NULL 键 默认值 额外 注释 category_id INT UNSIGNED NO PRI — AUTO_INCREMENT 产品类别主键 name VARCHAR(50) NO UNI — 类别名称 description TEXT YES NULL 类别描述
products 表
列名 数据类型 是否允许 NULL 键 默认值 额外 注释 id INT UNSIGNED NO PRI — AUTO_INCREMENT 产品ID name VARCHAR(100) NO — 产品名称 price DECIMAL(10, 2) NO — 价格 stock INT UNSIGNED NO 0 库存 category_id INT UNSIGNED YES MUL NULL 外键,关联 categories.category_id added_date DATE YES NULL 上架日期 created_at TIMESTAMP YES CURRENT_TIMESTAMP 创建时间 updated_at TIMESTAMP YES CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP 更新时间
这两个函数仅是一次性的,用于创建表与插入数据,可以看看语法,也可以不看,简单的SQL而已
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 import pymysqlfrom pymysql.cursors import DictCursor from typing import Optional , Tuple , Dict , List , Any from print_utils import ( Colors, print_header, print_subheader, print_info, print_success, print_warning, print_error, print_sql, print_result_item ) DB_CONFIG: Dict [str , Any ] = { 'host' : 'localhost' , 'port' : 3306 , 'user' : 'root' , 'password' : 'root' , 'database' : 'pymysql_demo_db' , 'charset' : 'utf8mb4' , 'cursorclass' : DictCursor, 'autocommit' : False , 'connect_timeout' : 10 } def get_mysql_connection () -> Optional [pymysql.Connection]: """创建 MySQL 连接""" print_info(f"尝试连接到 MySQL: {DB_CONFIG['host' ]} :{DB_CONFIG['port' ]} , Database: {DB_CONFIG['database' ]} " ) try : connection = pymysql.connect(**DB_CONFIG) print_success("连接成功" ) return connection except pymysql.Error as e: print_error(f"数据库连接失败: {e} " ) print_warning(" 请检查 DB_CONFIG 配置、MySQL 服务器状态及网络连接。" ) return None def setup_mysql_tables (conn: pymysql.connections.Connection ) -> bool : """创建演示所需的 categories 和 products 表。""" try : with conn.cursor() as cursor: print_info("开始创建表..." ) cursor.execute("DROP TABLE IF EXISTS products;" ) cursor.execute("DROP TABLE IF EXISTS categories;" ) print_warning(" (旧表 'products' 和 'categories' 已删除,如果存在)" ) cursor.execute(""" CREATE TABLE categories ( category_id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY, name VARCHAR(50) NOT NULL UNIQUE COMMENT '类别名称', description TEXT COMMENT '类别描述' ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='产品类别表'; """ ) print_success("表 'categories' 创建成功。" ) cursor.execute(""" CREATE TABLE products ( id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY COMMENT '产品ID', name VARCHAR(100) NOT NULL COMMENT '产品名称', price DECIMAL(10, 2) NOT NULL COMMENT '价格', stock INT UNSIGNED NOT NULL DEFAULT 0 COMMENT '库存', category_id INT UNSIGNED COMMENT '外键, 关联 categories 表', added_date DATE COMMENT '上架日期', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间', updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间', INDEX idx_prod_name (name(20)), -- 名称前缀索引 FOREIGN KEY fk_prod_cat (category_id) REFERENCES categories(category_id) ON DELETE SET NULL -- 删除类别时,产品类别设为 NULL ON UPDATE CASCADE -- 更新类别 ID 时,产品自动更新 ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='产品信息表'; """ ) print_success("表 'products' 创建成功。" ) conn.commit() return True except pymysql.Error as e: print_error(f"创建表失败: {e} " ) conn.rollback() return False def populate_mysql_data (conn: pymysql.connections.Connection ) -> bool : """向 categories 和 products 表填充测试数据。""" try : with conn.cursor() as cursor: print_info("开始填充测试数据..." ) categories_data = [ ('电子产品' , '手机、电脑、配件等' ), ('图书音像' , '各类实体书籍和数字媒体' ), ('服装鞋包' , '男女服装、鞋子和箱包' ), ('家居生活' , '家具、厨具、家纺等' ), ] cat_sql = "INSERT INTO categories (name, description) VALUES (%s, %s)" cursor.executemany(cat_sql, categories_data) print_success(f" 插入了 {cursor.rowcount} 个类别。" ) cursor.execute("SELECT category_id, name FROM categories" ) cat_map = {row['name' ]: row['category_id' ] for row in cursor.fetchall()} products_data = [ ('智能手机 V12' , 4999.00 , 50 , cat_map['电子产品' ], '2024-10-01' ), ('蓝牙耳机 AirSound' , 799.00 , 100 , cat_map['电子产品' ], '2024-11-15' ), ('PyMySQL 深度指南' , 88.50 , 200 , cat_map['图书音像' ], '2023-05-20' ), ('算法导论 (原版)' , 128.00 , 80 , cat_map['图书音像' ], '2024-01-10' ), ('纯棉印花 T 恤' , 129.00 , 300 , cat_map['服装鞋包' ], '2024-08-01' ), ('透气运动跑鞋' , 499.00 , 60 , cat_map['服装鞋包' ], '2024-09-01' ), ('北欧风实木餐桌' , 2899.00 , 10 , cat_map['家居生活' ], '2023-12-01' ), ('乳胶记忆棉枕头' , 159.00 , 150 , cat_map['家居生活' ], '2024-03-01' ), ('游戏本 RTX 9000' , 12999.00 , 20 , cat_map['电子产品' ], None ), ('科幻短篇小说集' , 65.00 , 120 , cat_map['图书音像' ], '2024-06-15' ), ('速干运动短裤' , 189.00 , 100 , None , '2024-07-01' ), ('咖啡机' , 899.00 , 30 , cat_map['家居生活' ], '2024-05-01' ), ('SQL 查询的艺术' , 75.00 , 90 , cat_map['图书音像' ], '2024-04-15' ) ] prod_sql = "INSERT INTO products (name, price, stock, category_id, added_date) VALUES (%s, %s, %s, %s, %s)" cursor.executemany(prod_sql, products_data) print_success(f" 插入了 {cursor.rowcount} 个产品。" ) conn.commit() print_success("测试数据填充完成。" ) return True except pymysql.Error as e: print_error(f"填充数据失败: {e} " ) conn.rollback() return False if __name__ == '__main__' : conn = get_mysql_connection() if conn is None : exit(1 ) if not setup_mysql_tables(conn): exit(1 ) if not populate_mysql_data(conn): exit(1 ) with conn.cursor() as cursor: print_header("查询数据" )
16.2.5 数据库操作:CRUD 与查询详解 16.2.5.1 插入数据 (INSERT) 插入单行记录 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 def insert_single_product (conn: pymysql.Connection, product_name: str , price: float , stock: int , category_id: Optional [int ] = None , added_date: Optional [str ] = None ) -> Optional [int ]: """ 插入单个产品记录 :param conn: 数据库连接 :param product_name: 产品名称 :param price: 产品价格 :param stock: 产品库存 :param category_id: 产品分类 ID (可选) :param added_date: 产品添加日期 (可选) :return: 新插入的产品 ID (若插入失败则返回 None) """ sql = """ INSERT INTO products (name, price, stock, category_id, added_date) VALUES (%s, %s, %s, %s, %s) """ params = (product_name, price, stock, category_id, added_date) try : with conn.cursor() as cursor: cursor.execute(sql, params) new_id = cursor.lastrowid conn.commit() print_success(f"成功插入产品:{product_name} (ID: {new_id} )" ) return new_id except pymysql.Error as e: print_error(f"插入产品失败: {e} " ) conn.rollback() return None if __name__ == '__main__' : conn = get_mysql_connection() if conn is None : exit(1 ) try : new_id = insert_single_product(conn, product_name='iPhone X' , price=9999.99 , stock=100 , category_id=1 , added_date='2021-01-01' ) if new_id is not None : print_success(f"新插入的产品 ID: {new_id} " ) except Exception as e: print_error(f"插入产品失败: {e} " )
批量插入记录 (executemany
) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 def insert_multiple_products (conn: pymysql.Connection, products_list: List [Tuple ] ) -> Optional [List [int ]]: """ 批量插入产品记录 :param conn: 数据库连接 :param products_list: 产品列表 (列表项为元组,包含 (name, price, stock, category_id, added_date) 五个字段) :return: 新插入的产品 ID 列表 (若插入失败则返回 None) """ if not products_list: return 0 sql = "INSERT INTO products (name,price,stock,category_id,added_date,created_at) VALUES (%s,%s,%s,%s,%s,NOW())" try : with conn.cursor() as cursor: cursor.executemany(sql, products_list) new_ids = [cursor.lastrowid for _ in range (len (products_list))] conn.commit() print_success(f"成功插入 {len (products_list)} 个产品" ) return new_ids except pymysql.Error as e: print_error(f"插入产品失败: {e} " ) conn.rollback() return None if __name__ == '__main__' : conn = get_mysql_connection() if not conn: exit(1 ) products_list = [ ('iPhone X' , 9999 , 100 , 1 , '2021-01-01' ), ('华为 P30 Pro' , 8888 , 50 , 2 , '2021-01-02' ), ('小米 10' , 7777 , 20 , 3 , '2021-01-03' ), ('OPPO Find X3' , 6666 , 10 , 4 , '2021-01-04' ), ('vivo NEX' , 5555 , 5 , 4 , '2021-01-05' ), ] new_ids = insert_multiple_products(conn, products_list) if new_ids: print_info(f"新插入的产品 ID 列表: {new_ids} " )
注意,在使用的时候category_id 、 时不能存储4以上的,由于我们上面插入的insert外键最高为4,若需要更新为4类别的,则需要增加category_id的数量
16.2.5.2 更新数据 (UPDATE) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 def update_product_stock (conn: pymysql.Connection, product_id: int , stock: int ) -> bool : """更新指定产品的库存 (增加或减少)。""" sql_get_stock = "SELECT stock FROM products WHERE id = %s" sql_update = "UPDATE products SET stock = %s WHERE id = %s" try : with conn.cursor() as cursor: cursor.execute(sql_get_stock, (product_id,)) result = cursor.fetchone() if not result: print_warning(f"更新库存失败: 未找到产品 ID 为 {product_id} 的记录" ) return False current_stock = result["stock" ] new_stock = current_stock + stock if new_stock < 0 : print_warning(f"更新库存失败: 库存不足 (当前库存: {current_stock} , 增加数量: {stock} )" ) return False affected_rows = cursor.execute(sql_update, (new_stock, product_id)) conn.commit() if affected_rows > 0 : print_success(f"成功更新产品 ID 为 {product_id} 的库存为 {new_stock} " ) return True else : print_warning(f"更新库存失败: 未找到影响行数 (affected_rows: {affected_rows} )" ) return False except pymysql.Error as e: print_error(f"更新库存失败: {e} " ) conn.rollback() return False if __name__ == '__main__' : conn = get_mysql_connection() if not conn: exit(1 ) product_id = 1 stock = 10 if update_product_stock(conn, product_id, stock): print_info(f"库存更新成功" ) else : print_error(f"库存更新失败" )
16.2.5.3 删除数据 (DELETE) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 def delete_product_by_id (conn: pymysql.connections.Connection, product_id: int ) -> bool : """根据 ID 删除产品。""" sql = "DELETE FROM products WHERE id = %s" try : with conn.cursor() as cursor: affected_rows = cursor.execute(sql, (product_id,)) conn.commit() if affected_rows > 0 : print_success(f"成功删除产品 ID={product_id} " ) return True else : print_warning(f"删除产品 ID={product_id} 时未找到匹配记录。" ) return False except pymysql.Error as e: print_error(f"删除产品 ID={product_id} 失败: {e} " ) conn.rollback() return False if __name__ == '__main__' : conn = get_mysql_connection() if not conn: exit(1 ) affected_rows = delete_product_by_id(conn, 1 ) if affected_rows: print_success("删除产品成功" ) else : print_warning("删除产品失败" )
16.2.5.4 查询数据 (SELECT) - 基础 查询所有列 (*
) 1 2 3 4 5 6 7 8 9 10 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_header("演示 SELECT 查询" ) try : with conn.cursor() as cursor: cursor.execute("SELECT * FROM products" ) result = cursor.fetchall() print_info(f"查询结果: {len (result)} 条结果" ) for row in result:print_result_item(row) except pymysql.Error as e:print_error(f"查询失败: {e} " )
查询指定列 1 2 3 4 5 6 7 8 9 10 11 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("查询所有产品的名称和价格 (前 5 条)" ) try : with conn.cursor() as cursor: cursor.execute("SELECT name, price FROM products LIMIT 5" ) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
16.2.5.5 查询数据 (SELECT) - 条件过滤 (WHERE
) 等于 (=) 1 2 3 4 5 6 7 8 9 10 11 12 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("查询商品分类 ID 为 2的产品" ) try : category_id_to_find = 2 with conn.cursor() as cursor: cursor.execute("SELECT id,name,price FROM products Where category_id = %s" , (category_id_to_find,)) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
比较运算符 (>, <=) 1 2 3 4 5 6 7 8 9 10 11 12 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("查询商品价钱大于100的所商品名和价格" ) try : max_price = 100.0 with conn.cursor() as cursor: cursor.execute("SELECT name,price FROM productS WHERE price > %s ORDER BY price DESC" , (max_price,)) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
BETWEEN … AND …区间查询 1 2 3 4 5 6 7 8 9 10 11 12 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("查询商品价格在100~500之间的商品,且通过商品价格降序排序" ) try : min_p, max_p = 100.0 , 500.0 with conn.cursor() as cursor: cursor.execute("SELECT name, price FROM products WHERE price BETWEEN %s AND %s ORDER BY price" , (min_p, max_p)) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
IN (…) 1 2 3 4 5 6 7 8 9 10 11 12 13 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("查询类别 ID 为 1 或 4 (电子产品或家居生活) 的产品" ) try : target_cat_ids = (1 , 4 ) placeholders = ', ' .join(['%s' ] * len (target_cat_ids)) with conn.cursor() as cursor: cursor.execute(f"SELECT id, name, category_id FROM products WHERE category_id IN ({placeholders} )" ,target_cat_ids) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
LIKE (模糊匹配) 1 2 3 4 5 6 7 8 9 10 11 12 13 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("查询名称包含 '运动' 的产品" ) try : search_term = "运动" with conn.cursor() as cursor: cursor.execute("SELECT name, price FROM products WHERE name LIKE %s" , (f"%{search_term} %" ,)) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
IS NULL / IS NOT NULL 1 2 3 4 5 6 7 8 9 10 11 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("查询没有上架日期的产品" ) try : with conn.cursor() as cursor: cursor.execute("SELECT id,name FROM products WHERE added_date IS NULL" ) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
组合条件 (AND, OR, NOT) 1 2 3 4 5 6 7 8 9 10 11 12 13 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("查询图书音像类(ID=2) 或 价格低于100元 且 库存大于0 的产品" ) try : cat_id = 2 max_p = 100.0 with conn.cursor() as cursor: cursor.execute("SELECT * FROM products WHERE (category_id = %s OR price < %s) AND stock > 0" , (cat_id, max_p)) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
16.2.5.6 查询数据 (SELECT) - 排序 (ORDER BY
) 单列排序 (ASC/DESC) 1 2 3 4 5 6 7 8 9 10 11 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("按库存数量降序排列产品 (前 5)" ) try : with conn.cursor() as cursor: cursor.execute("SELECT name,stock FROM products ORDER BY stock DESC LIMIT 5" ) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
多列排序 1 2 3 4 5 6 7 8 9 10 11 12 13 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("按类别 ID 升序、价格降序排列产品 (前 5)" ) try : with conn.cursor() as cursor: cursor.execute("SELECT * FROM products ORDER BY category_id ASC,price DESC LIMIT 5" ) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
16.2.5.7 查询数据 (SELECT) - 去重 (DISTINCT
) 1 2 3 4 5 6 7 8 9 10 11 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("查询现有的所有产品分类的 ID" ) try : with conn.cursor() as cursor: cursor.execute("SELECT DISTINCT category_id FROM products WHERE category_id IS NOT NULL" ) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
16.2.5.8 查询数据 (SELECT) - 聚合函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("计算产品统计信息 (总数, 总库存, 平均价, 最高价, 最低价)" ) sql = """ SELECT COUNT(*) as total_products, SUM(stock) as total_stock, AVG(price) as average_price, MAX(price) as max_price, MIN(price) as min_price FROM products """ try : with conn.cursor() as cursor: cursor.execute(sql) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
16.2.5.9 查询数据 (SELECT) - 分组 (GROUP BY
) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("按类别名称分组,统计各类别的产品数量和平均价格" ) sql = """ SELECT c.name as category_name, -- 从 categories 表获取分类名称 COUNT(p.id) as num_products, -- 统计每个分类下产品数量 AVG(p.price) as avg_price -- 统计每个分类下产品平均价格 FROM products as p LEFT JOIN categories as c ON p.category_id = c.category_id -- 左连接 categories 表 GROUP BY c.name -- 按照类别名称分组 ORDER BY num_products DESC -- 按产品数量倒序排列 """ try : with conn.cursor() as cursor: cursor.execute(sql) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
16.2.5.10 查询数据 (SELECT) - 分组过滤 (HAVING
) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("按类别分组,并找出平均价格高于 3000 元的类别" ) sql = """ SELECT c.name as category_name, COUNT(p.id) as num_products, AVG(p.price) as avg_price FROM products p JOIN categories c ON p.category_id = c.category_id -- 使用外键关联分类 GROUP BY c.name HAVING AVG(p.price) > %s -- 使用HAVING 过滤分组结果 ORDER BY avg_price DESC -- 按平均价格降序排序 """ try : MIN_AVG_PRICE = 3000.0 with conn.cursor() as cursor: cursor.execute(sql, (MIN_AVG_PRICE,)) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
16.2.5.11 查询数据 (SELECT) - 分页 (LIMIT
/OFFSET
) 1 2 3 4 5 6 7 8 9 10 11 12 13 print_subheader("分页查询: 获取第 3 页数据 (每页 4 条)" ) page = 3 page_size = 4 offset = (page - 1 ) * page_size sql = "SELECT id, name, price FROM products ORDER BY id LIMIT %s OFFSET %s" try : with conn.cursor() as cursor: cursor.execute(sql, (page_size, offset)) results = cursor.fetchall() print_info(f"第 {page} 页 (每页 {page_size} 条),获取 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
16.2.5.12 查询数据 (SELECT) - 连接查询 (JOIN
) 内连接 (INNER JOIN
) INNER JOIN: 必须要两边表都有匹配的数据才会出现在结果中。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("内连接查询: 获取产品及其对应的类别名称 (只显示有类别的产品)" ) sql = """ SELECT p.id,p.name AS product_name , p.price,c.name AS category_name FROM products p INNER JOIN categories c ON p.category_id = c.category_id -- 只返回两个表都能匹配上的行 ORDER BY p.name,c.name LIMIT 5 -- 限制输出 """ try : with conn.cursor() as cursor: cursor.execute(sql) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
左连接 (LEFT JOIN
) 左连接 保留左表所有行(右表无匹配则补空)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("左连接查询: 获取所有产品及其类别名称 (没有类别的产品也会显示)" ) sql = """ SELECT p.id, p.name AS product_name, p.price, c.name AS category_name FROM products p -- 这里换成 RIGHT JOIN 就会少一条记录 (没有类别的产品) LEFT JOIN categories c ON p.category_id = c.category_id -- 返回左表(products)的所有行 ORDER BY p.id """ try : with conn.cursor() as cursor: cursor.execute(sql) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
16.2.5.13 查询数据 (SELECT) - CASE
表达式 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("使用 CASE 表达式根据价格给产品分类" ) sql = """ SELECT name, price, CASE WHEN price < 100 THEN '入门级' WHEN price BETWEEN 100 AND 999.99 THEN '标准级' WHEN price BETWEEN 1000 AND 4999.99 THEN '进阶级' ELSE '旗舰级' END AS price_level -- 给 CASE 结果起别名 FROM products ORDER BY price LIMIT 6 -- 限制输出 """ try : with conn.cursor() as cursor: cursor.execute(sql) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
16.2.5.14 查询数据 (SELECT) - 子查询 (Subquery) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def select_demonstration (conn: pymysql.Connection ) -> None : """演示 SELECT 查询""" print_subheader("子查询: 查找价格高于平均价格的产品" ) sql = """ SELECT ID,name,price FROM products WHERE price > (SELECT AVG(price) FROM products) ORDER BY price DESC """ try : with conn.cursor() as cursor: cursor.execute(sql) results = cursor.fetchall() print_info(f"查询到 {len (results)} 条记录:" ) for row in results: print_result_item(row) except pymysql.Error as e: print_error(f"查询失败: {e} " )
更详细的SQL语法,详见MYSQL篇章
16.3 SQLAlchemy Core: SQL 表达式语言 SQLAlchemy Core 提供了一种使用 Python 对象构建 SQL 语句的方式,避免直接拼接字符串,从而提高代码的安全性和可维护性。本节将核心概念和常用 API 通过表格和简洁示例进行介绍。
首先,你需要安装 SQLAlchemy 库:
SQLAlchemy 本身不包含数据库驱动程序 (DBAPI),它依赖于第三方驱动来与具体的数据库进行通信。因此,你还需要根据你使用的数据库安装相应的驱动。SQLAlchemy 通过驱动实现与数据库的连接和交互。
常用数据库及其推荐驱动:
数据库 SQLAlchemy 连接 URL 方言 推荐驱动 (DBAPI) 安装命令 (示例) PostgreSQL postgresql
psycopg2
(binary)pip install psycopg2-binary
MySQL mysql
mysqlclient
pip install mysqlclient
mysql
PyMySQL
pip install PyMySQL
SQLite sqlite
sqlite3
(内置)(无需额外安装) Microsoft SQL mssql
pyodbc
pip install pyodbc
Oracle oracle
cx_Oracle
pip install cx_Oracle
选择驱动 :
对于 PostgreSQL,psycopg2
是最常用且功能最全的选择。psycopg2-binary
包含了预编译的 C 扩展,安装更方便。 对于 MySQL,mysqlclient
是一个性能较好的 C 扩展驱动,但可能需要编译环境。PyMySQL
是纯 Python 实现,安装简单,兼容性好,在许多场景下性能也足够。SQLAlchemy 连接 URL 中需要指定驱动,如 mysql+pymysql://...
或 mysql+mysqlclient://...
。 SQLite 的 sqlite3
驱动是 Python 内置的。 本章我们还是使用PyMysql作为核心驱动!
16.3.1 引擎 (Engine): 数据库连接的入口 Engine
是 SQLAlchemy 应用与数据库交互的起点,负责管理连接池和数据库方言。
创建引擎 (create_engine
)
使用 sqlalchemy.create_engine()
函数创建。
参数 类型 描述 示例 (SQLite 文件) url
str
必需 . 数据库连接 URL,格式为 dialect+driver://user:password@host:port/database
。具体格式见下方说明。"sqlite:///myapp.db"
echo
bool
可选,默认为 False
。设为 True
时,打印 SQLAlchemy 执行的所有 SQL 语句。调试时非常有用 。 echo=True
pool_size
int
可选,连接池保持的最小连接数。 pool_size=5
max_overflow
int
可选,超出 pool_size
后允许额外创建的最大连接数。 max_overflow=10
pool_recycle
int
可选,连接在连接池中保持多少秒后被回收 (防止数据库服务器因超时关闭连接)。 pool_recycle=3600
(1 小时)connect_args
dict
可选,传递给底层 DBAPI connect()
方法的额外参数。 connect_args={'timeout': 10}
(某些驱动)execution_options
dict
可选,设置执行选项,如 isolation_level
(事务隔离级别)。 json_serializer
callable
可选,用于序列化 JSON 数据。 json_deserializer
callable
可选,用于反序列化 JSON 数据。
数据库连接 URL ( url
) 详解:
格式 : dialect[+driver]://[user[:password]@][host][:port]/[database][?key=value&key=value...]
dialect
: 数据库类型 (sqlite
, mysql
, postgresql
, mssql
, oracle
等)。driver
(可选) : 使用的 DBAPI 库 (psycopg2
, pymysql
, mysqlclient
, pyodbc
等)。如果省略,SQLAlchemy 会尝试默认驱动。示例 :SQLite (文件): sqlite:///path/to/your/database.db
(注意三个 /
) SQLite (内存): sqlite:///:memory:
(或者仅 sqlite://
) MySQL (PyMySQL): mysql+pymysql://user:pass@host:3306/dbname?charset=utf8mb4
PostgreSQL (psycopg2): postgresql+psycopg2://user:pass@host:5432/dbname
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 from sqlalchemy import create_enginefrom print_utils import print_header, print_info, print_success, print_errorprint_header("创建 SQLAlchemy Engine 示例" ) try : sqlite_file_engine = create_engine("sqlite:///sqlalchemy_core_example.db" , echo=True ) print_success("创建 SQLite 文件数据库引擎成功 (带 SQL 日志)。" ) sqlite_memory_engine = create_engine("sqlite:///:memory:" , echo=False ) print_success("创建 SQLite 内存数据库引擎成功。" ) mysql_url = "mysql+pymysql://root:root@localhost:3306/pymysql_demo?charset=utf8mb4" mysql_engine = create_engine(mysql_url, echo=True ) print_success("创建 MySQL 引擎配置示例 (pymysql_demo)。" ) except ImportError as ie: print_error(f"创建引擎失败:缺少必要的数据库驱动。请安装相应的库。错误: {ie} " ) except Exception as e: print_error(f"创建引擎时发生意外错误: {e} " )
16.3.2. 元数据与表定义 使用 MetaData
, Table
, Column
对象在 Python 代码中定义数据库模式。
核心类与概念
类/概念 描述 MetaData
表和其他模式对象的容器。通常一个数据库一个实例。 Table
代表数据库中的一张表。关联到 MetaData
。 Column
代表表中的一列。包含名称、类型和约束。 SQLAlchemy Types 独立于数据库的类型,如 Integer
, String
, DateTime
, Boolean
等。 Constraints 约束条件,如主键、外键、唯一、非空、检查约束、索引。
1 2 3 4 5 6 7 8 9 10 11 from sqlalchemy import MetaData, Table, Column, Integer, Stringmetadata = MetaData() user_table = Table( "user" , metadata, Column("id" , Integer, primary_key=True ), Column("name" , String(50 ), nullable=False ), Column("email" , String(120 ), unique=True ), )
简单来说,SQLAlchemy 中的 MetaData
对象就是一个数据库结构蓝图的“登记簿” 。
它主要干两件事:
记录信息 :它收集并保存了你在 Python 代码里定义的所有数据库表(Table
对象)、列、索引、外键等这些“蓝图”信息。操作数据库结构 :当你需要根据这些“蓝图”在实际数据库中创建表时,你会用到它,比如调用 metadata.create_all(engine)
就能把所有登记在册的表都建出来Table
构造函数关键参数
参数 类型 描述 name
str
必需 . 表名。metadata
MetaData
必需 . 关联的 MetaData
对象。*columns
Column
, Constraint
, Index
等对象必需 . 定义表的列和表级约束。**kwargs
特定数据库方言的选项 (如 mysql_engine='InnoDB'
)。
Column
构造函数关键参数
参数 类型 描述 name
str
列名 (通常省略,SQLAlchemy 会使用属性名)。 type_
SQLAlchemy Type 必需 . 列的数据类型 (如 Integer
, String(50)
, DateTime
)。primary_key
bool
是否为主键。 nullable
bool
是否允许为空 (默认为 True
)。 default
Any
Python 级别的默认值。 server_default
str
或 SQL Expression数据库级别的默认值 (如 func.now()
或 "0"
)。 unique
bool
是否唯一。 index
bool
是否为此列创建索引。 ForeignKey(...)
ForeignKey
对象定义外键约束,指向 目标表名.目标列名
。 onupdate
Any
或 SQL Expression更新行时自动设置的值 (常用于时间戳)。 server_onupdate
WorkspaceedValue
或 str
等数据库级别的更新触发器。 comment
str
列注释 (部分数据库支持)。
主键、唯一、索引、默认值、非空等示例
1 2 3 4 5 6 7 8 9 10 11 from sqlalchemy import Table, Column, Integer, String, Boolean, DateTime, funcuser_table = Table( "user" , metadata, Column("id" , Integer, primary_key=True ), Column("username" , String(50 ), nullable=False , unique=True , index=True ), Column("email" , String(120 ), nullable=False ), Column("created_at" , DateTime, server_default=func.now()), Column("is_active" , Boolean, default=True ), )
外键约束示例
1 2 3 4 5 6 7 8 9 from sqlalchemy import ForeignKeyaddress_table = Table( "address" , metadata, Column("id" , Integer, primary_key=True ), Column("user_id" , Integer, ForeignKey("user.id" )), Column("address" , String(255 )), )
表级约束与索引示例
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 from sqlalchemy import Table, Column, Integer, String, UniqueConstraint, Indexmetadata = MetaData() product_table = Table( "product" , metadata, Column("id" , Integer, primary_key=True ), Column("name" , String(100 )), Column("sku" , String(50 )), UniqueConstraint("sku" , name="uix_1" ), Index("ix_product_name" , "name" ), )
常用 SQLAlchemy 类型
类型 映射的常见 SQL 类型 (示例) Integer
/ SmallInteger
/ BigInteger
INTEGER
, SMALLINT
, BIGINT
String(length)
VARCHAR(length)
Text
TEXT
, CLOB
Numeric(prec, scale)
NUMERIC(prec, scale)
, DECIMAL
Float(precision)
FLOAT
, REAL
Boolean
BOOLEAN
, SMALLINT
(0/1)Date
DATE
Time
TIME
DateTime
TIMESTAMP
, DATETIME
LargeBinary
BLOB
, BYTEA
JSON
JSON
, JSONB
(需要数据库支持)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 from sqlalchemy import (MetaData, Table, Column, Integer, String, DateTime, Boolean, Numeric, ForeignKey, Index, PrimaryKeyConstraint, UniqueConstraint, CheckConstraint, func, TIMESTAMP, Text) from print_utils import print_header, print_info, print_success, print_errorfrom sqlalchemy import create_engineimport pymysqlDB_CONFIG = { "host" : "localhost" , "port" : 3306 , "user" : "root" , "password" : "root" , "db_name" : "pymysql_demo" } def create_database (): """创建数据库(如果不存在)""" try : conn = pymysql.connect( host=DB_CONFIG["host" ], port=DB_CONFIG["port" ], user=DB_CONFIG["user" ], password=DB_CONFIG["password" ] ) cursor = conn.cursor() cursor.execute(f"SHOW DATABASES LIKE '{DB_CONFIG['db_name' ]} '" ) result = cursor.fetchone() if not result: print_info(f"数据库 '{DB_CONFIG['db_name' ]} ' 不存在,正在创建..." ) cursor.execute(f"CREATE DATABASE {DB_CONFIG['db_name' ]} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci" ) print_success(f"数据库 '{DB_CONFIG['db_name' ]} ' 创建成功" ) else : print_info(f"数据库 '{DB_CONFIG['db_name' ]} ' 已存在" ) conn.commit() cursor.close() conn.close() except Exception as e: print_error(f"创建数据库时出错: {str (e)} " ) raise e def create_table (): print_header("定义 SQLAlchemy Core 表结构示例 (MySQL)" ) metadata_obj = MetaData() print_info("Meta 对象已创建" ) categories_table = Table( "categories" , metadata_obj, Column("category_id" , Integer, primary_key=True , autoincrement=True , comment="类别ID" ), Column("name" , String(50 ), nullable=False , unique=True , comment="类别名称" ), mysql_engine="InnoDB" , mysql_charset="utf8mb4" , mysql_collate="utf8mb4_unicode_ci" , mysql_row_format="DYNAMIC" ) products_table = Table( "products" , metadata_obj, Column("id" , Integer, primary_key=True , autoincrement=True , comment="产品ID" ), Column("name" , String(100 ), nullable=False , comment="产品名称" ), Column("price" , Numeric(10 , 2 ), nullable=False , comment="价格" ), Column("category_id" , Integer, ForeignKey("categories.category_id" ), nullable=False , comment="类别ID" ), Column("created_at" , TIMESTAMP, server_default=func.now(), comment="创建时间" ), Column("updated_at" , TIMESTAMP, server_default=func.now(), onupdate=func.now(), comment="更新时间" ), Column("is_active" , Boolean, default=True , comment="是否激活" ), mysql_engine="InnoDB" , mysql_charset="utf8mb4" , mysql_collate="utf8mb4_unicode_ci" , mysql_row_format="DYNAMIC" ) return metadata_obj, categories_table, products_table def create_tables (engine, metadata_obj ): metadata_obj.create_all(bind=engine) print_success("表创建完成" ) def check_table (metadata_obj, products_table ): print_info("\nMetaData 中包含的表:" ) for table_name in metadata_obj.tables: print (f" - {table_name} " ) print_info("\n'products' 表的列信息:" ) for column in products_table.columns: print (f" - 列名: {column.name} , 类型: {column.type } , 主键: {column.primary_key} , 外键: {column.foreign_keys} " ) if __name__ == '__main__' : create_database() engine = create_engine(f"mysql+pymysql://{DB_CONFIG['user' ]} :{DB_CONFIG['password' ]} @{DB_CONFIG['host' ]} :{DB_CONFIG['port' ]} /{DB_CONFIG['db_name' ]} " , echo=True ) metadata_obj, categories_table, products_table = create_table() create_tables(engine, metadata_obj) check_table(metadata_obj, products_table)
16.3.3. 执行 SQL 表达式语句 定义好表之后,就可以使用 SQLAlchemy Core 构建并执行 INSERT, SELECT, UPDATE, DELETE 语句了。
核心执行流程
获取连接 : 使用 with engine.connect() as connection:
或 with engine.begin() as connection:
(推荐,自动事务管理)。构建语句 : 使用 insert(table)
, select(columns)
, update(table)
, delete(table)
。添加子句 : 链式调用 .values(...)
, .where(...)
, .order_by(...)
, .limit(...)
, .offset(...)
, .join(...)
等方法。执行 : result = connection.execute(statement, [parameters])
。处理结果 (对于 SELECT)。事务管理 : 如果使用 engine.connect()
, DML 操作后需 connection.commit()
或 connection.rollback()
。如果使用 engine.begin()
, 事务自动处理。常用语句构建函数/方法
函数/方法 用途 示例 insert(table)
构建 INSERT 语句。 stmt = insert(products_table).values(name="P1", price=10)
select(...)
构建 SELECT 语句。参数是列对象或表对象。 stmt = select(products_table.c.name, products_table.c.price)
update(table)
构建 UPDATE 语句。 stmt = update(products_table).where(products_table.c.id == 1)
delete(table)
构建 DELETE 语句。 stmt = delete(products_table).where(products_table.c.stock == 0)
.values(...)
(INSERT/UPDATE) 指定要插入或更新的列和值。 .values(name="P2", price=20)
.where(...)
(SELECT/UPDATE/DELETE) 指定过滤条件。 .where(products_table.c.price > 100)
.order_by(...)
(SELECT) 指定排序方式 (.asc()
, .desc()
)。 .order_by(products_table.c.price.desc())
.limit(n)
(SELECT) 限制返回结果数量。 .limit(10)
.offset(n)
(SELECT) 跳过指定数量的结果 (用于分页)。 .offset(20)
.join(...)
(SELECT) 连接其他表。参数通常是目标表和连接条件。 .join(categories_table, products_table.c.category_id == ...)
.label(name)
给列或表达式设置别名。 select(products_table.c.name.label("product_name"))
text(sql)
用于执行原生 SQL 字符串 (配合参数绑定使用)。 stmt = text("SELECT * FROM products WHERE id = :id")
常用条件操作符/函数
操作符/函数 用途 示例 ==
, !=
, <
, >
基本比较。 table.c.price == 100
&
/ and_(...)
逻辑与 (AND)。 (table.c.price > 10) & (table.c.stock > 0)
` /
or_(…)`逻辑或 (OR)。 ~
/ not_(...)
逻辑非 (NOT)。 ~table.c.name.like('A%')
.like(pattern)
SQL LIKE 操作 ( %
通配符)。 table.c.name.like('%Core%')
.ilike(pattern)
不区分大小写的 LIKE (某些数据库)。 table.c.name.ilike('%core%')
.in_([...])
SQL IN 操作。 table.c.category_id.in_([1, 3, 5])
.is_(None)
SQL IS NULL。 table.c.added_date.is_(None)
.isnot(None)
SQL IS NOT NULL。 table.c.description.isnot(None)
.between(a, b)
SQL BETWEEN 操作。 table.c.price.between(100, 500)
func.<name>
SQL 函数 (如 func.now()
, func.count()
) select(func.count(products_table.c.id))
结果处理 (ResultProxy
/ CursorResult
)
connection.execute()
返回一个结果对象,用于获取数据。
方法 描述 Workspaceall()
获取所有行,返回 Row
对象列表。 Workspaceone()
获取下一行,返回 Row
对象或 None
。 Workspacemany(size=None)
获取指定数量的行,返回 Row
对象列表。 scalar()
获取结果的第一行第一列的值,无结果或多列时行为可能变化或报错。 scalars()
(2.0+)返回 ScalarResult
,迭代产生每行第一列的值。 first()
获取第一行 (Row
对象),无结果返回 None
。常用。 one()
获取唯一一行,无结果或多于一行时报错。 one_or_none()
获取唯一一行,无结果返回 None
,多于一行时报错。 mappings()
(2.0+)返回 MappingResult
,迭代产生字典形式的行。 (直接迭代结果对象) 每次迭代返回一个 Row
对象。 rowcount
(属性)DML 操作影响的行数。 inserted_primary_key
(属性)INSERT 操作的自增主键 (元组)。可能只包含第一个插入行的主键。 keys()
(属性)返回结果的列名列表。
Row
对象访问 :
按索引: row[0]
按列名: row['column_name']
或 row.column_name
按列对象: row[table.c.column_name]
转字典: dict(row._mapping)
(SQLAlchemy 2.0+) 或 dict(row)
(旧版) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 from sqlalchemy import (MetaData, Table, Column, Integer, String, DateTime, Boolean, Numeric, ForeignKey, Index, PrimaryKeyConstraint, UniqueConstraint, CheckConstraint, func, TIMESTAMP, Text) from print_utils import print_header, print_info, print_success, print_error, print_subheader, print_sql,print_warningfrom sqlalchemy import create_engineDB_CONFIG = { 'DB_URI' : 'mysql+pymysql://root:root@localhost:3306/pymysql_demo' } engine = create_engine(DB_CONFIG['DB_URI' ]) products_table: Table = Table('products' , MetaData(), Column('id' , Integer, primary_key=True ), Column('name' , String(50 ), nullable=False ), Column('price' , Numeric(10 , 2 ), nullable=False ), Column('category_id' , Integer, ForeignKey('categories.category_id' ), nullable=False ), Column('is_active' , Boolean, default=True ) ) categories_table = Table('categories' , MetaData(), Column('category_id' , Integer, primary_key=True ), Column('name' , String(50 ), nullable=False ) )
插入单行记录 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 def insert_category_data (): """ 插入类别表数据 """ print_header("SQLAlchemy 类别表 插入示例" ) print_subheader("1. 插入单个产品记录" ) stmt_insert_single = insert(categories_table).values( name="书籍类" ) print_sql(str (stmt_insert_single)) inserted_pk = None try : with engine.begin() as connection: result = connection.execute(stmt_insert_single) if result.inserted_primary_key: inserted_pk = result.inserted_primary_key[0 ] print_success(f"成功插入类别 '书籍类' (ID: {inserted_pk} )" ) elif result.rowcount == 1 : print_success("成功插入类别 '书籍类' (无法获取主键)" ) else : print_warning("插入类别 '书籍类' 时 rowcount 不为 1" ) except Exception as e: print_error(f"插入类别 '书籍类' 时出错: {e} " ) def insert_product_data (): """ 插入产品表数据 """ print_header("SQLAlchemy 产品表 插入示例" ) with engine.connect() as conn: result = conn.execute(categories_table.select().where(categories_table.c.name == "书籍类" )) category = result.fetchone() if not category: print_error("未找到'书籍类'类别" ) return category_id = category[0 ] stmt_insert_single = insert(products_table).values( name="SQLAlchemy书籍" , price=123.45 , category_id=category_id, is_active=True ) print_sql(str (stmt_insert_single)) inserted_pk = None try : with engine.begin() as connection: result = connection.execute(stmt_insert_single) if result.inserted_primary_key: inserted_pk = result.inserted_primary_key[0 ] print_success(f"成功插入产品 'SQLAlchemy书籍' (ID: {inserted_pk} )" ) elif result.rowcount == 1 : print_success("成功插入产品 'SQLAlchemy书籍' (无法获取主键)" ) else : print_warning("插入产品 'SQLAlchemy书籍' 时 rowcount 不为 1" ) except Exception as e: print_error(f"插入单个产品时出错: {e} " ) if __name__ == '__main__' : insert_product_data()
批量插入记录 (INSERT Multiple) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 def insert_many_products (): print_header("SQLAlchemy 批量插入示例" ) print_subheader("1. 批量插入多个产品记录" ) stmt_insert_many = insert(products_table).values([ {'name' : 'Python编程' , 'price' : 123.45 , 'category_id' : 1 }, {'name' : 'SQLAlchemy教程' , 'price' : 99.99 , 'category_id' : 1 }, {'name' : '大话西游之月光宝盒' , 'price' : 150.00 , 'category_id' : 1 } ]) print_sql(str (stmt_insert_many)) try : with engine.begin() as connection: result = connection.execute(stmt_insert_many) print_success(f"成功插入 {result.rowcount} 条记录" ) except Exception as e: print_error(f"插入失败: {e} " )
更新数据 (UPDATE) 1 2 3 4 5 6 7 8 9 10 11 12 13 def update_product (): print_header("SQLAlchemy 更新示例" ) stmt_update = update(products_table).where(products_table.c.id == 1 ).values( price = 999999 ) print_sql(str (stmt_update)) try : with engine.begin() as connection: result = connection.execute(stmt_update) print_success(f"成功更新 {result.rowcount} 条记录" ) except Exception as e: print_error(f"更新失败: {e} " )
批量更新记录 (UPDATE Multiple) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def update_products (): print_header("SQLAlchemy 批量更新示例" ) category_id_to_update = 1 price_increase_factor = 1.10 stmt_update = update(products_table).where( products_table.c.category_id == category_id_to_update ).values( price = products_table.c.price * price_increase_factor ) print_sql(str (stmt_update)) try : with engine.begin() as connection: result = connection.execute(stmt_update) print_success(f"成功更新 {result.rowcount} 条记录" ) except Exception as e: print_error(f"更新失败: {e} " )
删除数据 (DELETE) 1 2 3 4 5 6 7 8 9 10 11 def delete_product (): print_header("SQLAlchemy 删除示例" ) product_id_to_delete = 1 stmt_delete = delete(products_table).where(products_table.c.id == product_id_to_delete) print_sql(str (stmt_delete)) try : with engine.begin() as connection: result = connection.execute(stmt_delete) print_success(f"成功删除 {result.rowcount} 条记录" ) except Exception as e: print_error(f"删除失败: {e} " )
批量删除记录 (DELETE Multiple) 1 2 3 4 5 6 7 8 9 10 11 12 def delete_products (): print_header("SQLAlchemy 批量删除示例" ) stmt_delete = delete(products_table).where(products_table.c.is_active == False ) print_sql(str (stmt_delete)) try : with engine.begin() as connection: result = connection.execute(stmt_delete) print_success(f"成功删除 {result.rowcount} 条记录" ) except Exception as e: print_error(f"删除失败: {e} " )
在进入最重要的查询的基础前,我们可以看到代码有很多是重复的,每一次都要进行begin,捕获…这个繁杂的过程会导致代码冗余,我们可以采用AOP的思想,去实现一个事务的装饰器,如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 from functools import wrapsdef transactional (func=None , *, engine_obj=None ): """ 使用AOP思想实现的事务装饰器 可以直接应用于函数上,无需手动管理事务 用法: @transactional # 使用默认引擎 def my_function(): # 执行SQL操作,自动事务管理 @transactional(engine_obj=custom_engine) # 使用自定义引擎 def another_function(): # 执行SQL操作,自动事务管理 """ _engine = engine_obj or engine def decorator (fn ): @wraps(fn ) def wrapper (*args, **kwargs ): if 'connection' in kwargs: return fn(*args, **kwargs) try : with _engine.begin() as connection: kwargs['connection' ] = connection result = fn(*args, **kwargs) return result except Exception as e: print_error(f"事务执行失败: {e} " ) raise return wrapper if func is None : return decorator return decorator(func)
通过这个装饰器,我们就可以实现代码的优雅性
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 @transactional def insert_many_products_aop (connection=None ): """使用AOP事务装饰器实现的批量插入""" print_header("SQLAlchemy 批量插入示例 (使用事务装饰器)" ) stmt_insert_many = insert(products_table).values([ {'name' : 'Python编程AOP版' , 'price' : 123.45 , 'category_id' : 1 }, {'name' : 'SQLAlchemy教程AOP版' , 'price' : 99.99 , 'category_id' : 1 }, {'name' : '大话西游之月光宝盒AOP版' , 'price' : 150.00 , 'category_id' : 1 } ]) print_sql(str (stmt_insert_many)) result = connection.execute(stmt_insert_many) print_success(f"成功插入 {result.rowcount} 条记录" )
查询数据 (SELECT - 基础) 1 2 3 4 5 6 7 8 9 @transactional def select_all_products (connection=None ): print_info("查询所有产品的 '*' (前 5 条):" ) stmt_select_all = select(products_table).limit(5 ) print_sql(str (stmt_select_all)) result = connection.execute(stmt_select_all) for row in result.fetchmany(5 ): print (row)
查询数据 (SELECT - 条件过滤 WHERE) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 from sqlalchemy import and_, or_, not_ from sqlalchemy import * from print_utils import print_header, print_info, print_success, print_error, print_subheader, print_sql, print_warning@transactional def select_all_products (connection=None ): print_info("查询所有产品的 '*' (前 5 条):" ) stmt_select_all = select(products_table).limit(5 ) print_sql(str (stmt_select_all)) result = connection.execute(stmt_select_all) for row in result.fetchmany(5 ): print (row) @transactional def select_by_category_id (category_id_to_find, connection=None ): print_info(f"\n查询类别 ID 为 {category_id_to_find} 的产品:" ) stmt_eq = select(products_table.c.id , products_table.c.name, products_table.c.price).where( products_table.c.category_id == category_id_to_find ) print_sql(str (stmt_eq)) result = connection.execute(stmt_eq) for row in result: print (row) @transactional def select_by_price_gt (max_price, connection=None ): print_info(f"\n查询价格大于 {max_price} 的产品名称和价格 (降序):" ) stmt_gt = select(products_table.c.name, products_table.c.price).where( products_table.c.price > max_price ).order_by(products_table.c.price.desc()) print_sql(str (stmt_gt)) result = connection.execute(stmt_gt) for row in result: print (row) @transactional def select_by_price_between (min_p, max_p, connection=None ): print_info(f"\n查询价格在 {min_p} 到 {max_p} 之间的产品名称和价格 (升序):" ) stmt_between = select(products_table.c.name, products_table.c.price).where( products_table.c.price.between(min_p, max_p) ).order_by(products_table.c.price) print_sql(str (stmt_between)) result = connection.execute(stmt_between) for row in result: print (row) @transactional def select_by_category_ids (target_cat_ids, connection=None ): print_info(f"\n查询类别 ID 在 {target_cat_ids} 中的产品:" ) stmt_in = select(products_table.c.id , products_table.c.name, products_table.c.category_id).where( products_table.c.category_id.in_(target_cat_ids) ) print_sql(str (stmt_in)) result = connection.execute(stmt_in) for row in result: print (row) @transactional def select_by_name_like (search_term, connection=None ): print_info(f"\n查询名称包含 '{search_term} ' 的产品:" ) stmt_like = select(products_table.c.name, products_table.c.price).where( products_table.c.name.like(f"%{search_term} %" ) ) print_sql(str (stmt_like)) result = connection.execute(stmt_like) for row in result: print (row) @transactional def alter_add_stock_column (connection=None ): print_info("增加 stock 字段到 products 表:" ) stmt_alter = text("ALTER TABLE products ADD COLUMN stock INTEGER NOT NULL" ) print_sql(str (stmt_alter)) connection.execute(stmt_alter) @transactional def select_by_complex_condition (cat_id, min_stock, connection=None ): print_info(f"\n查询类别 ID 为 {cat_id} 且 (价格小于 50 或 库存大于 {min_stock} ) 的产品:" ) stmt_complex = select(products_table).where( and_( products_table.c.category_id == cat_id, or_( products_table.c.price < 50.0 , products_table.c.stock > min_stock ) ) ) print_sql(str (stmt_complex)) result = connection.execute(stmt_complex) for row in result: print (row) if __name__ == '__main__' : select_all_products() select_by_category_id(1 ) select_by_price_gt(100.0 ) select_by_price_between(100.0 , 500.0 ) select_by_category_ids([1 , 3 ]) select_by_name_like("Core" ) select_by_complex_condition(1 , 10 )
查询数据 (SELECT - 排序 ORDER BY) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 print_subheader("7. 查询数据 (排序 ORDER BY)" ) @transactional def select_products_order_by_price_desc (connection=None ): print_info("\n按价格降序排列产品 (前 5 条):" ) stmt_order_desc = select( products_table.c.name, products_table.c.price ).order_by( products_table.c.price.desc() ).limit(5 ) print_sql(str (stmt_order_desc)) result = connection.execute(stmt_order_desc) for row in result: print (row) @transactional def select_products_order_by_multi_columns (connection=None ): print_info("\n按类别 ID 升序、价格降序排列产品 (前 5 条):" ) stmt_order_multi = select( products_table.c.name, products_table.c.category_id, products_table.c.price ).order_by( products_table.c.category_id.asc(), products_table.c.price.desc() ).limit(5 ) print_sql(str (stmt_order_multi)) result = connection.execute(stmt_order_multi) for row in result: print (row)
查询数据 (SELECT - 去重 DISTINCT) 1 2 3 4 5 6 7 8 @transactional def select_distinct_products (): print_info("\n查询所有类别 ID 并去重:" ) stmt_distinct = select(products_table.c.category_id,products_table.c.name).distinct() print_sql(str (stmt_distinct)) result = connection.execute(stmt_distinct) for row in result: print (row)
查询数据 (SELECT - 聚合函数) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 @transactional def select_aggregation_products (connection=None ): print_info("\n计算产品统计信息 (总数, 总库存, 平均价, 最高价, 最低价):" ) stmt_aggregation = select( func.count(products_table.c.id ).label("total_products" ), func.sum (products_table.c.stock).label("total_stock" ), func.avg(products_table.c.price).label("average_price" ), func.max (products_table.c.price).label("max_price" ), func.min (products_table.c.price).label("min_price" ) ) print_sql(str (stmt_aggregation)) result = connection.execute(stmt_aggregation).first() if result: print_result_item(dict (result._mapping)) else : print_info("未能获取产品统计信息。" )
查询数据 (SELECT - 分组 GROUP BY) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 @transactional def select_group_by_category (connection=None ): print_info("\n按类别分组计算产品统计信息:" ) stmt_group_by = select( products_table.c.category_id, func.count(products_table.c.id ).label("product_count" ), func.sum (products_table.c.price).label("price" ), ).group_by(products_table.c.category_id).order_by(products_table.c.category_id.asc()) print_sql(str (stmt_group_by)) result = connection.execute(stmt_group_by).first() if result: print_result_item(dict (result._mapping)) else : print_info("未能获取产品统计信息。" ) @transactional def select_group_by_category (connection=None ): print_info("\n按类别分组计算产品统计信息 (产品数量大于1的类别):" ) stmt_group_by = select( products_table.c.category_id, func.count(products_table.c.id ).label("product_count" ), func.sum (products_table.c.price).label("price" ), ).group_by(products_table.c.category_id)\ .having( and_( func.count(products_table.c.id ) > 1 , func.sum (products_table.c.price) > 0 ) )\ .order_by(products_table.c.category_id.asc()) print_sql(str (stmt_group_by)) result = connection.execute(stmt_group_by).first() if result: print_result_item(dict (result._mapping)) else : print_info("未能获取产品统计信息。" )
查询数据 (SELECT - 分页 LIMIT/OFFSET) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 @transactional def select_limit_pages (connection=None , page=1 , page_size=3 ): print_info(f"\n分页查询产品数据 (每页 {page_size} 条, 第 {page} 页):" ) offset_val = (page - 1 ) * page_size print_info(f"当前页码: {page} , 每页显示 {page_size} 条, 偏移量: {offset_val} " ) count_stmt = select(func.count()).select_from(products_table) total_count = connection.execute(count_stmt).scalar() total_pages = (total_count + page_size - 1 ) // page_size stmt_paging = select(products_table.c.name, products_table.c.price)\ .offset(offset_val)\ .limit(page_size) print_sql(str (stmt_paging)) result = connection.execute(stmt_paging) print_info(f"总记录数: {total_count} , 总页数: {total_pages} " ) print_info("当前页数据:" ) for row in result: print_result_item(dict (row._mapping)) if __name__ == '__main__' : for page in range (1 ,10 ): select_limit_pages(page=page) print_info("-" * 50 )
查询数据 (SELECT - 连接查询 JOIN) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 @transactional def select_products_with_join (connection=None ): print_info("\n内连接查询: 获取产品及其对应的类别名称 (只显示有类别的产品):" ) stmt_inner_join = select( products_table.c.name, products_table.c.price, categories_table.c.name.label("category_name" ) ).join( categories_table, products_table.c.category_id == categories_table.c.category_id ).where( products_table.c.category_id == 1 ) print_sql(str (stmt_inner_join)) result = connection.execute(stmt_inner_join) for row in result: print_result_item(dict (row._mapping))
查询数据 (SELECT - CASE 表达式) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 @transactional def select_products_with_level (connection=None ): print_info("\n使用 CASE 表达式根据价格给产品分类:" ) stmt_case = select( products_table.c.name, products_table.c.price, case ( (products_table.c.price > 500 , "白金级" ), (products_table.c.price > 200 , "进阶级" ), (products_table.c.price > 100 , "标准级" ), (products_table.c.price > 10 , "入门级" ), else_="未知" ).label("price_level" ) ) print_sql(str (stmt_case)) result = connection.execute(stmt_case) for row in result: print_result_item(dict (row._mapping))
查询数据 (SELECT - 子查询 Subquery) 子查询方法快速参考
为了方便快速选择合适的子查询构建方式,这里根据核心使用场景进行了归纳:
方法/概念 核心场景 (一句话概括) 关键点/提示 .scalar_subquery()
WHERE 中与单个 预期值比较 子查询应返回单行单列 。用于 >、=、<
等比较。 .subquery()
FROM 或 JOIN 中像表一样使用 子查询结果 像操作普通 Table
一样操作它;用 .c
访问列;聚合列需用 .label()
命名。 (直接用 select()
) WHERE 中与列表 比较 (如 IN
) 子查询应返回单列多行(如id) ;常与 .in_()
配合。 .exists()
WHERE 中检查是否存在 满足条件的关联行 只关心“有没有”,不关心“是什么”或“有多少”。 .alias()
需要在同一查询 中区分同一个表 的多次引用时 相关子查询 和自连接 的必备工具。.correlate()
(较少用) 显式声明 子查询依赖的外部表 主要用于非 WHERE
/FROM
子句的子查询,或自动关联失效时。 .lateral()
(需数据库支持) 子查询引用同 FROM 中之前 的表 用于行级计算或复杂的 Top-N 查询。
选择思路小结 :
需要子查询返回一个值 来比较? -> .scalar_subquery()
需要把子查询的结果当作一张表 来连接 (JOIN) 或从中选择 (FROM)? -> .subquery()
需要判断某个东西是否在子查询返回的列表 里? -> 直接用 select()
配合 .in_()
只需要判断有没有符合条件的关联数据 ,不在乎具体内容? -> .exists()
查询中涉及自己和自己比 (相关子查询、自连接)? -> .alias()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 @transactional def select_products_with_subquery (connection=None ): """ --- 子查询用在 WHERE 子句中 ---""" print_info("\n子查询: 查找价格高于平均价格的产品:" ) subquery_avg_price = select(func.avg(products_table.c.price)).scalar_subquery() stmt_subquery_where = select(products_table.c.name, products_table.c.price).where( products_table.c.price > subquery_avg_price ) print_sql(str (stmt_subquery_where)) result = connection.execute(stmt_subquery_where) for row in result: print_result_item(dict (row._mapping)) @transactional def select_products_with_subquery_in_select (connection=None ): """ --- 子查询用在 SELECT 子句中 ---""" print_info("\n子查询: 查找每个类别中价格最高的产品 (使用派生表):" ) subquery_max_price = select( products_table.c.category_id, func.max (products_table.c.price).label("max_price" ) ).group_by(products_table.c.category_id).subquery() stmt_subquery_from = select( products_table.c.name, products_table.c.price, products_table.c.category_id, ).join( subquery_max_price, and_( products_table.c.category_id == subquery_max_price.c.category_id, products_table.c.price == subquery_max_price.c.max_price ) ).order_by(products_table.c.category_id) print_sql(str (stmt_subquery_from)) result = connection.execute(stmt_subquery_from) for row in result: print_result_item(dict (row._mapping)) @transactional def select_products_with_in_subquery (connection=None , min_price=100 , max_price=200 ): """ --- 子查询用在 IN 子句中 ---""" print_info(f"\n子查询: 查找价格在 {min_price} 到 {max_price} 之间的产品,以及这些产品所属的类别:" ) subquery_product_ids = select(products_table.c.id ).where( products_table.c.price.between(min_price, max_price) ).scalar_subquery() stmt = select( products_table.c.name, products_table.c.price, categories_table.c.name.label("category_name" ) ).join( categories_table, products_table.c.category_id == categories_table.c.category_id ).where( products_table.c.id .in_(subquery_product_ids) ) print_sql(str (stmt)) result = connection.execute(stmt) for row in result: print_result_item(dict (row._mapping)) @transactional def select_products_with_alias_self_comparison (connection=None ): """ --- 子查询用于自连接比较 ---""" print_info("\n子查询: 查找比同类别平均价格高的产品:" ) avg_price_by_category = select( products_table.c.category_id, func.avg(products_table.c.price).label("avg_price" ) ).group_by(products_table.c.category_id).subquery() stmt = select( products_table.c.name, products_table.c.category_id, products_table.c.price, avg_price_by_category.c.avg_price ).join( avg_price_by_category, products_table.c.category_id == avg_price_by_category.c.category_id, ).where( products_table.c.price > avg_price_by_category.c.avg_price ) print_sql(str (stmt)) result = connection.execute(stmt) for row in result: print_result_item(dict (row._mapping))
SQLAlchemy 数据库函数 (func) SQLAlchemy 通过 sqlalchemy.func
这个特殊对象,提供了一种与数据库无关 的方式来调用 SQL 内置函数。这意味着你可以用统一的 Python 语法(如 func.count()
, func.now()
, func.lower()
)来生成对应数据库(如 MySQL, PostgreSQL, SQLite)的特定函数调用(如 COUNT()
, NOW()
, LOWER()
)。SQLAlchemy 的方言 (Dialect) 会负责进行正确的语法转换。
常用 func
函数调用参考表
下表列出了一些常用的 SQL 函数及其通过 func
调用的方式:
func
调用示例对应 SQL 函数 (常见) 用途说明 分类 func.count(col)
/ func.count()
COUNT(col)
/ COUNT(*)
计数 (指定列或所有行) 聚合 func.sum(col)
SUM(col)
求和 聚合 func.avg(col)
AVG(col)
平均值 聚合 func.max(col)
MAX(col)
最大值 聚合 func.min(col)
MIN(col)
最小值 聚合 func.count(distinct(col))
COUNT(DISTINCT col)
计算非重复值的数量 聚合 func.now()
NOW()
, CURRENT_TIMESTAMP
获取当前日期时间 (时间戳) 日期/时间 func.current_date()
CURRENT_DATE
获取当前日期 日期/时间 func.current_time()
CURRENT_TIME
获取当前时间 日期/时间 func.extract(field, date_col)
EXTRACT(field FROM date_col)
提取日期/时间部分 (‘year’, ‘month’) 日期/时间 func.lower(str_col)
LOWER(str_col)
字符串转小写 字符串 func.upper(str_col)
UPPER(str_col)
字符串转大写 字符串 func.length(str_col)
LENGTH(str_col)
, LEN()
获取字符串长度 字符串 func.concat(*args)
CONCAT(arg1, arg2, ...)
字符串拼接 字符串 func.substring(str, start, len)
SUBSTRING(str, start, len)
, SUBSTR()
提取子字符串 字符串 func.abs(num_col)
ABS(num_col)
绝对值 数学 func.round(num_col, digits)
ROUND(num_col, digits)
四舍五入 数学 func.random()
RAND()
, RANDOM()
生成随机数 (具体行为看数据库) 其他
注意: 某些函数(特别是日期/时间函数)的具体名称和行为可能因数据库方言而异,但通过 func
调用通常能提供较好的兼容性。
代码示例 (完整版)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 def demonstrate_sqlalchemy_functions (): """ 展示 SQLAlchemy 中各种数据库函数的使用方法 包括聚合函数、字符串函数、日期时间函数、数学函数和条件函数 """ print_subheader("16. 使用 SQLAlchemy func 调用数据库函数 - 完整示例" ) demonstrate_conditional_functions() @transactional def demonstrate_aggregate_functions (connection=None ): """展示 SQLAlchemy 聚合函数的使用""" print_info("\n聚合函数示例:产品统计" ) stmt_agg_complete = select( func.count(products_table.c.id ).label("total_products" ), func.count(distinct(products_table.c.category_id)).label("distinct_categories" ), func.sum (products_table.c.stock).label("total_stock" ), func.avg(products_table.c.price).label("average_price" ), func.max (products_table.c.price).label("max_price" ), func.min (products_table.c.price).label("min_price" ) ) print_sql(str (stmt_agg_complete)) try : result = connection.execute(stmt_agg_complete).first() if result: print_info("聚合统计结果:" ) print_result_item(dict (result._mapping)) else : print_info("未能获取产品统计信息。" ) except Exception as e: print_error(f"执行聚合查询时出错: {e} " ) @transactional def demonstrate_datetime_functions (connection=None ): """展示 SQLAlchemy 日期/时间函数的使用""" print_info("\n日期/时间函数示例:数据库当前时间和日期部分提取" ) stmt_func_datetime_complete = select( func.now().label("now" ), func.current_date().label("current_date" ), func.current_time().label("current_time" ), func.current_timestamp().label("current_timestamp" ), func.sysdate().label("sysdate" ), func.localtime().label("localtime" ), func.localtimestamp().label("localtimestamp" ), func.extract('year' , func.current_timestamp()).label("year" ), func.extract('month' , func.current_timestamp()).label("month" ), func.extract('day' , func.current_timestamp()).label("day" ), func.extract('hour' , func.current_timestamp()).label("hour" ), func.extract('minute' , func.current_timestamp()).label("minute" ), func.extract('second' , func.current_timestamp()).label("second" ), ) print_sql(str (stmt_func_datetime_complete)) result = connection.execute(stmt_func_datetime_complete) print_info("\n日期时间函数结果:" ) print_info("-" * 80 ) print_info(f"{'函数名' :<20 } {'值' :<30 } " ) print_info("-" * 80 ) for row in result: for key, value in row._mapping.items(): print_info(f"{key:<20 } {str (value):<30 } " ) print_info("-" * 80 ) @transactional def demonstrate_math_functions (connection=None ): """展示 SQLAlchemy 数学函数的使用""" print_info("\n数学函数示例:数学运算" ) stmt_func_math = select( func.abs (-10 ).label("绝对值" ), func.ceil(10.5 ).label("向上取整" ), func.floor(10.5 ).label("向下取整" ), func.sign(-10 ).label("符号" ), func.sqrt(16 ).label("平方根" ), func.power(2 , 3 ).label("幂运算" ), func.mod(10 , 3 ).label("取模" ), func.round (10.5 ).label("四舍五入" ), ) print_sql(str (stmt_func_math)) try : result = connection.execute(stmt_func_math) print_info("\n数学函数结果:" ) print_info("-" * 80 ) print_info(f"{'函数名' :<20 } {'值' :<30 } " ) print_info("-" * 80 ) for row in result: for key, value in row._mapping.items(): print_info(f"{key:<20 } {str (value):<30 } " ) print_info("-" * 80 ) except Exception as e: print_error(f"执行数学函数查询时出错: {e} " ) if __name__ == '__main__' : demonstrate_sqlalchemy_functions()
16.3.4 SQLAlchemy ORM: 对象关系映射 ORM (Object Relational Mapper) 将数据库表映射为 Python 类,允许通过对象进行数据库操作。它是 SQLAlchemy 的核心功能,构建于 Core 之上。
核心概念: 类 (Model) <-> 表 (Table), 对象 (Instance) <-> 行 (Row), 属性 (Attribute) <-> 列 (Column)。
建议项目结构 (用于 ORM 示例):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 sqlalchemy_orm_practice/ ├── __init__.py ├── core_config.py ├── models/ │ ├── __init__.py │ ├── base.py │ └── orm_models.py ├── crud/ │ ├── __init__.py │ └── orm_crud_ops.py ├── utils/ │ ├── __init__.py │ └── print_utils.py └── main_orm_runner.py
16.3.3.1. Session: ORM 的数据库交互句柄 Session
是与数据库交互的接口,管理对象状态 (Unit of Work) 和事务。
文件 : sqlalchemy_orm_practice/core_config.py
知识点: 创建 : sessionmaker
(工厂) -> SessionLocal
(配置好的 Session 类) -> session = SessionLocal()
(实例)。配置 (sessionmaker
) : bind=engine
(必需), autocommit=False
(默认), autoflush=False
(推荐), expire_on_commit=True
(默认)。生命周期/作用域 : 短期存在,非线程安全。Web 应用模式:每请求一 Session 。严禁全局共享 Session 实例 。上下文管理 (with
) : with SessionLocal() as session:
自动管理 session.close()
。核心 Session 方法 (快速参考): 方法 用途说明 session.add(obj)
添加新对象实例 (标记为待 INSERT)。 session.add_all(list)
添加多个新对象实例。 session.delete(obj)
标记持久化对象为待 DELETE。 session.commit()
Flush 挂起更改并提交事务。 session.rollback()
回滚事务,撤销更改。 session.flush()
将更改同步到 DB (不提交事务),获取自增 ID 等。 session.get(Model, pk)
通过主键高效获取单个对象。 session.execute(stmt)
(2.0+) 执行 Core 语句 (ORM 查询主要方式)。
代码示例: Engine 和 Session 设置 (使用 PyMySQL)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 """ SQLAlchemy ORM 核心配置模块 此模块包含数据库连接引擎和会话的配置。 提供了创建数据库引擎和会话工厂的功能,用于整个应用程序的数据库交互。 主要组件: - engine: 数据库连接引擎 - SessionLocal: 本地会话工厂,用于创建数据库会话 """ from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from utils.print_utils import print_info, print_success, print_error DB_USER = "root" DB_PASSWORD = "root" DB_HOST = "localhost" DB_PORT = 3306 DB_NAME = "sqlalchemy_orm_db" DB_CHARSET = "utf8mb4" ORM_DATABASE_URL = f"mysql+pymysql://{DB_USER} :{DB_PASSWORD} @{DB_HOST} :{DB_PORT} /{DB_NAME} ?charset={DB_CHARSET} " print_info(f"ORM 数据库连接 URL: {ORM_DATABASE_URL.replace(DB_PASSWORD, '******' )} " ) orm_engine = create_engine( ORM_DATABASE_URL, echo=True , future=True , pool_size=10 , max_overflow=20 , pool_recycle=3600 , ) print_success("ORM 数据库引擎创建成功" ) SessionLocal = sessionmaker( autocommit=False , autoflush=False , bind=orm_engine, expire_on_commit=True ) print_success("ORM 会话工厂创建成功" )
16.3.3.2. 声明式映射 (Declarative Mapping): 定义模型 使用带类型注解的 Python 类映射数据库表 (SQLAlchemy 2.0 风格)。
代码示例: Base 和 Models 定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 """ SQLAlchemy ORM 基础模型模块 此模块定义了应用程序中所有 ORM 模型的基类和 Mixin 类。 主要组件: - DeclarativeBase: 所有 ORM 模型的基类 - 各种可复用的 Mixin 类,如 TimestampMixin(提供时间戳功能) """ from sqlalchemy import Integer, DateTime, Boolean, TIMESTAMP, func, MetaDatafrom sqlalchemy.orm import DeclarativeBase, Mapped, mapped_columnfrom datetime import datetimefrom typing import Optional from utils.print_utils import print_successclass Base (DeclarativeBase ): """所有ORM 模型的基类""" pass class TimestampMixin : """提供时间戳功能的 Mixin 类""" create_time: Mapped[Optional [datetime]] = mapped_column( TIMESTAMP, server_default=func.now(), comment="创建时间" ) update_time: Mapped[Optional [datetime]] = mapped_column( TIMESTAMP, server_default=func.now(), onupdate=func.now(), comment="更新时间" ) class AbstractBaseModel (Base ): """包含 通用 ID 和 逻辑删除标记的抽象基类""" __abstract__ = True id : Mapped[int ] = mapped_column( Integer, primary_key=True , autoincrement=True , comment="主键ID" ) is_deleted: Mapped[bool ] = mapped_column(Boolean, server_default="0" , index=True , comment="逻辑删除标记" ) print_success("ORM Base, Mixin, AbstractBaseModel 已定义 (models/base.py)。" )
16.3.3.3. 定义关系 (relationship) 在模型间建立关联,映射数据库外键。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 """ SQLAlchemy ORM 模型定义模块 此模块包含具体的 ORM 模型类定义,用于映射到数据库表。 主要模型: - CategoryORM: 映射到 categories 表的分类模型 - ProductORM: 映射到 products 表的产品模型,与 CategoryORM 有关联关系 """ from sqlalchemy import String, Integer, Numeric, ForeignKey, Booleanfrom sqlalchemy.orm import Mapped, mapped_column, relationship from typing import Optional , List from .base import AbstractBaseModel, TimestampMixin, Base from utils.print_utils import print_success class CategoryORM (AbstractBaseModel, TimestampMixin): """产品类型 ORM 模型""" __tablename__ = "categories_orm" name: Mapped[str ] = mapped_column( String(100 ), nullable=False , index=True , unique=True , comment="类别名称" ) description: Mapped[Optional [str ]] = mapped_column( String(255 ), nullable=True , comment="类别描述" ) products: Mapped[List ["ProductORM" ]] = relationship( "ProductORM" , back_populates="category" , cascade="all, delete-orphan" , lazy="selectin" ) def __repr__ (self ): """打印对象信息""" return f"<CategoryORM(id={self.id } , name='{self.name} ')>" class ProductORM (AbstractBaseModel, TimestampMixin): """产品 ORM 模型""" __tablename__ = "products_orm" name: Mapped[str ] = mapped_column( String(100 ), index=True , comment="产品名称" , ) price: Mapped[float ] = mapped_column( Numeric(10 , 2 ), nullable=False , comment="产品价格" ) stock: Mapped[int ] = mapped_column( Integer, default=0 , server_default="0" , nullable=False , comment="库存数量" ) is_available: Mapped[bool ] = mapped_column( Boolean, default=True , server_default="1" , nullable=False , comment="是否上架" ) category_id: Mapped[Optional [int ]] = mapped_column( Integer, ForeignKey("categories_orm.id" , ondelete="SET NULL" ), nullable=True , index=True , comment="类别所属ID" ) category: Mapped[Optional ["CategoryORM" ]] = relationship( "CategoryORM" , back_populates="products" , lazy="joined" ) def __repr__ (self ): return f"<ProductORM(id={self.id } , name='{self.name} ', price={self.price} )>" print_success("具体模型 CategoryORM, ProductORM (含关系占位) 已定义 (models/orm_models.py)。" )
16.3.3.4. ORM CRUD 操作 (使用 Session 和 select) (2.0+) 主要通过 session.execute()
结合 select()
, update()
, delete()
语句,或直接操作 Session 管理的对象。
文件 : sqlalchemy_orm_practice/crud/orm_crud_ops.py
代码示例 : 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 """ SQLAlchemy ORM CRUD 操作模块 此模块演示了 SQLAlchemy ORM 的基本 CRUD(创建、读取、更新、删除)操作。 主要功能: - 创建: 创建新的类别和产品记录 - 读取: 按 ID、条件筛选和排序查询记录 - 更新: 更新现有记录的属性 - 删除: 从数据库中删除记录 """ from sqlalchemy import select, update, delete, func from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session from core_config import SessionLocal from models.orm_models import CategoryORM, ProductORM from utils.print_utils import print_header, print_subheader, print_info, print_success, print_error, print_sql, \ print_warning, \ print_result_item from typing import Optional class OrmCrudOps : """封装 ORM CRUD 示例的类""" def __init__ (self, session: Session ): """依赖注入 Session 对象""" self .session = session self .last_category_id: Optional [int ] = None self .last_product_id: Optional [int ] = None def create_objects (self ) -> tuple [Optional [int ], Optional [int ]]: """创建Category和Product对象并插入到数据库中""" print_subheader("1. 创建 ORM 对象 (INSERT)" ) try : with self .session.begin_nested(): category = CategoryORM(name="书籍类" , description="书籍类产品...包含各类图书" ) self .session.add(category) self .session.flush() self .last_category_id = category.id product = ProductORM(name="Python 编程指南" , price=39.99 , category_id=category.id ) self .session.add(product) self .session.flush() self .last_product_id = product.id print_success("创建 ORM 对象成功。" ) return self .last_category_id, self .last_product_id except SQLAlchemyError as e: print_error(f"创建 ORM 对象失败: {e} " ) return None , None def query_object (self ): print_subheader("2. 查询 ORM 对象 (SELECT)" ) try : print_info(f"\n使用 session.get() 查询 Category ID={self.last_category_id} :" ) category = self .session.get(CategoryORM, self .last_category_id) if category: print_success(f"查询结果: {category} " ) else : print_warning(f"查询结果: 未找到 ID={self.last_category_id} 的 Category 对象。" ) print_info("\n条件查询 Product (name like '%Python%'):" ) stmt_find = select(ProductORM).where(ProductORM.name.like("%Python%" )).limit(1 ) product = self .session.execute(stmt_find).scalars().first() if product: print_success(f"查询结果: {product} " ) else : print_warning("查询结果: 未找到符合条件的 Product 对象。" ) except SQLAlchemyError as e: print_error(f"查询失败: {e} " ) def update_object (self ): print_subheader("3. 更新 ORM 对象 (UPDATE)" ) try : with self .session.begin_nested(): print_info(f"\n修改 Product ID={self.last_product_id} 的价格" ) product_to_update = self .session.get(ProductORM,self .last_product_id) if product_to_update: product_to_update.price = 88.88 else :print_warning(f"查询结果: 未找到 ID={self.last_product_id} 的 Product 对象。" ) print_info("\n批量降低 '书籍类' 的库存" ) stmt_bulk_update = update(ProductORM).where( ProductORM.category.has(CategoryORM.name == "书籍类" ) ).values( stock = ProductORM.stock - 5 ).execution_options(synchronize_session="fetch" ) result = self .session.execute(stmt_bulk_update) print_success(f"受影响的行数: {result.rowcount} " ) print_success("更新操作已暂存 (待外层 Commit)。" ) except SQLAlchemyError as e:print_error(f"更新失败: {e} " ) def delete_objects (self ): print_subheader("4. 删除 ORM 对象 (DELETE)" ) try : with self .session.begin_nested(): print_info(f"\n删除 Product ID={self.last_product_id} :" ) product_to_del = self .session.get(ProductORM, self .last_product_id) if product_to_del: self .session.delete(product_to_del) print_success(f" ID={self.last_product_id} 已标记删除。" ) else : print_warning(f" 未找到 ID={self.last_product_id} 。" ) print_info("\n批量删除类别为 NULL 的产品:" ) stmt_bulk_del = delete(ProductORM).where(ProductORM.category_id == None ) result = self .session.execute(stmt_bulk_del) print_success(f" 批量删除影响了 {result.rowcount} 行。" ) print_success(" 删除操作已暂存 (待外层 Commit)。" ) except Exception as e: print_error(f"删除对象时出错: {e} " ) raise
16.3.3.5 运行示例测试 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 """ SQLAlchemy ORM 主运行脚本 此脚本作为项目的入口点,创建数据库表并运行各个示例模块的功能演示。 主要功能: - 初始化数据库: 创建所有所需的表 - 运行示例: 按顺序调用各个模块的示例功能 - 演示工作流: 展示完整的 ORM 使用流程和最佳实践 """ from sqlalchemy import func, inspectfrom utils.print_utils import print_header, print_info, print_success, print_error, print_warningfrom core_config import SessionLocal, orm_enginefrom models.base import Basefrom models.orm_models import CategoryORM, ProductORMfrom crud.orm_crud_ops import OrmCrudOpsdef create_tables (): """创建数据库表""" print_header("创建数据库表" ) try : Base.metadata.create_all(bind=orm_engine) print_success("所有 ORM 表已创建 (如果尚不存在)。" ) inspector = inspect(orm_engine) tables = inspector.get_table_names() print_info(f"数据库中的表: {tables} " ) except Exception as e: print_error(f"创建表时出错: {e} " ) raise def run_crud_examples (): """运行 CRUD 操作示例""" print_header("运行 ORM CRUD 操作示例" ) with SessionLocal() as session: crud_executor = OrmCrudOps(session) try : category_id, product_id = crud_executor.create_objects() if not category_id or not product_id: print_warning("创建对象失败,跳过后续操作。" ) return crud_executor.query_object() crud_executor.update_object() session.commit() print_success("\n所有操作已成功提交!" ) except Exception as e: print_error(f"\n执行 CRUD 操作时出错: {e} " ) session.rollback() print_info("所有操作已回滚。" ) if __name__ == "__main__" : run_crud_examples() print_header("程序执行完毕" )
16.4 mybatis-py: 轻量级 Python SQL 映射器 mybatis-py
是一个为 Python 开发者设计的轻量级 SQL 映射框架,其核心理念与 Java 领域知名的 MyBatis 框架相近。它旨在提供一种方式,让开发者能够更直接地控制 SQL 语句的编写和执行,同时又通过映射机制简化 Python 代码与数据库之间的交互。对于既想利用 SQL 的全部能力进行性能调优或处理复杂逻辑,又希望避免直接操作原生数据库驱动(如 pymysql
)时繁琐的模板代码的场景,mybatis-py
是一个值得考虑的工具。
16.4.1 核心知识 核心定位 : 半自动化 ORM / SQL 映射器 (强调 SQL 控制权)。
主要功能特性概览:
特性编号 功能描述 详细说明与开发者价值 1 半自动化的 ORM 提供 Python 方法到 SQL 语句的映射,以及结果集到 Python 字典或简单对象的转换,简化数据访问代码,但开发者仍需编写 SQL。 2 支持动态 SQL 核心特性之一。允许在 XML Mapper 文件中使用 <if>
, <foreach>
, <where>
, <set>
等标签,根据传入参数动态构建和调整 SQL 语句,实现复杂查询逻辑。 3 装饰器 API 提供了类似 MyBatis 注解的 Python 装饰器 (@mb.SelectOne
, @mb.Insert
等),可直接在 Python 方法上绑定 SQL 语句,适用于简单、固定的 SQL 操作。 4 LRU 缓存及过期机制 内置基于 LRU (Least Recently Used) 算法的查询缓存,可配置缓存池大小和条目过期时间,对不经常变动但查询频繁的数据能有效提升性能。 5 Prepared Statement 支持 当使用 #{placeholder}
语法时,优先使用预编译语句,将 SQL 结构与数据分离,这是防御 SQL 注入攻击的关键手段。 6 预防大对象机制 (OOM) 内置机制(通过 max_result_bytes
参数)限制从数据库拉取并处理的数据总量,旨在避免因查询结果集过大导致的应用程序内存溢出问题。 7 多数据库支持 目前明确支持 MySQL 和 PostgreSQL,可以通过 ConnectionFactory
指定不同的 dbms_name
。
**安装 mybatis-py
**
数据库及表示例准备 (MySQL): (确保 MySQL 服务已启动,并已创建相应数据库、用户及授权。以下 SQL 用于创建本节演示所需的库和表。)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 CREATE DATABASE IF NOT EXISTS mybatis_py_complete_demo CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;CREATE USER IF NOT EXISTS 'mb_user_complete' @'localhost' IDENTIFIED BY 'your_strong_password' ;GRANT ALL PRIVILEGES ON mybatis_py_complete_demo.* TO 'mb_user_complete' @'localhost' ;FLUSH PRIVILEGES; USE mybatis_py_complete_demo; CREATE TABLE IF NOT EXISTS fruit_categories ( id INT AUTO_INCREMENT PRIMARY KEY COMMENT '类别ID (主键)' , name VARCHAR (100 ) NOT NULL UNIQUE COMMENT '类别名称 (唯一)' , description TEXT COMMENT '类别描述' ) ENGINE= InnoDB DEFAULT CHARSET= utf8mb4 COMMENT= '水果类别表' ; CREATE TABLE IF NOT EXISTS fruits ( id INT AUTO_INCREMENT PRIMARY KEY COMMENT '水果ID (主键)' , name VARCHAR (100 ) NOT NULL COMMENT '水果名称' , category_id INT COMMENT '类别ID (外键), 允许为NULL表示未分类' , price INT COMMENT '价格 (单位:分,使用整数存储以避免浮点精度问题)' , description TEXT COMMENT '水果描述 (可选)' , FOREIGN KEY (category_id) REFERENCES fruit_categories(id) ON DELETE SET NULL ON UPDATE CASCADE ) ENGINE= InnoDB DEFAULT CHARSET= utf8mb4 COMMENT= '水果信息表' ; SET FOREIGN_KEY_CHECKS= 0 ; TRUNCATE TABLE fruits;TRUNCATE TABLE fruit_categories;SET FOREIGN_KEY_CHECKS= 1 ; INSERT INTO fruit_categories (name, description) VALUES ('温带水果' , '如苹果、梨等' ), ('热带水果' , '如香蕉、芒果、菠萝等' ), ('浆果类' , '如草莓、蓝莓、树莓等' ); INSERT INTO fruits (name, category_id, price, description) VALUES ('红富士苹果' , (SELECT id FROM fruit_categories WHERE name= '温带水果' ), 750 , '脆甜多汁的红富士苹果' ), ('香芽蕉' , (SELECT id FROM fruit_categories WHERE name= '热带水果' ), 420 , '口感软糯的香芽蕉' ), ('奶油草莓' , (SELECT id FROM fruit_categories WHERE name= '浆果类' ), 1200 , '大颗香甜的奶油草莓' );
建议项目结构 (用于 mybatis-py
示例):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 db_framework_practice/ ├── examples/ │ ├── mybatis_py_v_final/ # 为当前完整重写版本创建新目录 │ │ ├── __init__.py │ │ ├── config.py # 数据库连接参数 │ │ ├── ex01_decorators_main.py # 装饰器使用示例主程序 │ │ ├── fruit_repo_deco.py # 使用装饰器的水果仓库类 │ │ ├── mappers/ # 存放 XML Mapper 文件 │ │ │ ├── __init__.py # 使 mappers 成为一个包 │ │ │ └── fruits_mapper.xml # 使用动态sql的核心模块 │ │ └── ex02_xml_mapper_main.py # XML Mapper 使用示例主程序 │ │ └── ex03_flask_app_main.py # Flask 集成示例 (在后续提供) ├── utils/ │ ├── __init__.py │ └── print_utils.py # 打印工具模块 └── ...
数据库配置文件: examples/mybatis_py_v_final/config.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 """ 数据库连接配置模块 此模块包含数据库连接的配置参数,用于建立与数据库的连接。 """ DB_CONFIG = { 'dbms_name' : 'mysql' , 'host' : 'localhost' , 'port' : 3306 , 'user' : 'root' , 'password' : 'root' , 'database' : 'mybatis_demo' , 'charset' : 'utf8mb4' } MYBATIS_CONFIG = { "cache_memory_limit" : 10 * 1024 * 1024 , "cache_max_live_ms" : 5 * 60 * 1000 , "max_result_bytes" : 50 * 1024 * 1024 }
Mybatis
类核心 API (基于提供的源码分析)
以下是对 mybatis-py
源码中 Mybatis
类的核心构造函数、主要方法及其参数的总结。
Mybatis.__init__(...)
构造函数参数
参数 类型 描述 源码默认/说明 conn
AbstractConnection
必需 . 已建立的数据库连接对象 (通常通过 ConnectionFactory.get_connection()
获取)。- mapper_path
str
必需 . XML Mapper 文件所在的目录路径或 Python 包路径 (例如 "mappers"
或 "your_package.mappers"
)。库会在此路径下查找并加载所有 .xml
后缀的 Mapper 文件。 cache_memory_limit
Optional[int]
可选。缓存的内存限制 (字节)。如果为 None
,源码中 Cache
对象以 0
初始化,可能表示不启用或使用不同逻辑。 默认: None
(源码中 Cache(0,…)) cache_max_live_ms
int
可选。缓存条目的最大存活时间 (毫秒)。 默认: 5 * 1000
(即 5 秒) max_result_bytes
int
可选。select_many
方法返回的结果列表允许占用的最大总字节数,用于防止 OOM。 默认: 100 * 1024 * 1024
(即 100MB)
XML Mapper 调用方法 (Mybatis 实例方法)
方法 参数 返回类型 用途说明 select_one(id: str, params: dict)
id
(XML中语句的 namespace.id
或全局唯一 id
) <br> params
(传递给SQL的参数字典)Optional[Dict]
执行 <select>
语句,预期返回单条记录 (字典) 或 None
。会使用缓存。 select_many(id: str, params: dict)
id
(XML中语句的 namespace.id
或全局唯一 id
) <br> params
(传递给SQL的参数字典)Optional[List[Dict]]
执行 <select>
语句,返回多条记录 (字典列表) 或 None
(若无结果)。会使用缓存及 max_result_bytes
限制。 update(id: str, params: dict)
id
(XML中语句的 namespace.id
或全局唯一 id
) <br> params
(传递给SQL的参数字典)int
执行 <update>
语句,返回受影响的行数 (cursor.rowcount()
)。会清空整个缓存 。 delete(id: str, params: dict)
id
(XML中语句的 namespace.id
或全局唯一 id
) <br> params
(传递给SQL的参数字典)int
执行 <delete>
语句,返回受影响的行数 (cursor.rowcount()
)。会清空整个缓存 。 insert(id: str, params: dict, primary_key: str = None)
id
(XML中语句的 namespace.id
或全局唯一 id
) <br> params
(参数字典) <br> primary_key
(可选,用于 PostgreSQL 的 RETURNING
子句,指定主键列名)int
执行 <insert>
语句。返回值是 cursor.lastrowid()
(通常用于获取 MySQL 自增ID)。会清空整个缓存 。params
字典可能被 XML 中的 keyProperty
修改。
装饰器方法 (@mb.DecoratorName
)
装饰器方法 (@mb.<Name>
) 装饰器参数 装饰的函数签名 (示例) 用途及内部行为说明 SelectOne(unparsed_sql)
unparsed_sql
(原始SQL字符串)def func(**kwargs) -> Optional[Dict]
将函数映射到给定的 SELECT
SQL,预期返回单条记录。kwargs
作为参数传递给 SQL (#{key}
占位符)。内部处理 SQL 解析、参数绑定、缓存、结果转换。 SelectMany(unparsed_sql)
unparsed_sql
(原始SQL字符串)def func(**kwargs) -> Optional[List[Dict]]
将函数映射到给定的 SELECT
SQL,返回多条记录。处理同上,并应用 max_result_bytes
。 Insert(unparsed_sql, primary_key=None)
unparsed_sql
(SQL), primary_key
(主键列名)def func(**kwargs) -> int
(通常返回 lastrowid
)将函数映射到 INSERT
SQL。primary_key
主要用于 PostgreSQL 的 RETURNING
。MySQL 中通常返回 lastrowid
。清空缓存 。 Update(unparsed_sql)
unparsed_sql
(原始SQL字符串)def func(**kwargs) -> int
(返回受影响行数)将函数映射到 UPDATE
SQL,返回受影响行数。清空缓存 。 Delete(unparsed_sql)
unparsed_sql
(原始SQL字符串)def func(**kwargs) -> int
(返回受影响行数)将函数映射到 DELETE
SQL,返回受影响行数。清空缓存 。
Workspace_rows(cursor, batch_size=1000)
辅助函数 (内部) :
在 Mybatis.select_many
和装饰器 SelectMany
内部使用,用于分批从数据库游标获取数据。 将每批获取的行数据 (通常是元组) 转换为字典列表 (键为列名)。 使用 yield
以生成器方式逐条返回字典,这使得调用方 (如 select_many
) 可以在迭代过程中检查 max_result_bytes
限制。 缓存键 (CacheKey(sql, param_list)
) :
用于在缓存中唯一标识一个查询。 它基于经过内部处理和参数绑定后的 SQL 语句 (sql
) 和实际绑定的参数列表 (param_list
) 生成。 确保 SQL 相同但参数不同的查询会有不同的缓存条目。 方法一: 使用装饰器 API 适用于 SQL 语句相对固定且逻辑简单的场景。
文件 : db_framework_practice/examples/mybatis_py_v2/ex01_decorators_usage.py
(封装操作)1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 """ 装饰器使用示例模块 # db_framework_practice/examples/mybatis_py_v2/ex01_decorators_usage.py 此模块演示了如何使用MyBatis-Py V2的装饰器功能进行数据库操作。 """ from mybatis import Mybatis, ConnectionFactoryfrom config import DB_CONFIGfrom typing import List , Dict , Optional , Any from utils.print_utils import *class FruitRepository : """水果数据仓库类,封装对水果表的所有数据库操作""" def __init__ (self, mappers_dir="mappers" , cache_size=50 *1024 *1024 ): """ 初始化水果仓库 Args: mappers_dir: XML映射文件目录 cache_size: 缓存大小,默认50MB """ self .connection_factory = ConnectionFactory.get_connection(**DB_CONFIG) self .mybatis = Mybatis(self .connection_factory, mappers_dir, cache_memory_limit=cache_size) self ._init_operations() def _init_operations (self ): """初始化数据库操作方法""" @self.mybatis.SelectOne("SELECT * FROM fruits WHERE id=#{id}" ) def get_one (id : int ) -> Optional [Dict [str , Any ]]: """获取一个水果记录,返回一个字典""" pass @self.mybatis.SelectMany("SELECT * FROM fruits WHERE category_id = #{category_id}" ) def get_many (category_id: int ) -> List [Dict [str , Any ]]: """获取多个水果记录,返回一个列表""" pass @self.mybatis.Insert( "INSERT INTO fruits (name, category_id, price, description) VALUES (#{name}, #{category_id}, #{price}, #{description})" , primary_key="id" ) def insert (name: str , category_id: int , price: float , description: str ) -> int : """插入水果记录,返回插入的ID""" pass @self.mybatis.Delete("DELETE FROM fruits WHERE id = #{id}" ) def delete (id : int ) -> int : """删除水果记录,返回删除的行数""" pass @self.mybatis.Update("UPDATE fruits SET name=#{name}, category_id=#{category_id}, price=#{price}, description=#{description} WHERE id=#{id}" ) def update (id : int , name: str , category_id: int , price: float , description: str ) -> int : """ 更新水果记录,返回更新的行数 """ pass self .get_one = get_one self .get_many = get_many self .insert = insert self .delete = delete self .update = update def close (self ): """关闭数据库连接""" if hasattr (self , 'connection_factory' ) and self .connection_factory: self .connection_factory.close() def main (): """主函数,用于测试FruitRepository类的功能""" print_header("Mybatis-Py: 装饰器 API 使用示例 (面向对象版)" ) repo = FruitRepository() try : print_info("1. 测试查询单个水果" ) fruit = repo.get_one(id =1 ) if fruit: print_success(f"查询到水果: {fruit['name' ]} , 价格: {fruit['price' ]} " ) else : print_warning("未找到ID为1的水果" ) print_info("\n2. 测试查询某分类下的所有水果" ) category_id = 1 fruits = repo.get_many(category_id=category_id) print_success(f"分类 {category_id} 下有 {len (fruits)} 个水果:" ) for fruit in fruits: print_info(f" - {fruit['name' ]} : ¥{fruit['price' ]} " ) print_info("\n3. 测试插入新水果" ) new_fruit = { "name" : "蓝莓" , "category_id" : 2 , "price" : 25.5 , "description" : "新鲜蓝莓,富含抗氧化物质" } new_id = repo.insert(**new_fruit) print_success(f"插入成功,新水果ID: {new_id} " ) print_info("\n4. 测试更新水果" ) update_fruit = { "id" : new_id, "name" : "有机蓝莓" , "category_id" : 2 , "price" : 28.5 , "description" : "有机认证蓝莓,无农药" } affected_rows = repo.update(**update_fruit) print_success(f"更新了 {affected_rows} 行记录" ) updated_fruit = repo.get_one(id =new_id) if updated_fruit: print_success(f"更新后的水果: {updated_fruit['name' ]} , 价格: {updated_fruit['price' ]} " ) print_info("\n5. 测试删除水果" ) deleted_rows = repo.delete(id =new_id) print_success(f"删除了 {deleted_rows} 行记录" ) except Exception as e: print_error(f"操作失败: {str (e)} " ) finally : repo.close() print_info("\n测试完成,已关闭数据库连接" ) if __name__ == "__main__" : main()
方法二: 使用 XML Mapper 文件 对于包含动态逻辑或较为复杂的 SQL 语句,XML Mapper 文件提供了更强大的表达能力。
1. XML Mapper 文件 (ddb_framework_practice/examples/mybatis_py_v2/mappers/fruits_mapper.xml
)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 <?xml version="1.0" encoding="UTF-8" ?> <!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd" > <mapper namespace ="fruits" > <insert id ="insertFruit" useGeneratedKeys ="true" keyProperty ="generated_id_xml" keyColumn ="id" > INSERT INTO fruits (name, category_id, price, description) VALUES ( #{name}, #{category_id}, #{price}, #{description} ) </insert > <delete id ="deleteFruitById" > DELETE FROM fruits WHERE id = #{id} </delete > <update id ="updateFruit" > UPDATE fruits SET name = #{name}, category_id = #{category_id}, price = #{price}, description = #{description} WHERE id = #{id} </update > <select id ="findFruitDetailsById" resultType ="dict" > SELECT f.id, f.name, f.category_id, fc.name as category_name, f.price, f.description FROM fruits as f LEFT JOIN FRUIT_CATEGORIES FC ON f.category_id = fc.id WHERE f.id = #{id} </select > <select id ="findFruitsByCriteria" resultType ="dict" > SELECT f.id, f.name, f.category_id, fc.name as category_name, f.price, f.description FROM fruits as f LEFT JOIN fruit_categories FC ON f.category_id = fc.id <where > <if test ="'name' in params" > f.name LIKE CONCAT('%', #{name}, '%') </if > </where > </select > </mapper >
2. Python 代码 (db_framework_practice/examples/mybatis_py_v2/ex02_xml_mapper_usage.py
)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 from mybatis import Mybatis, ConnectionFactoryfrom typing import List , Dict , Optional , Any from utils.print_utils import *from config import DB_CONFIGclass FruitRepoXmlOperational : """使用 XML Mapper 操作水果数据的数据仓库类""" def __init__ (self, mappers_dir="mappers" , cache_size=0 ): """ 初始化水果仓库 Args: mappers_dir: XML映射文件目录 cache_size: 缓存大小,默认50MB """ self .connection_factory = ConnectionFactory.get_connection(**DB_CONFIG) self .mybatis = Mybatis(self .connection_factory, mappers_dir, cache_memory_limit=cache_size) def insert_fruit (self, fruit: Dict [str , Any ] ) -> int : """插入水果数据(使用xml mapper)""" return self .mybatis.insert("fruits.insertFruit" , fruit) def delete_fruit (self, fruit_id: int ) -> int : """删除水果数据(使用xml mapper)""" delete_params = {"id" : fruit_id} return self .mybatis.delete("fruits.deleteFruitById" , delete_params) def update_fruit (self, fruit: Dict [str , Any ] ) -> int : """更新水果数据(使用xml mapper)""" result = self .mybatis.update("fruits.updateFruit" , fruit) if result > 0 : print_success(f"更新水果数据成功,影响行数: {result} " ) else : print_error(f"更新水果数据失败,影响行数: {result} " ) return result def find_fruit_details_by_id (self, fruit_id: int ) -> Dict [str , Any ]: """根据ID查询水果详细信息(使用xml mapper)""" return self .mybatis.select_one("fruits.findFruitDetailsById" , {"id" : fruit_id}) def find_fruits_by_criteria (self, params: Dict [str , Any ] ) -> List [Dict [str , Any ]]: """根据条件查询水果列表(使用xml mapper) Args: params: 查询参数字典,可包含以下键: - category_id: 分类ID - min_price: 最低价格 - max_price: 最高价格 - name: 水果名称(支持模糊查询) - sort_by: 排序字段,如'price' - sort_order: 排序方向,如'DESC' - limit: 返回记录数量限制 - offset: 分页偏移量 Returns: 符合条件的水果列表 """ return self .mybatis.select_many("fruits.findFruitsByCriteria" , params) if __name__ == '__main__' : repo = FruitRepoXmlOperational() print_header("测试查询水果列表" ) params = { "name" : "苹" , } fruits = repo.find_fruits_by_criteria(params) for fruit in fruits: print_info(f"水果ID: {fruit['id' ]} , 名称: {fruit['name' ]} , 价格: {fruit['price' ]} " )
16.4.2 实战 (Flask Integration) 将 mybatis-py
集成到 Flask Web 应用中,提供 API 接口。