From 73bd45f3dab1975194b4b25252971f0558a8a8af Mon Sep 17 00:00:00 2001 From: David Levy Date: Sat, 7 Feb 2026 13:39:40 -0600 Subject: [PATCH] FIX: Extract SqlTypeCode into own module, eliminate circular import workaround --- CHANGELOG.md | 11 +- mssql_python/__init__.py | 1 + mssql_python/connection.py | 53 +++-- mssql_python/constants.py | 5 + mssql_python/cursor.py | 96 +++----- mssql_python/mssql_python.pyi | 44 +++- mssql_python/pybind/ddbc_bindings.cpp | 9 + mssql_python/type_code.py | 115 ++++++++++ tests/test_002_types.py | 164 +++++++++++++ tests/test_004_cursor.py | 318 ++++++++++++++++++++++++-- tests/test_018_polars_integration.py | 240 +++++++++++++++++++ 11 files changed, 956 insertions(+), 100 deletions(-) create mode 100644 mssql_python/type_code.py create mode 100644 tests/test_018_polars_integration.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 517a60bfc..7eda42a95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,22 +7,31 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased] ### Added + - New feature: Support for macOS and Linux. - Documentation: Added API documentation in the Wiki. +- Added `SqlTypeCode` class for dual-compatible type codes in `cursor.description`. ### Changed + - Improved error handling in the connection module. +- Enhanced `cursor.description[i][1]` to return `SqlTypeCode` objects that compare equal to both SQL type integers and Python types, improving backwards compatibility while aligning with DB-API 2.0. Note that `SqlTypeCode` instances are intentionally unhashable; code that previously used `cursor.description[i][1]` as a dict or set key should use `int(type_code)` or `type_code.type_code` instead. ### Fixed + - Bug fix: Resolved issue with connection timeout. +- Fixed `cursor.description` type handling for better DB-API 2.0 compliance (Issue #352). ## [1.0.0-alpha] - 2025-02-24 ### Added + - Initial release of the mssql-python driver for SQL Server. ### Changed + - N/A ### Fixed -- N/A \ No newline at end of file + +- N/A diff --git a/mssql_python/__init__.py b/mssql_python/__init__.py index 2bcac47bb..ff02a4fbb 100644 --- a/mssql_python/__init__.py +++ b/mssql_python/__init__.py @@ -60,6 +60,7 @@ # Cursor Objects from .cursor import Cursor +from .type_code import SqlTypeCode # Logging Configuration (Simplified single-level DEBUG system) from .logging import logger, setup_logging, driver_logger diff --git a/mssql_python/connection.py b/mssql_python/connection.py index ba79e2a3f..ad6c9457f 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -45,6 +45,8 @@ from mssql_python.connection_string_builder import _ConnectionStringBuilder from mssql_python.constants import _RESERVED_PARAMETERS +from mssql_python.type_code import SqlTypeCode + if TYPE_CHECKING: from mssql_python.row import Row @@ -923,7 +925,9 @@ def cursor(self) -> Cursor: logger.debug("cursor: Cursor created successfully - total_cursors=%d", len(self._cursors)) return cursor - def add_output_converter(self, sqltype: int, func: Callable[[Any], Any]) -> None: + def add_output_converter( + self, sqltype: Union[int, SqlTypeCode, type], func: Callable[[Any], Any] + ) -> None: """ Register an output converter function that will be called whenever a value with the given SQL type is read from the database. @@ -936,32 +940,39 @@ def add_output_converter(self, sqltype: int, func: Callable[[Any], Any]) -> None vulnerabilities. This API should never be exposed to untrusted or external input. Args: - sqltype (int): The integer SQL type value to convert, which can be one of the - defined standard constants (e.g. SQL_VARCHAR) or a database-specific - value (e.g. -151 for the SQL Server 2008 geometry data type). + sqltype (int, SqlTypeCode, or type): The SQL type value to convert. + Also accepts SqlTypeCode objects or Python types for backward compatibility. func (callable): The converter function which will be called with a single parameter, the value, and should return the converted value. If the value is NULL - then the parameter passed to the function will be None, otherwise it - will be a bytes object. + then the parameter passed to the function will be None. For string/binary + columns, the value will be bytes (UTF-16LE encoded for strings). For other + types (int, decimal.Decimal, datetime, etc.), the value will be the native + Python object. Returns: None """ + if isinstance(sqltype, SqlTypeCode): + sqltype = sqltype.type_code with self._converters_lock: self._output_converters[sqltype] = func # Pass to the underlying connection if native implementation supports it - if hasattr(self._conn, "add_output_converter"): + # Only forward int type codes to native layer; Python type keys are handled + # only in our Python-side dictionary + if isinstance(sqltype, int) and hasattr(self._conn, "add_output_converter"): self._conn.add_output_converter(sqltype, func) logger.info(f"Added output converter for SQL type {sqltype}") - def get_output_converter(self, sqltype: Union[int, type]) -> Optional[Callable[[Any], Any]]: + def get_output_converter( + self, sqltype: Union[int, SqlTypeCode, type] + ) -> Optional[Callable[[Any], Any]]: """ Get the output converter function for the specified SQL type. Thread-safe implementation that protects the converters dictionary with a lock. Args: - sqltype (int or type): The SQL type value or Python type to get the converter for + sqltype (int, SqlTypeCode, or type): The SQL type value to get the converter for. Returns: callable or None: The converter function or None if no converter is registered @@ -970,27 +981,43 @@ def get_output_converter(self, sqltype: Union[int, type]) -> Optional[Callable[[ ⚠️ The returned converter function will be executed on database values. Only use converters from trusted sources. """ + original_sqltype = sqltype + if isinstance(sqltype, SqlTypeCode): + sqltype = sqltype.type_code with self._converters_lock: - return self._output_converters.get(sqltype) + result = self._output_converters.get(sqltype) + # Fallback: try python_type key for backward compatibility + if result is None and isinstance(original_sqltype, SqlTypeCode): + result = self._output_converters.get(original_sqltype.python_type) + return result - def remove_output_converter(self, sqltype: Union[int, type]) -> None: + def remove_output_converter(self, sqltype: Union[int, SqlTypeCode, type]) -> None: """ Remove the output converter function for the specified SQL type. Thread-safe implementation that protects the converters dictionary with a lock. Args: - sqltype (int or type): The SQL type value to remove the converter for + sqltype (int, SqlTypeCode, or type): The SQL type value to remove the converter for. Returns: None """ + python_type_key = None + if isinstance(sqltype, SqlTypeCode): + python_type_key = sqltype.python_type + sqltype = sqltype.type_code with self._converters_lock: if sqltype in self._output_converters: del self._output_converters[sqltype] # Pass to the underlying connection if native implementation supports it - if hasattr(self._conn, "remove_output_converter"): + # Only forward int type codes to native layer; Python type keys are handled + # only in our Python-side dictionary + if isinstance(sqltype, int) and hasattr(self._conn, "remove_output_converter"): self._conn.remove_output_converter(sqltype) + # Symmetric with get_output_converter: also remove python_type key if present + if python_type_key is not None and python_type_key in self._output_converters: + del self._output_converters[python_type_key] logger.info(f"Removed output converter for SQL type {sqltype}") def clear_output_converters(self) -> None: diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 03d40c833..c24822760 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -114,7 +114,12 @@ class ConstantsDDBC(Enum): SQL_FETCH_ABSOLUTE = 5 SQL_FETCH_RELATIVE = 6 SQL_FETCH_BOOKMARK = 8 + # NOTE: The following SQL Server-specific type constants MUST stay in sync with + # the corresponding values in mssql_python/pybind/ddbc_bindings.cpp SQL_DATETIMEOFFSET = -155 + SQL_SS_TIME2 = -154 # SQL Server TIME(n) type + SQL_SS_UDT = -151 # SQL Server User-Defined Types (geometry, geography, hierarchyid) + SQL_SS_XML = -152 # SQL Server XML type SQL_C_SS_TIMESTAMPOFFSET = 0x4001 SQL_SCOPE_CURROW = 0 SQL_BEST_ROWID = 1 diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 3dd7aa283..d37eb1b45 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -20,6 +20,7 @@ from mssql_python.helpers import check_error from mssql_python.logging import logger from mssql_python import ddbc_bindings +from mssql_python.type_code import SqlTypeCode from mssql_python.exceptions import ( InterfaceError, NotSupportedError, @@ -142,6 +143,9 @@ def __init__(self, connection: "Connection", timeout: int = 0) -> None: ) self.messages = [] # Store diagnostic messages + # Store raw column metadata for converter lookups + self._column_metadata = None + def _is_unicode_string(self, param: str) -> bool: """ Check if a string contains non-ASCII characters. @@ -724,6 +728,14 @@ def _reset_cursor(self) -> None: logger.debug("SQLFreeHandle succeeded") self._clear_rownumber() + self._column_metadata = None + self.description = None + + # Clear any result-set-specific caches to avoid stale mappings + if hasattr(self, "_cached_column_map"): + self._cached_column_map = None + if hasattr(self, "_cached_converter_map"): + self._cached_converter_map = None # Reinitialize the statement handle self._initialize_cursor() @@ -756,6 +768,7 @@ def close(self) -> None: self.hstmt = None logger.debug("SQLFreeHandle succeeded") self._clear_rownumber() + self._column_metadata = None # Clear metadata to prevent memory leaks self.closed = True def _check_closed(self) -> None: @@ -942,8 +955,12 @@ def _initialize_description(self, column_metadata: Optional[Any] = None) -> None """Initialize the description attribute from column metadata.""" if not column_metadata: self.description = None + self._column_metadata = None # Clear metadata too return + # Store raw metadata for converter map building + self._column_metadata = column_metadata + description = [] for _, col in enumerate(column_metadata): # Get column name - lowercase it if the lowercase flag is set @@ -954,10 +971,13 @@ def _initialize_description(self, column_metadata: Optional[Any] = None) -> None column_name = column_name.lower() # Add to description tuple (7 elements as per PEP-249) + # Use SqlTypeCode for backwards-compatible type_code that works with both + # `desc[1] == str` (pandas) and `desc[1] == -9` (DB-API 2.0) + sql_type = col["DataType"] description.append( ( column_name, # name - self._map_data_type(col["DataType"]), # type_code + SqlTypeCode(sql_type), # type_code - dual compatible None, # display_size col["ColumnSize"], # internal_size col["ColumnSize"], # precision - should match ColumnSize @@ -975,6 +995,7 @@ def _build_converter_map(self): """ if ( not self.description + or not self._column_metadata or not hasattr(self.connection, "_output_converters") or not self.connection._output_converters ): @@ -982,17 +1003,20 @@ def _build_converter_map(self): converter_map = [] - for desc in self.description: - if desc is None: - converter_map.append(None) - continue - sql_type = desc[1] + for col_meta in self._column_metadata: + # Use the raw SQL type code from metadata, not the mapped Python type + sql_type = col_meta["DataType"] + python_type = SqlTypeCode._get_python_type(sql_type) converter = self.connection.get_output_converter(sql_type) - # If no converter found for the SQL type, try the WVARCHAR converter as a fallback + + # Fallback: If no converter found for SQL type code, try the mapped Python type. + # This provides backward compatibility for code that registered converters by Python type. if converter is None: - from mssql_python.constants import ConstantsDDBC + converter = self.connection.get_output_converter(python_type) - converter = self.connection.get_output_converter(ConstantsDDBC.SQL_WVARCHAR.value) + # Fallback: try SQL_WVARCHAR converter for str/bytes columns + if converter is None and python_type in (str, bytes): + converter = self.connection.get_output_converter(ddbc_sql_const.SQL_WVARCHAR.value) converter_map.append(converter) @@ -1022,41 +1046,6 @@ def _get_column_and_converter_maps(self): return column_map, converter_map - def _map_data_type(self, sql_type): - """ - Map SQL data type to Python data type. - - Args: - sql_type: SQL data type. - - Returns: - Corresponding Python data type. - """ - sql_to_python_type = { - ddbc_sql_const.SQL_INTEGER.value: int, - ddbc_sql_const.SQL_VARCHAR.value: str, - ddbc_sql_const.SQL_WVARCHAR.value: str, - ddbc_sql_const.SQL_CHAR.value: str, - ddbc_sql_const.SQL_WCHAR.value: str, - ddbc_sql_const.SQL_FLOAT.value: float, - ddbc_sql_const.SQL_DOUBLE.value: float, - ddbc_sql_const.SQL_DECIMAL.value: decimal.Decimal, - ddbc_sql_const.SQL_NUMERIC.value: decimal.Decimal, - ddbc_sql_const.SQL_DATE.value: datetime.date, - ddbc_sql_const.SQL_TIMESTAMP.value: datetime.datetime, - ddbc_sql_const.SQL_TIME.value: datetime.time, - ddbc_sql_const.SQL_BIT.value: bool, - ddbc_sql_const.SQL_TINYINT.value: int, - ddbc_sql_const.SQL_SMALLINT.value: int, - ddbc_sql_const.SQL_BIGINT.value: int, - ddbc_sql_const.SQL_BINARY.value: bytes, - ddbc_sql_const.SQL_VARBINARY.value: bytes, - ddbc_sql_const.SQL_LONGVARBINARY.value: bytes, - ddbc_sql_const.SQL_GUID.value: uuid.UUID, - # Add more mappings as needed - } - return sql_to_python_type.get(sql_type, str) - @property def rownumber(self) -> int: """ @@ -1369,6 +1358,7 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state except Exception as e: # pylint: disable=broad-exception-caught # If describe fails, it's likely there are no results (e.g., for INSERT) self.description = None + self._column_metadata = None # Reset rownumber for new result set (only for SELECT statements) if self.description: # If we have column descriptions, it's likely a SELECT @@ -1385,15 +1375,6 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state self._cached_column_map = None self._cached_converter_map = None - # After successful execution, initialize description if there are results - column_metadata = [] - try: - ddbc_bindings.DDBCSQLDescribeCol(self.hstmt, column_metadata) - self._initialize_description(column_metadata) - except Exception as e: - # If describe fails, it's likely there are no results (e.g., for INSERT) - self.description = None - self._reset_inputsizes() # Reset input sizes after execution # Return self for method chaining return self @@ -2425,6 +2406,7 @@ def nextset(self) -> Union[bool, None]: logger.debug("nextset: No more result sets available") self._clear_rownumber() self.description = None + self._column_metadata = None return False self._reset_rownumber() @@ -2444,6 +2426,7 @@ def nextset(self) -> Union[bool, None]: except Exception as e: # pylint: disable=broad-exception-caught # If describe fails, there might be no results in this result set self.description = None + self._column_metadata = None logger.debug( "nextset: Moved to next result set - column_count=%d", @@ -2788,12 +2771,7 @@ def rollback(self): self._connection.rollback() def __del__(self): - """ - Destructor to ensure the cursor is closed when it is no longer needed. - This is a safety net to ensure resources are cleaned up - even if close() was not called explicitly. - If the cursor is already closed, it will not raise an exception during cleanup. - """ + """Safety net to close cursor if close() was not called explicitly.""" if "closed" not in self.__dict__ or not self.closed: try: self.close() diff --git a/mssql_python/mssql_python.pyi b/mssql_python/mssql_python.pyi index dd3fd96a0..2774405fc 100644 --- a/mssql_python/mssql_python.pyi +++ b/mssql_python/mssql_python.pyi @@ -81,6 +81,33 @@ def TimeFromTicks(ticks: int) -> datetime.time: ... def TimestampFromTicks(ticks: int) -> datetime.datetime: ... def Binary(value: Union[str, bytes, bytearray]) -> bytes: ... +# SqlTypeCode - Dual-compatible type code for cursor.description +class SqlTypeCode: + """ + A type code that supports dual comparison with both SQL type integers and Python types. + + This class is used in cursor.description[i][1] to provide backwards compatibility + with libraries like pandas (which compare with Python types like str, int, float) + while also supporting DB-API 2.0 style integer type code comparisons. + + Examples: + >>> desc = cursor.description + >>> desc[0][1] == str # True if column is string type + >>> desc[0][1] == 12 # True if SQL_VARCHAR + >>> int(desc[0][1]) # Returns the SQL type code as integer + """ + + type_code: int + python_type: type + + def __init__(self, type_code: int) -> None: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... + def __int__(self) -> int: ... + __hash__: None # Unhashable; runtime raises TypeError with helpful message + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + # DB-API 2.0 Exception Hierarchy # https://www.python.org/dev/peps/pep-0249/#exceptions class Warning(Exception): @@ -133,7 +160,7 @@ class Row: description: List[ Tuple[ str, - Any, + Union[SqlTypeCode, type], Optional[int], Optional[int], Optional[int], @@ -163,11 +190,14 @@ class Cursor: """ # DB-API 2.0 Required Attributes + # description is a sequence of 7-item tuples: + # (name, type_code, display_size, internal_size, precision, scale, null_ok) + # type_code is SqlTypeCode which compares equal to both SQL integers and Python types description: Optional[ List[ Tuple[ str, - Any, + Union[SqlTypeCode, type], Optional[int], Optional[int], Optional[int], @@ -265,9 +295,13 @@ class Connection: ) -> None: ... def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]: ... def set_attr(self, attribute: int, value: Union[int, str, bytes, bytearray]) -> None: ... - def add_output_converter(self, sqltype: int, func: Callable[[Any], Any]) -> None: ... - def get_output_converter(self, sqltype: Union[int, type]) -> Optional[Callable[[Any], Any]]: ... - def remove_output_converter(self, sqltype: Union[int, type]) -> None: ... + def add_output_converter( + self, sqltype: Union[int, SqlTypeCode, type], func: Callable[[Any], Any] + ) -> None: ... + def get_output_converter( + self, sqltype: Union[int, SqlTypeCode, type] + ) -> Optional[Callable[[Any], Any]]: ... + def remove_output_converter(self, sqltype: Union[int, SqlTypeCode, type]) -> None: ... def clear_output_converters(self) -> None: ... def execute(self, sql: str, *args: Any) -> Cursor: ... def batch_execute( diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 2cf04fe0d..cc1b5426f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -27,6 +27,15 @@ #define MAX_DIGITS_IN_NUMERIC 64 #define SQL_MAX_NUMERIC_LEN 16 #define SQL_SS_XML (-152) +#ifndef SQL_SS_UDT +#define SQL_SS_UDT (-151) // SQL Server User-Defined Types (geometry, geography, hierarchyid) +#endif +#ifndef SQL_DATETIME2 +#define SQL_DATETIME2 (42) +#endif +#ifndef SQL_SMALLDATETIME +#define SQL_SMALLDATETIME (58) +#endif #define STRINGIFY_FOR_CASE(x) \ case x: \ diff --git a/mssql_python/type_code.py b/mssql_python/type_code.py new file mode 100644 index 000000000..4ef83c2bd --- /dev/null +++ b/mssql_python/type_code.py @@ -0,0 +1,115 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. +""" + +import decimal +import uuid +import datetime +from mssql_python.constants import ConstantsDDBC as ddbc_sql_const + + +class SqlTypeCode: + """ + A dual-compatible type code that compares equal to both SQL type integers and Python types. + + This class maintains backwards compatibility with code that checks + `cursor.description[i][1] == str` while also supporting DB-API 2.0 + compliant code that checks `cursor.description[i][1] == -9`. + + Examples: + >>> type_code = SqlTypeCode(-9) + >>> type_code == str # Backwards compatible with pandas, etc. + True + >>> type_code == -9 # DB-API 2.0 compliant + True + >>> int(type_code) # Get the raw SQL type code + -9 + """ + + # SQL type code to Python type mapping (class-level cache) + _type_map = None + + def __init__(self, type_code: int): + self.type_code = type_code + self.python_type = self._get_python_type(type_code) + + @classmethod + def _get_type_map(cls): + """Lazily build the SQL to Python type mapping.""" + if cls._type_map is None: + cls._type_map = { + ddbc_sql_const.SQL_CHAR.value: str, + ddbc_sql_const.SQL_VARCHAR.value: str, + ddbc_sql_const.SQL_LONGVARCHAR.value: str, + ddbc_sql_const.SQL_WCHAR.value: str, + ddbc_sql_const.SQL_WVARCHAR.value: str, + ddbc_sql_const.SQL_WLONGVARCHAR.value: str, + ddbc_sql_const.SQL_INTEGER.value: int, + ddbc_sql_const.SQL_REAL.value: float, + ddbc_sql_const.SQL_FLOAT.value: float, + ddbc_sql_const.SQL_DOUBLE.value: float, + ddbc_sql_const.SQL_DECIMAL.value: decimal.Decimal, + ddbc_sql_const.SQL_NUMERIC.value: decimal.Decimal, + ddbc_sql_const.SQL_DATE.value: datetime.date, + ddbc_sql_const.SQL_TIMESTAMP.value: datetime.datetime, + ddbc_sql_const.SQL_TIME.value: datetime.time, + ddbc_sql_const.SQL_SS_TIME2.value: datetime.time, + ddbc_sql_const.SQL_TYPE_DATE.value: datetime.date, + ddbc_sql_const.SQL_TYPE_TIME.value: datetime.time, + ddbc_sql_const.SQL_TYPE_TIMESTAMP.value: datetime.datetime, + ddbc_sql_const.SQL_TYPE_TIMESTAMP_WITH_TIMEZONE.value: datetime.datetime, + ddbc_sql_const.SQL_BIT.value: bool, + ddbc_sql_const.SQL_TINYINT.value: int, + ddbc_sql_const.SQL_SMALLINT.value: int, + ddbc_sql_const.SQL_BIGINT.value: int, + ddbc_sql_const.SQL_BINARY.value: bytes, + ddbc_sql_const.SQL_VARBINARY.value: bytes, + ddbc_sql_const.SQL_LONGVARBINARY.value: bytes, + ddbc_sql_const.SQL_GUID.value: uuid.UUID, + ddbc_sql_const.SQL_SS_UDT.value: bytes, + ddbc_sql_const.SQL_SS_XML.value: str, + ddbc_sql_const.SQL_DATETIME2.value: datetime.datetime, + ddbc_sql_const.SQL_SMALLDATETIME.value: datetime.datetime, + ddbc_sql_const.SQL_DATETIMEOFFSET.value: datetime.datetime, + } + return cls._type_map + + @classmethod + def _get_python_type(cls, sql_code: int) -> type: + """Get the Python type for a SQL type code.""" + return cls._get_type_map().get(sql_code, str) + + def __eq__(self, other): + """Compare equal to both Python types and SQL integer codes.""" + if isinstance(other, type): + return self.python_type == other + if isinstance(other, int): + return self.type_code == other + if isinstance(other, SqlTypeCode): + return self.type_code == other.type_code + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + """ + SqlTypeCode is intentionally unhashable because __eq__ allows + comparisons to both Python types and integer SQL codes, and + there is no single hash value that can be consistent with both. + """ + raise TypeError( + "SqlTypeCode is unhashable. Use int(type_code) or type_code.type_code " + "as a dict key instead. Example: {int(desc[1]): handler}" + ) + + def __int__(self): + return self.type_code + + def __repr__(self): + type_name = self.python_type.__name__ if self.python_type else "Unknown" + return f"SqlTypeCode({self.type_code}, {type_name})" + + def __str__(self): + return str(self.type_code) diff --git a/tests/test_002_types.py b/tests/test_002_types.py index 4828d72ea..2617f5485 100644 --- a/tests/test_002_types.py +++ b/tests/test_002_types.py @@ -16,6 +16,7 @@ TimestampFromTicks, Binary, ) +from mssql_python import SqlTypeCode def test_string_type(): @@ -1267,3 +1268,166 @@ def test_utf8_4byte_sequence_complete_coverage(): assert len(result) > 0, f"Invalid pattern should produce some output" assert True, "Complete 4-byte sequence coverage validated" + + +# ============================================================================= +# SqlTypeCode Unit Tests (DB-API 2.0 + pandas compatibility) +# ============================================================================= + + +class TestSqlTypeCode: + """ + Unit tests for SqlTypeCode class. + + SqlTypeCode provides dual compatibility: + - Compares equal to Python type objects (str, int, float, etc.) for pandas compatibility + - Compares equal to SQL integer codes for DB-API 2.0 compliance + """ + + def test_SqlTypeCode_import(self): + """Test that SqlTypeCode is importable from public API.""" + assert SqlTypeCode is not None + + def test_SqlTypeCode_equals_python_type_str(self): + """Test SqlTypeCode for SQL_WVARCHAR (-9) equals str.""" + tc = SqlTypeCode(-9) # SQL_WVARCHAR + assert tc == str, "SqlTypeCode(-9) should equal str" + assert not (tc != str), "SqlTypeCode(-9) should not be != str" + + def test_SqlTypeCode_equals_python_type_int(self): + """Test SqlTypeCode for SQL_INTEGER (4) equals int.""" + tc = SqlTypeCode(4) # SQL_INTEGER + assert tc == int, "SqlTypeCode(4) should equal int" + + def test_SqlTypeCode_equals_python_type_float(self): + """Test SqlTypeCode for SQL_REAL (7) equals float.""" + tc = SqlTypeCode(7) # SQL_REAL + assert tc == float, "SqlTypeCode(7) should equal float" + + def test_SqlTypeCode_equals_python_type_bytes(self): + """Test SqlTypeCode for SQL_BINARY (-2) equals bytes.""" + tc = SqlTypeCode(-2) # SQL_BINARY + assert tc == bytes, "SqlTypeCode(-2) should equal bytes" + + def test_SqlTypeCode_equals_sql_integer_code(self): + """Test SqlTypeCode equals its raw SQL integer code.""" + tc = SqlTypeCode(4) # SQL_INTEGER + assert tc == 4, "SqlTypeCode(4) should equal 4" + assert tc == SqlTypeCode(4).type_code, "SqlTypeCode(4) should equal its type_code" + + def test_SqlTypeCode_equals_negative_sql_code(self): + """Test SqlTypeCode with negative SQL codes (e.g., SQL_WVARCHAR = -9).""" + tc = SqlTypeCode(-9) # SQL_WVARCHAR + assert tc == -9, "SqlTypeCode(-9) should equal -9" + + def test_SqlTypeCode_dual_compatibility(self): + """Test that SqlTypeCode equals both Python type AND SQL code simultaneously.""" + tc = SqlTypeCode(4) # SQL_INTEGER + # Must satisfy BOTH comparisons - this is the key feature + assert tc == int and tc == 4, "SqlTypeCode should equal both int and 4" + + def test_SqlTypeCode_int_conversion(self): + """Test int(SqlTypeCode) returns raw SQL code.""" + tc = SqlTypeCode(-9) + assert int(tc) == -9, "int(SqlTypeCode(-9)) should return -9" + tc2 = SqlTypeCode(4) + assert int(tc2) == 4, "int(SqlTypeCode(4)) should return 4" + + def test_SqlTypeCode_unhashable(self): + """Test SqlTypeCode is intentionally unhashable due to eq/hash contract.""" + import pytest + + tc = SqlTypeCode(4) + with pytest.raises(TypeError) as exc_info: + hash(tc) + assert "unhashable" in str(exc_info.value).lower() + + def test_SqlTypeCode_repr(self): + """Test SqlTypeCode has informative repr.""" + tc = SqlTypeCode(4) + r = repr(tc) + assert "4" in r, "repr should contain the SQL code" + assert "SqlTypeCode" in r, "repr should contain class name" + + def test_SqlTypeCode_type_code_property(self): + """Test SqlTypeCode.type_code returns raw SQL code.""" + tc = SqlTypeCode(-9) + assert tc.type_code == -9 + tc2 = SqlTypeCode(93) # SQL_TYPE_TIMESTAMP + assert tc2.type_code == 93 + + def test_SqlTypeCode_python_type_property(self): + """Test SqlTypeCode.python_type returns mapped type.""" + tc = SqlTypeCode(4) # SQL_INTEGER + assert tc.python_type == int + tc2 = SqlTypeCode(-9) # SQL_WVARCHAR + assert tc2.python_type == str + + def test_SqlTypeCode_unknown_type_maps_to_str(self): + """Test unknown SQL codes map to str by default.""" + tc = SqlTypeCode(99999) # Unknown code + assert tc.python_type == str + assert tc == str # Should still work for comparison + + def test_SqlTypeCode_pandas_simulation(self): + """ + Simulate pandas read_sql type checking behavior. + + Pandas checks `cursor.description[i][1] == str` to determine + if a column should be treated as string data. + """ + # Simulate a description tuple like pandas receives + description = [ + ("name", SqlTypeCode(-9), None, None, None, None, None), # nvarchar + ("age", SqlTypeCode(4), None, None, None, None, None), # int + ("salary", SqlTypeCode(6), None, None, None, None, None), # float + ] + + # Pandas-style type checking + string_columns = [] + for name, type_code, *rest in description: + if type_code == str: + string_columns.append(name) + + assert string_columns == ["name"], "Only 'name' column should be detected as string" + + # Verify other types work too + for name, type_code, *rest in description: + if type_code == int: + assert name == "age" + if type_code == float: + assert name == "salary" + + def test_SqlTypeCode_dbapi_simulation(self): + """ + Simulate DB-API 2.0 style type checking with integer codes. + """ + # Simulate description + description = [ + ("id", SqlTypeCode(4), None, None, None, None, None), # SQL_INTEGER + ("data", SqlTypeCode(-9), None, None, None, None, None), # SQL_WVARCHAR + ] + + # DB-API style: check raw SQL code + for name, type_code, *rest in description: + if type_code == 4: # SQL_INTEGER + assert name == "id" + if type_code == -9: # SQL_WVARCHAR + assert name == "data" + + def test_SqlTypeCode_equality_with_other_SqlTypeCode(self): + """Test SqlTypeCode equality with another SqlTypeCode.""" + tc1 = SqlTypeCode(4) + tc2 = SqlTypeCode(4) + tc3 = SqlTypeCode(-9) + + assert tc1 == tc2, "Same code SqlTypeCodes should be equal via ==" + assert tc1 != tc3, "Different code SqlTypeCodes should not be equal via !=" + + def test_SqlTypeCode_inequality(self): + """Test SqlTypeCode inequality comparisons.""" + tc = SqlTypeCode(4) + assert tc != str, "SQL_INTEGER should not equal str" + assert tc != float, "SQL_INTEGER should not equal float" + assert tc != 5, "SqlTypeCode(4) should not equal 5" + assert tc != -9, "SqlTypeCode(4) should not equal -9" diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 575496299..42088fbb8 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -16,7 +16,6 @@ from contextlib import closing import mssql_python import uuid -import re from conftest import is_azure_sql_connection # Setup test table @@ -1197,7 +1196,7 @@ def test_fetchall(cursor): def test_fetchall_lob(cursor): - """Test fetching all rows""" + """Test fetching all rows with LOB columns""" cursor.execute("SELECT * FROM #pytest_all_data_types") rows = cursor.fetchall() assert isinstance(rows, list), "fetchall should return a list" @@ -2382,16 +2381,123 @@ def test_drop_tables_for_join(cursor, db_connection): def test_cursor_description(cursor): - """Test cursor description""" + """Test cursor description with SqlTypeCode for backwards compatibility.""" cursor.execute("SELECT database_id, name FROM sys.databases;") desc = cursor.description - expected_description = [ - ("database_id", int, None, 10, 10, 0, False), - ("name", str, None, 128, 128, 0, False), - ] - assert len(desc) == len(expected_description), "Description length mismatch" - for desc, expected in zip(desc, expected_description): - assert desc == expected, f"Description mismatch: {desc} != {expected}" + + from mssql_python.constants import ConstantsDDBC as ddbc_sql_const + + # Verify length + assert len(desc) == 2, "Description should have 2 columns" + + # Test 1: DB-API 2.0 compliant - compare with SQL type codes (integers) + assert desc[0][1] == ddbc_sql_const.SQL_INTEGER.value, "database_id should be SQL_INTEGER (4)" + assert desc[1][1] == ddbc_sql_const.SQL_WVARCHAR.value, "name should be SQL_WVARCHAR (-9)" + + # Test 2: Backwards compatible - compare with Python types (for pandas, etc.) + assert desc[0][1] == int, "database_id should also compare equal to Python int" + assert desc[1][1] == str, "name should also compare equal to Python str" + + # Test 3: Can convert to int to get raw SQL code + assert int(desc[0][1]) == 4, "int(type_code) should return SQL_INTEGER (4)" + assert int(desc[1][1]) == -9, "int(type_code) should return SQL_WVARCHAR (-9)" + + # Test 4: Verify other tuple elements + assert desc[0][0] == "database_id", "First column name should be database_id" + assert desc[1][0] == "name", "Second column name should be name" + + +def test_cursor_description_pandas_compatibility(cursor): + """ + Test that cursor.description type_code works with pandas-style type checking. + + Pandas and other libraries check `cursor.description[i][1] == str` to determine + column types. This test ensures SqlTypeCode maintains backwards compatibility. + """ + cursor.execute("SELECT database_id, name FROM sys.databases;") + desc = cursor.description + + # Simulate what pandas does internally when reading SQL results + # pandas checks: if description[i][1] == str: treat as string column + type_map = {} + for col_desc in desc: + col_name = col_desc[0] + type_code = col_desc[1] + + # This is how pandas-like code typically checks types + if type_code == str: + type_map[col_name] = "string" + elif type_code == int: + type_map[col_name] = "integer" + elif type_code == float: + type_map[col_name] = "float" + elif type_code == bytes: + type_map[col_name] = "bytes" + else: + type_map[col_name] = "other" + + assert type_map["database_id"] == "integer", "database_id should be detected as integer" + assert type_map["name"] == "string", "name should be detected as string" + + +def test_cursor_description_datetime_types(cursor, db_connection): + """ + Regression test for Issue #352: Ensure DATE/datetime columns return correct ODBC type codes. + + This test verifies that cursor.description properly handles date/time columns, + returning the correct ODBC 3.x type codes while maintaining backwards compatibility + with Python datetime types for pandas-style comparisons. + """ + from mssql_python.constants import ConstantsDDBC + + try: + # Create a table with various date/time types + cursor.execute(""" + CREATE TABLE #pytest_datetime_desc ( + id INT PRIMARY KEY, + date_col DATE, + time_col TIME, + datetime_col DATETIME, + datetime2_col DATETIME2 + ); + """) + db_connection.commit() + + cursor.execute( + "SELECT id, date_col, time_col, datetime_col, datetime2_col FROM #pytest_datetime_desc;" + ) + desc = cursor.description + + assert len(desc) == 5, "Should have 5 columns in description" + + # Verify column names + assert desc[0][0] == "id", "First column should be 'id'" + assert desc[1][0] == "date_col", "Second column should be 'date_col'" + assert desc[2][0] == "time_col", "Third column should be 'time_col'" + assert desc[3][0] == "datetime_col", "Fourth column should be 'datetime_col'" + assert desc[4][0] == "datetime2_col", "Fifth column should be 'datetime2_col'" + + # Test 1: DB-API 2.0 compliant - verify SQL type codes as integers + # DATE should be SQL_TYPE_DATE (91) + assert ( + int(desc[1][1]) == ConstantsDDBC.SQL_TYPE_DATE.value + ), f"DATE column should have SQL_TYPE_DATE type code ({ConstantsDDBC.SQL_TYPE_DATE.value})" + # TIME should be SQL_SS_TIME2 (-154) or SQL_TYPE_TIME (92) + time_type_code = int(desc[2][1]) + assert time_type_code in ( + ConstantsDDBC.SQL_SS_TIME2.value, + ConstantsDDBC.SQL_TYPE_TIME.value, + ), f"TIME column should have SQL_SS_TIME2 or SQL_TYPE_TIME type code, got {time_type_code}" + + # Test 2: Backwards compatible - compare with Python types (for pandas, etc.) + assert desc[1][1] == date, "DATE should compare equal to datetime.date" + assert desc[2][1] == time, "TIME should compare equal to datetime.time" + assert desc[3][1] == datetime, "DATETIME should compare equal to datetime.datetime" + assert desc[4][1] == datetime, "DATETIME2 should compare equal to datetime.datetime" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_datetime_desc;") + db_connection.commit() def test_parse_datetime(cursor, db_connection): @@ -8986,11 +9092,8 @@ def test_decimal_separator_fetch_regression(cursor, db_connection): finally: # Reset separator to default just in case mssql_python.setDecimalSeparator(".") - try: - cursor.execute("DROP TABLE IF EXISTS #TestDecimal") - db_connection.commit() - except Exception: - pass + cursor.execute("DROP TABLE IF EXISTS #TestDecimal") + db_connection.commit() def test_datetimeoffset_read_write(cursor, db_connection): @@ -13405,11 +13508,8 @@ def test_decimal_scientific_notation_to_varchar(cursor, db_connection, values, d ), f"{description}: Row {i} mismatch - expected {expected_val}, got {stored_val}" finally: - try: - cursor.execute(f"DROP TABLE {table_name}") - db_connection.commit() - except: - pass + cursor.execute(f"DROP TABLE IF EXISTS {table_name}") + db_connection.commit() SMALL_XML = "1" @@ -13511,13 +13611,186 @@ def test_xml_malformed_input(cursor, db_connection): ) db_connection.commit() - with pytest.raises(Exception): + with pytest.raises(mssql_python.Error): cursor.execute("INSERT INTO #pytest_xml_invalid (xml_col) VALUES (?);", INVALID_XML) finally: cursor.execute("DROP TABLE IF EXISTS #pytest_xml_invalid;") db_connection.commit() +def test_column_metadata_thread_safety_concurrent_cursors(db_connection, conn_str): + """Test thread safety of _column_metadata with concurrent cursors across threads.""" + import threading + from mssql_python import connect + + # Track results and errors from each thread + results = {} + errors = [] + lock = threading.Lock() + + def worker(thread_id, table_suffix): + # Each thread uses its own independent connection + thread_conn = None + cursor = None + try: + thread_conn = connect(conn_str) + cursor = thread_conn.cursor() + + try: + # Create a unique temp table for this thread + table_name = f"#pytest_thread_meta_{table_suffix}" + cursor.execute(f"DROP TABLE IF EXISTS {table_name};") + + # Create table with distinct column structure for this thread + cursor.execute(f""" + CREATE TABLE {table_name} ( + thread_id INT, + col_{table_suffix}_a NVARCHAR(100), + col_{table_suffix}_b INT, + col_{table_suffix}_c FLOAT + ); + """) + thread_conn.commit() + + # Insert test data + cursor.execute(f""" + INSERT INTO {table_name} VALUES + ({thread_id}, 'data_{thread_id}_1', {thread_id * 100}, {thread_id * 1.5}), + ({thread_id}, 'data_{thread_id}_2', {thread_id * 200}, {thread_id * 2.5}); + """) + thread_conn.commit() + + # Execute SELECT and verify description metadata is correct + cursor.execute(f"SELECT * FROM {table_name} ORDER BY col_{table_suffix}_b;") + + # Verify cursor has correct description for THIS query + desc = cursor.description + assert desc is not None, f"Thread {thread_id}: description should not be None" + assert len(desc) == 4, f"Thread {thread_id}: should have 4 columns" + + # Verify column names are correct for this thread's table + col_names = [d[0].lower() for d in desc] + expected_names = [ + "thread_id", + f"col_{table_suffix}_a", + f"col_{table_suffix}_b", + f"col_{table_suffix}_c", + ] + assert col_names == expected_names, f"Thread {thread_id}: column names should match" + + # Fetch all rows and verify data + rows = cursor.fetchall() + assert len(rows) == 2, f"Thread {thread_id}: should have 2 rows" + assert rows[0][0] == thread_id, f"Thread {thread_id}: thread_id column should match" + + # Verify _column_metadata is set (internal attribute) + assert ( + cursor._column_metadata is not None + ), f"Thread {thread_id}: _column_metadata should be set" + + # Clean up + cursor.execute(f"DROP TABLE IF EXISTS {table_name};") + thread_conn.commit() + + with lock: + results[thread_id] = { + "success": True, + "col_count": len(desc), + "row_count": len(rows), + } + + finally: + if cursor: + cursor.close() + if thread_conn: + thread_conn.close() + + except Exception as e: + with lock: + errors.append((thread_id, str(e))) + + # Create and start multiple threads + num_threads = 5 + threads = [] + + for i in range(num_threads): + t = threading.Thread(target=worker, args=(i, f"t{i}"), daemon=True) + threads.append(t) + + # Start all threads at roughly the same time + for t in threads: + t.start() + + # Wait for all threads to complete + for t in threads: + t.join(timeout=30) # 30 second timeout per thread + + # Verify threads actually finished (not just timed out) + hung_threads = [t for t in threads if t.is_alive()] + assert len(hung_threads) == 0, f"{len(hung_threads)} thread(s) still running after timeout" + + # Verify no errors occurred + assert len(errors) == 0, f"Thread errors occurred: {errors}" + + # Verify all threads completed successfully + assert len(results) == num_threads, f"Expected {num_threads} results, got {len(results)}" + + for thread_id, result in results.items(): + assert result["success"], f"Thread {thread_id} did not succeed" + assert result["col_count"] == 4, f"Thread {thread_id} had wrong column count" + assert result["row_count"] == 2, f"Thread {thread_id} had wrong row count" + + +def test_column_metadata_isolation_sequential_queries(cursor, db_connection): + """ + Test that _column_metadata is correctly updated between sequential queries. + + Verifies that each execute() call properly replaces the previous metadata, + ensuring no stale data leaks between queries. + """ + try: + # Query 1: Simple 2-column query + cursor.execute("SELECT 1 as col_a, 'hello' as col_b;") + desc1 = cursor.description + meta1 = cursor._column_metadata + cursor.fetchall() + + assert len(desc1) == 2, "First query should have 2 columns" + assert meta1 is not None, "_column_metadata should be set" + + # Query 2: Different structure - 4 columns + cursor.execute("SELECT 1 as x, 2 as y, 3 as z, 4 as w;") + desc2 = cursor.description + meta2 = cursor._column_metadata + cursor.fetchall() + + assert len(desc2) == 4, "Second query should have 4 columns" + assert meta2 is not None, "_column_metadata should be set" + + # Verify the metadata was replaced, not appended + assert len(meta2) == 4, "_column_metadata should have 4 entries" + assert meta1 is not meta2, "_column_metadata should be a new object" + + # Query 3: Back to 2 columns with different names + cursor.execute("SELECT 'test' as different_name, 42.5 as another_col;") + desc3 = cursor.description + meta3 = cursor._column_metadata + cursor.fetchall() + + assert len(desc3) == 2, "Third query should have 2 columns" + assert len(meta3) == 2, "_column_metadata should have 2 entries" + + # Verify column names are from the new query + col_names = [d[0].lower() for d in desc3] + assert col_names == [ + "different_name", + "another_col", + ], "Column names should be from third query" + + except Exception as e: + pytest.fail(f"Column metadata isolation test failed: {e}") + + # ==================== CODE COVERAGE TEST CASES ==================== @@ -14030,7 +14303,8 @@ def test_row_output_converter_general_exception(cursor, db_connection): # Create a custom output converter that will raise a general exception def failing_converter(value): - if value == "test_value": + # Driver passes string values as UTF-16LE encoded bytes to output converters + if value == "test_value".encode("utf-16-le"): raise RuntimeError("Custom converter error for testing") return value diff --git a/tests/test_018_polars_integration.py b/tests/test_018_polars_integration.py new file mode 100644 index 000000000..2006b5982 --- /dev/null +++ b/tests/test_018_polars_integration.py @@ -0,0 +1,240 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Polars integration tests for mssql-python driver (Issue #352).""" + +import datetime +import platform + +import pytest + +# Skip on Alpine ARM64 — polars crashes with "Illegal instruction" during import. +_machine = platform.machine().lower() +_is_arm = _machine in ("aarch64", "arm64", "armv8") + +# Check if running on Alpine (musl libc) +_is_alpine = False +try: + with open("/etc/os-release", "r") as f: + _is_alpine = "alpine" in f.read().lower() +except (FileNotFoundError, PermissionError): + # /etc/os-release missing or unreadable — not Alpine, continue. + pass + +if _is_arm and _is_alpine: + pytest.skip( + "Skipping polars tests on Alpine ARM64 (polars crashes during import)", + allow_module_level=True, + ) + +# Now safe to import polars on supported platforms +pl = pytest.importorskip("polars", reason="polars not available on this platform") + + +class TestPolarsIntegration: + """Integration tests for polars compatibility with mssql-python.""" + + def test_polars_read_database_basic(self, cursor, db_connection): + """ + Test polars can read basic data types via pl.read_database(). + + This is the exact scenario reported in issue #352. + """ + # Create test table with various types + cursor.execute(""" + CREATE TABLE #pytest_polars_basic ( + id INT, + name NVARCHAR(100), + value FLOAT + ); + """) + cursor.execute(""" + INSERT INTO #pytest_polars_basic VALUES + (1, 'Alice', 100.5), + (2, 'Bob', 200.75), + (3, 'Charlie', 300.25); + """) + db_connection.commit() + + try: + # Use polars read_database with our connection + df = pl.read_database( + query="SELECT id, name, value FROM #pytest_polars_basic ORDER BY id", + connection=db_connection, + ) + + assert len(df) == 3, "Should have 3 rows" + assert df.columns == ["id", "name", "value"], "Column names should match" + assert df["id"].to_list() == [1, 2, 3], "id values should match" + assert df["name"].to_list() == ["Alice", "Bob", "Charlie"], "name values should match" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_polars_basic;") + db_connection.commit() + + def test_polars_read_database_with_dates(self, cursor, db_connection): + """ + Test polars can read DATE columns - the specific failure case from issue #352. + + The original error was: + ComputeError: could not append value: 2013-01-01 of type: date to the builder + """ + cursor.execute(""" + CREATE TABLE #pytest_polars_dates ( + id INT, + date_col DATE, + datetime_col DATETIME + ); + """) + cursor.execute(""" + INSERT INTO #pytest_polars_dates VALUES + (1, '2013-01-01', '2013-01-01 10:30:00'), + (2, '2024-06-15', '2024-06-15 14:45:30'), + (3, '2025-12-31', '2025-12-31 23:59:59'); + """) + db_connection.commit() + + try: + df = pl.read_database( + query="SELECT id, date_col, datetime_col FROM #pytest_polars_dates ORDER BY id", + connection=db_connection, + ) + + assert len(df) == 3, "Should have 3 rows" + assert "date_col" in df.columns, "date_col should be present" + assert "datetime_col" in df.columns, "datetime_col should be present" + + # Verify date values are correct + dates = df["date_col"].to_list() + assert dates[0] == datetime.date(2013, 1, 1), "First date should be 2013-01-01" + assert dates[1] == datetime.date(2024, 6, 15), "Second date should be 2024-06-15" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_polars_dates;") + db_connection.commit() + + def test_polars_read_database_all_common_types(self, cursor, db_connection): + """ + Test polars can read all common SQL Server data types. + """ + cursor.execute(""" + CREATE TABLE #pytest_polars_types ( + int_col INT, + bigint_col BIGINT, + float_col FLOAT, + decimal_col DECIMAL(10,2), + varchar_col VARCHAR(100), + nvarchar_col NVARCHAR(100), + bit_col BIT, + date_col DATE, + datetime_col DATETIME, + time_col TIME + ); + """) + cursor.execute(""" + INSERT INTO #pytest_polars_types VALUES + (42, 9223372036854775807, 3.14159, 123.45, 'hello', N'世界', 1, + '2025-01-15', '2025-01-15 10:30:00', '14:30:00'); + """) + db_connection.commit() + + try: + df = pl.read_database( + query="SELECT * FROM #pytest_polars_types", + connection=db_connection, + ) + + assert len(df) == 1, "Should have 1 row" + assert len(df.columns) == 10, "Should have 10 columns" + + # Verify all column values + row = df.row(0) + assert row[0] == 42, "int_col should be 42" + assert row[1] == 9223372036854775807, "bigint_col should be max BIGINT" + assert abs(row[2] - 3.14159) < 1e-4, "float_col should be ~3.14159" + assert row[4] == "hello", "varchar_col should be 'hello'" + assert row[5] == "世界", "nvarchar_col should be '世界' (Unicode)" + assert row[6] in (1, True), "bit_col should be 1 or True" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_polars_types;") + db_connection.commit() + + def test_polars_read_database_with_nulls(self, cursor, db_connection): + """ + Test polars can handle NULL values correctly. + """ + cursor.execute(""" + CREATE TABLE #pytest_polars_nulls ( + id INT, + nullable_str NVARCHAR(100) NULL, + nullable_int INT NULL, + nullable_date DATE NULL + ); + """) + cursor.execute(""" + INSERT INTO #pytest_polars_nulls VALUES + (1, 'has value', 100, '2025-01-01'), + (2, NULL, NULL, NULL), + (3, 'another', 200, '2025-12-31'); + """) + db_connection.commit() + + try: + df = pl.read_database( + query="SELECT * FROM #pytest_polars_nulls ORDER BY id", + connection=db_connection, + ) + + assert len(df) == 3, "Should have 3 rows" + + # Check NULL handling across all nullable columns + str_values = df["nullable_str"].to_list() + assert str_values[0] == "has value" + assert str_values[1] is None, "NULL should become None" + assert str_values[2] == "another" + + int_values = df["nullable_int"].to_list() + assert int_values[0] == 100 + assert int_values[1] is None, "NULL int should become None" + assert int_values[2] == 200 + + # Issue #352: date NULL handling was the original bug + date_values = df["nullable_date"].to_list() + assert date_values[1] is None, "NULL date should become None (Issue #352)" + assert date_values[0] is not None, "Non-NULL date should have a value" + assert date_values[2] is not None, "Non-NULL date should have a value" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_polars_nulls;") + db_connection.commit() + + def test_polars_read_database_large_result(self, cursor, db_connection): + """ + Test polars can handle larger result sets. + """ + cursor.execute(""" + CREATE TABLE #pytest_polars_large ( + id INT, + data NVARCHAR(100) + ); + """) + + # Insert 1000 rows + for i in range(100): + values = ", ".join([f"({i*10+j}, 'row_{i*10+j}')" for j in range(10)]) + cursor.execute(f"INSERT INTO #pytest_polars_large VALUES {values};") + db_connection.commit() + + try: + df = pl.read_database( + query="SELECT * FROM #pytest_polars_large ORDER BY id", + connection=db_connection, + ) + + assert len(df) == 1000, "Should have 1000 rows" + assert df["id"].min() == 0, "Min id should be 0" + assert df["id"].max() == 999, "Max id should be 999" + + finally: + cursor.execute("DROP TABLE IF EXISTS #pytest_polars_large;") + db_connection.commit()