From 4fce3d063039ac5a384f3d9030ecc8135c8df60e Mon Sep 17 00:00:00 2001 From: Christopher Rowley Date: Wed, 24 Dec 2025 15:39:27 +0000 Subject: [PATCH] Add PyString wrapper for Python strings --- docs/src/conversion-to-julia.md | 1 + docs/src/pythoncall-reference.md | 1 + docs/src/pythoncall.md | 3 + src/API/exports.jl | 1 + src/API/types.jl | 21 +++++ src/C/pointers.jl | 1 + src/Convert/pyconvert.jl | 3 +- src/Convert/rules.jl | 1 + src/Core/builtins.jl | 7 ++ src/Utils/Utils.jl | 139 ++++++++++++++++--------------- src/Wrap/PyString.jl | 15 ++++ src/Wrap/Wrap.jl | 3 +- test/Wrap.jl | 23 +++++ 13 files changed, 152 insertions(+), 67 deletions(-) create mode 100644 src/Wrap/PyString.jl diff --git a/docs/src/conversion-to-julia.md b/docs/src/conversion-to-julia.md index 67d1e904..40dc6eb8 100644 --- a/docs/src/conversion-to-julia.md +++ b/docs/src/conversion-to-julia.md @@ -37,6 +37,7 @@ From Python, the arguments to a Julia function will be converted according to th | `None` | `Missing` | | `bytes` | `Vector{UInt8}`, `Vector{Int8}`, `String` | | `str` | `String`, `Symbol`, `Char`, `Vector{UInt8}`, `Vector{Int8}` | +| `str` | `PyString` | | `range` | `UnitRange` | | `collections.abc.Mapping` | `Dict` | | `collections.abc.Iterable` | `Vector`, `Set`, `Tuple`, `NamedTuple`, `Pair` | diff --git a/docs/src/pythoncall-reference.md b/docs/src/pythoncall-reference.md index f09d5f12..b370eb09 100644 --- a/docs/src/pythoncall-reference.md +++ b/docs/src/pythoncall-reference.md @@ -184,6 +184,7 @@ PyList PySet PyDict PyIterable +PyString PyArray PyIO PyTable diff --git a/docs/src/pythoncall.md b/docs/src/pythoncall.md index 5bc610ff..fe08b309 100644 --- a/docs/src/pythoncall.md +++ b/docs/src/pythoncall.md @@ -227,6 +227,9 @@ Python: [3, 4, 5, None, 1, 2] There are wrappers for other container types, such as [`PyDict`](@ref) and [`PySet`](@ref). +`PyString` is a zero-copy wrapper around a Python `str`, exposing it as a Julia +`AbstractString` backed by the UTF-8 pointer cached by Python. + The wrapper [`PyArray`](@ref) provides a Julia array view of any Python array, i.e. anything satisfying either the buffer protocol or the numpy array interface. This includes things like `bytes`, `bytearray`, `array.array` and `numpy.ndarray`: diff --git a/src/API/exports.jl b/src/API/exports.jl index 3d476182..ea5b823c 100644 --- a/src/API/exports.jl +++ b/src/API/exports.jl @@ -117,6 +117,7 @@ export PyDict export PyIO export PyIterable export PyList +export PyString export PyPandasDataFrame export PySet export PyTable diff --git a/src/API/types.jl b/src/API/types.jl index 55a4211e..85e35ba7 100644 --- a/src/API/types.jl +++ b/src/API/types.jl @@ -169,6 +169,27 @@ struct PyIterable{T} PyIterable{T}(x) where {T} = new{T}(Py(x)) end +""" + PyString(x) + +Wraps the Python `str` `x` as an `AbstractString` without copying. + +The UTF-8 data is stored as a pointer and byte length obtained from +`PyUnicode_AsUTF8AndSize`, and remains valid as long as the underlying Python +object is alive. +""" +struct PyString <: AbstractString + py::Py + ptr::Ptr{UInt8} + nbytes::Int + function PyString(x) + py = Py(x) + PythonCall.Core.pyisstr(py) || throw(ArgumentError("PyString expects a Python `str`")) + ptr, n = PythonCall.Core.pystr_utf8_pointer(py) + new(py, ptr, n) + end +end + """ PyList{T=Py}([x]) diff --git a/src/C/pointers.jl b/src/C/pointers.jl index d7441960..bc12865f 100644 --- a/src/C/pointers.jl +++ b/src/C/pointers.jl @@ -139,6 +139,7 @@ const CAPI_FUNC_SIGS = Dict{Symbol,Pair{Tuple,Type}}( :PyComplex_AsCComplex => (PyPtr,) => Py_complex, # STR :PyUnicode_DecodeUTF8 => (Ptr{Cchar}, Py_ssize_t, Ptr{Cchar}) => PyPtr, + :PyUnicode_AsUTF8AndSize => (PyPtr, Ptr{Py_ssize_t}) => Ptr{Cchar}, :PyUnicode_AsUTF8String => (PyPtr,) => PyPtr, :PyUnicode_InternInPlace => (Ptr{PyPtr},) => Cvoid, # BYTES diff --git a/src/Convert/pyconvert.jl b/src/Convert/pyconvert.jl index 2b03eb72..15937702 100644 --- a/src/Convert/pyconvert.jl +++ b/src/Convert/pyconvert.jl @@ -321,7 +321,7 @@ function pyconvert_rule_fast(::Type{T}, x::Py) where {T} pyisfloat(x) && return pyconvert_return(T(pyfloat_asdouble(x))) elseif (T == ComplexF64) pyiscomplex(x) && return pyconvert_return(T(pycomplex_ascomplex(x))) - elseif (T == String) | (T == Char) | (T == Symbol) + elseif (T == String) | (T == Char) | (T == Symbol) | (T == PyString) pyisstr(x) && return pyconvert_rule_str(T, x) elseif (T == Vector{UInt8}) | (T == Base.CodeUnits{UInt8,String}) pyisbytes(x) && return pyconvert_rule_bytes(T, x) @@ -473,6 +473,7 @@ function init_pyconvert() pyconvert_add_rule("builtins:float", Nothing, pyconvert_rule_float, priority) pyconvert_add_rule("builtins:float", Missing, pyconvert_rule_float, priority) pyconvert_add_rule("numbers:Complex", Number, pyconvert_rule_complex, priority) + pyconvert_add_rule("builtins:str", PyString, pyconvert_rule_str, priority) pyconvert_add_rule("numbers:Integral", Number, pyconvert_rule_int, priority) pyconvert_add_rule("builtins:str", Symbol, pyconvert_rule_str, priority) pyconvert_add_rule("builtins:str", Char, pyconvert_rule_str, priority) diff --git a/src/Convert/rules.jl b/src/Convert/rules.jl index ac104d5e..6daa66af 100644 --- a/src/Convert/rules.jl +++ b/src/Convert/rules.jl @@ -48,6 +48,7 @@ pyconvert_rule_str(::Type{Char}, x::Py) = begin pyconvert_unconverted() end end +pyconvert_rule_str(::Type{PyString}, x::Py) = pyconvert_return(PyString(x)) ### bytes diff --git a/src/Core/builtins.jl b/src/Core/builtins.jl index 7e772662..73a7726d 100644 --- a/src/Core/builtins.jl +++ b/src/Core/builtins.jl @@ -573,6 +573,13 @@ pystr_asUTF8vector(x::Py) = pystr_asstring(x::Py) = (b = pystr_asUTF8bytes(x); ans = pybytes_asUTF8string(b); pydel!(b); ans) +function pystr_utf8_pointer(x::Py) + n = Ref{C.Py_ssize_t}() + p = C.PyUnicode_AsUTF8AndSize(x, n) + p == C_NULL && pythrow() + Ptr{UInt8}(p), Int(n[]) +end + function pystr_intern!(x::Py) ptr = Ref(getptr(x)) C.PyUnicode_InternInPlace(ptr) diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl index 6e7b3f4d..419d96e2 100644 --- a/src/Utils/Utils.jl +++ b/src/Utils/Utils.jl @@ -180,109 +180,79 @@ size_to_fstrides(elsz::Integer, sz::Tuple{Vararg{Integer}}) = size_to_cstrides(elsz::Integer, sz::Tuple{Vararg{Integer}}) = isempty(sz) ? () : (size_to_cstrides(elsz * sz[end], sz[1:end-1])..., elsz) -struct StaticString{T,N} <: AbstractString - codeunits::NTuple{N,T} - StaticString{T,N}(codeunits::NTuple{N,T}) where {T,N} = new{T,N}(codeunits) -end - -function Base.String(x::StaticString{T,N}) where {T,N} - ts = x.codeunits - n = N - while n > 0 && iszero(ts[n]) - n -= 1 +function utf8_allzeros(codeunit::F, n::Int, i::Int) where {F} + @inbounds for j in i:n + iszero(codeunit(j)) || return false end - cs = T[ts[i] for i = 1:n] - transcode(String, cs) + return true end -function Base.convert(::Type{StaticString{T,N}}, x::AbstractString) where {T,N} - ts = transcode(T, convert(String, x)) - n = length(ts) - n > N && throw(InexactError(:convert, StaticString{T,N}, x)) - n > 0 && iszero(ts[n]) && throw(InexactError(:convert, StaticString{T,N}, x)) - z = zero(T) - cs = ntuple(i -> i > n ? z : @inbounds(ts[i]), N) - StaticString{T,N}(cs) -end - -StaticString{T,N}(x::AbstractString) where {T,N} = convert(StaticString{T,N}, x) - -Base.ncodeunits(x::StaticString{T,N}) where {T,N} = N - -Base.codeunit(x::StaticString, i::Integer) = x.codeunits[i] - -Base.codeunit(x::StaticString{T}) where {T} = T - -function Base.isvalid(x::StaticString{UInt8,N}, i::Int) where {N} - if i < 1 || i > N +function utf8_isvalid(codeunit::F, n::Int, i::Int; zeroterminated::Bool = false) where {F} + if i < 1 || i > n return false end - cs = x.codeunits - c = @inbounds cs[i] - if all(iszero, (cs[j] for j = i:N)) - return false - elseif (c & 0x80) == 0x00 + zeroterminated && utf8_allzeros(codeunit, n, i) && return false + c = @inbounds codeunit(i) + if (c & 0x80) == 0x00 return true elseif (c & 0x40) == 0x00 return false elseif (c & 0x20) == 0x00 - return @inbounds (i ≤ N - 1) && ((cs[i+1] & 0xC0) == 0x80) + return @inbounds (i ≤ n - 1) && ((codeunit(i + 1) & 0xC0) == 0x80) elseif (c & 0x10) == 0x00 - return @inbounds (i ≤ N - 2) && - ((cs[i+1] & 0xC0) == 0x80) && - ((cs[i+2] & 0xC0) == 0x80) + return @inbounds (i ≤ n - 2) && + ((codeunit(i + 1) & 0xC0) == 0x80) && + ((codeunit(i + 2) & 0xC0) == 0x80) elseif (c & 0x08) == 0x00 - return @inbounds (i ≤ N - 3) && - ((cs[i+1] & 0xC0) == 0x80) && - ((cs[i+2] & 0xC0) == 0x80) && - ((cs[i+3] & 0xC0) == 0x80) + return @inbounds (i ≤ n - 3) && + ((codeunit(i + 1) & 0xC0) == 0x80) && + ((codeunit(i + 2) & 0xC0) == 0x80) && + ((codeunit(i + 3) & 0xC0) == 0x80) else return false end - return false end -function Base.iterate(x::StaticString{UInt8,N}, i::Int = 1) where {N} - i > N && return - cs = x.codeunits - c = @inbounds cs[i] - if all(iszero, (cs[j] for j = i:N)) +function utf8_iterate(x, codeunit::F, n::Int, i::Int = 1; zeroterminated::Bool = false) where {F} + i > n && return + c = @inbounds codeunit(i) + if zeroterminated && utf8_allzeros(codeunit, n, i) return elseif (c & 0x80) == 0x00 return (reinterpret(Char, UInt32(c) << 24), i + 1) elseif (c & 0x40) == 0x00 nothing elseif (c & 0x20) == 0x00 - if @inbounds (i ≤ N - 1) && ((cs[i+1] & 0xC0) == 0x80) + if @inbounds (i ≤ n - 1) && ((codeunit(i + 1) & 0xC0) == 0x80) return ( - reinterpret(Char, (UInt32(cs[i]) << 24) | (UInt32(cs[i+1]) << 16)), + reinterpret(Char, (UInt32(codeunit(i)) << 24) | (UInt32(codeunit(i + 1)) << 16)), i + 2, ) end elseif (c & 0x10) == 0x00 - if @inbounds (i ≤ N - 2) && ((cs[i+1] & 0xC0) == 0x80) && ((cs[i+2] & 0xC0) == 0x80) + if @inbounds (i ≤ n - 2) && ((codeunit(i + 1) & 0xC0) == 0x80) && ((codeunit(i + 2) & 0xC0) == 0x80) return ( reinterpret( Char, - (UInt32(cs[i]) << 24) | - (UInt32(cs[i+1]) << 16) | - (UInt32(cs[i+2]) << 8), + (UInt32(codeunit(i)) << 24) | + (UInt32(codeunit(i + 1)) << 16) | + (UInt32(codeunit(i + 2)) << 8), ), i + 3, ) end elseif (c & 0x08) == 0x00 - if @inbounds (i ≤ N - 3) && - ((cs[i+1] & 0xC0) == 0x80) && - ((cs[i+2] & 0xC0) == 0x80) && - ((cs[i+3] & 0xC0) == 0x80) + if @inbounds (i ≤ n - 3) && + ((codeunit(i + 1) & 0xC0) == 0x80) && + ((codeunit(i + 2) & 0xC0) == 0x80) && + ((codeunit(i + 3) & 0xC0) == 0x80) return ( reinterpret( Char, - (UInt32(cs[i]) << 24) | - (UInt32(cs[i+1]) << 16) | - (UInt32(cs[i+2]) << 8) | - UInt32(cs[i+3]), + (UInt32(codeunit(i)) << 24) | + (UInt32(codeunit(i + 1)) << 16) | + (UInt32(codeunit(i + 2)) << 8) | + UInt32(codeunit(i + 3)), ), i + 4, ) @@ -291,6 +261,45 @@ function Base.iterate(x::StaticString{UInt8,N}, i::Int = 1) where {N} throw(StringIndexError(x, i)) end +struct StaticString{T,N} <: AbstractString + codeunits::NTuple{N,T} + StaticString{T,N}(codeunits::NTuple{N,T}) where {T,N} = new{T,N}(codeunits) +end + +function Base.String(x::StaticString{T,N}) where {T,N} + ts = x.codeunits + n = N + while n > 0 && iszero(ts[n]) + n -= 1 + end + cs = T[ts[i] for i = 1:n] + transcode(String, cs) +end + +function Base.convert(::Type{StaticString{T,N}}, x::AbstractString) where {T,N} + ts = transcode(T, convert(String, x)) + n = length(ts) + n > N && throw(InexactError(:convert, StaticString{T,N}, x)) + n > 0 && iszero(ts[n]) && throw(InexactError(:convert, StaticString{T,N}, x)) + z = zero(T) + cs = ntuple(i -> i > n ? z : @inbounds(ts[i]), N) + StaticString{T,N}(cs) +end + +StaticString{T,N}(x::AbstractString) where {T,N} = convert(StaticString{T,N}, x) + +Base.ncodeunits(x::StaticString{T,N}) where {T,N} = N + +Base.codeunit(x::StaticString, i::Integer) = x.codeunits[i] + +Base.codeunit(x::StaticString{T}) where {T} = T + +Base.isvalid(x::StaticString{UInt8,N}, i::Int) where {N} = + utf8_isvalid(j -> @inbounds(x.codeunits[j]), N, i; zeroterminated = true) + +Base.iterate(x::StaticString{UInt8,N}, i::Int = 1) where {N} = + utf8_iterate(x, j -> @inbounds(x.codeunits[j]), N, i; zeroterminated = true) + function Base.isvalid(x::StaticString{UInt32,N}, i::Int) where {N} i < 1 && return false cs = x.codeunits diff --git a/src/Wrap/PyString.jl b/src/Wrap/PyString.jl new file mode 100644 index 00000000..d9d3d7bb --- /dev/null +++ b/src/Wrap/PyString.jl @@ -0,0 +1,15 @@ +ispy(::PyString) = true +Py(x::PyString) = x.py + +Base.ncodeunits(x::PyString) = x.nbytes +Base.codeunit(::PyString) = UInt8 +Base.codeunit(::Type{PyString}) = UInt8 +Base.codeunit(x::PyString, i::Integer) = @inbounds unsafe_load(x.ptr + (i - 1)) + +Base.isvalid(x::PyString, i::Int) = + Utils.utf8_isvalid(j -> @inbounds(unsafe_load(x.ptr + (j - 1))), x.nbytes, i) + +Base.iterate(x::PyString, i::Int = 1) = + Utils.utf8_iterate(x, j -> @inbounds(unsafe_load(x.ptr + (j - 1))), x.nbytes, i) + +Base.String(x::PyString) = unsafe_string(x.ptr, x.nbytes) diff --git a/src/Wrap/Wrap.jl b/src/Wrap/Wrap.jl index e67129cb..9d82f852 100644 --- a/src/Wrap/Wrap.jl +++ b/src/Wrap/Wrap.jl @@ -14,7 +14,7 @@ using ..Convert using ..PyMacro import ..PythonCall: - PyArray, PyDict, PyIO, PyIterable, PyList, PyPandasDataFrame, PySet, PyTable + PyArray, PyDict, PyIO, PyIterable, PyList, PyPandasDataFrame, PySet, PyString, PyTable using Base: @propagate_inbounds using Tables: Tables @@ -23,6 +23,7 @@ using UnsafePointers: UnsafePtr import ..Core: Py, ispy include("PyIterable.jl") +include("PyString.jl") include("PyDict.jl") include("PyList.jl") include("PySet.jl") diff --git a/test/Wrap.jl b/test/Wrap.jl index 9e23f8a3..477ea921 100644 --- a/test/Wrap.jl +++ b/test/Wrap.jl @@ -1,3 +1,26 @@ +@testitem "PyString" begin + py_s = pystr("héllo 🌍") + y = PyString(py_s) + + @test y isa PyString + @test PythonCall.ispy(y) + @test Py(y) === py_s + + ptr, n = PythonCall.Core.pystr_utf8_pointer(py_s) + @test y.ptr == ptr + @test y.nbytes == n + + expected = "héllo 🌍" + @test String(y) == expected + @test length(y) == length(expected) + @test collect(y) == collect(expected) + @test collect(codeunits(y)) == collect(codeunits(expected)) + + z = PyString("abc") + @test String(z) == "abc" + @test ncodeunits(z) == 3 +end + @testitem "PyArray" begin x = pyimport("array").array("i", pylist([1, 2, 3])) y = PyArray(x)