Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/conversion-to-julia.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ From Python, the arguments to a Julia function will be converted according to th
| `None` | `Missing` |
| `bytes` | `Vector{UInt8}`, `Vector{Int8}`, `String` |
| `str` | `String`, `Symbol`, `Char`, `Vector{UInt8}`, `Vector{Int8}` |
| `str` | `PyString` |
| `range` | `UnitRange` |
| `collections.abc.Mapping` | `Dict` |
| `collections.abc.Iterable` | `Vector`, `Set`, `Tuple`, `NamedTuple`, `Pair` |
Expand Down
1 change: 1 addition & 0 deletions docs/src/pythoncall-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ PyList
PySet
PyDict
PyIterable
PyString
PyArray
PyIO
PyTable
Expand Down
3 changes: 3 additions & 0 deletions docs/src/pythoncall.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ Python: [3, 4, 5, None, 1, 2]

There are wrappers for other container types, such as [`PyDict`](@ref) and [`PySet`](@ref).

`PyString` is a zero-copy wrapper around a Python `str`, exposing it as a Julia
`AbstractString` backed by the UTF-8 pointer cached by Python.

The wrapper [`PyArray`](@ref) provides a Julia array view of any Python array, i.e. anything
satisfying either the buffer protocol or the numpy array interface. This includes things
like `bytes`, `bytearray`, `array.array` and `numpy.ndarray`:
Expand Down
1 change: 1 addition & 0 deletions src/API/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ export PyDict
export PyIO
export PyIterable
export PyList
export PyString
export PyPandasDataFrame
export PySet
export PyTable
Expand Down
21 changes: 21 additions & 0 deletions src/API/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,27 @@ struct PyIterable{T}
PyIterable{T}(x) where {T} = new{T}(Py(x))
end

"""
PyString(x)

Wraps the Python `str` `x` as an `AbstractString` without copying.

The UTF-8 data is stored as a pointer and byte length obtained from
`PyUnicode_AsUTF8AndSize`, and remains valid as long as the underlying Python
object is alive.
"""
struct PyString <: AbstractString
py::Py
ptr::Ptr{UInt8}
nbytes::Int
function PyString(x)
py = Py(x)
PythonCall.Core.pyisstr(py) || throw(ArgumentError("PyString expects a Python `str`"))
ptr, n = PythonCall.Core.pystr_utf8_pointer(py)
new(py, ptr, n)
end
end

"""
PyList{T=Py}([x])

Expand Down
1 change: 1 addition & 0 deletions src/C/pointers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ const CAPI_FUNC_SIGS = Dict{Symbol,Pair{Tuple,Type}}(
:PyComplex_AsCComplex => (PyPtr,) => Py_complex,
# STR
:PyUnicode_DecodeUTF8 => (Ptr{Cchar}, Py_ssize_t, Ptr{Cchar}) => PyPtr,
:PyUnicode_AsUTF8AndSize => (PyPtr, Ptr{Py_ssize_t}) => Ptr{Cchar},
:PyUnicode_AsUTF8String => (PyPtr,) => PyPtr,
:PyUnicode_InternInPlace => (Ptr{PyPtr},) => Cvoid,
# BYTES
Expand Down
3 changes: 2 additions & 1 deletion src/Convert/pyconvert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ function pyconvert_rule_fast(::Type{T}, x::Py) where {T}
pyisfloat(x) && return pyconvert_return(T(pyfloat_asdouble(x)))
elseif (T == ComplexF64)
pyiscomplex(x) && return pyconvert_return(T(pycomplex_ascomplex(x)))
elseif (T == String) | (T == Char) | (T == Symbol)
elseif (T == String) | (T == Char) | (T == Symbol) | (T == PyString)
pyisstr(x) && return pyconvert_rule_str(T, x)
elseif (T == Vector{UInt8}) | (T == Base.CodeUnits{UInt8,String})
pyisbytes(x) && return pyconvert_rule_bytes(T, x)
Expand Down Expand Up @@ -473,6 +473,7 @@ function init_pyconvert()
pyconvert_add_rule("builtins:float", Nothing, pyconvert_rule_float, priority)
pyconvert_add_rule("builtins:float", Missing, pyconvert_rule_float, priority)
pyconvert_add_rule("numbers:Complex", Number, pyconvert_rule_complex, priority)
pyconvert_add_rule("builtins:str", PyString, pyconvert_rule_str, priority)
pyconvert_add_rule("numbers:Integral", Number, pyconvert_rule_int, priority)
pyconvert_add_rule("builtins:str", Symbol, pyconvert_rule_str, priority)
pyconvert_add_rule("builtins:str", Char, pyconvert_rule_str, priority)
Expand Down
1 change: 1 addition & 0 deletions src/Convert/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ pyconvert_rule_str(::Type{Char}, x::Py) = begin
pyconvert_unconverted()
end
end
pyconvert_rule_str(::Type{PyString}, x::Py) = pyconvert_return(PyString(x))

### bytes

Expand Down
7 changes: 7 additions & 0 deletions src/Core/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,13 @@ pystr_asUTF8vector(x::Py) =
pystr_asstring(x::Py) =
(b = pystr_asUTF8bytes(x); ans = pybytes_asUTF8string(b); pydel!(b); ans)

function pystr_utf8_pointer(x::Py)
n = Ref{C.Py_ssize_t}()
p = C.PyUnicode_AsUTF8AndSize(x, n)
p == C_NULL && pythrow()
Ptr{UInt8}(p), Int(n[])
end

function pystr_intern!(x::Py)
ptr = Ref(getptr(x))
C.PyUnicode_InternInPlace(ptr)
Expand Down
139 changes: 74 additions & 65 deletions src/Utils/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,109 +180,79 @@ size_to_fstrides(elsz::Integer, sz::Tuple{Vararg{Integer}}) =
size_to_cstrides(elsz::Integer, sz::Tuple{Vararg{Integer}}) =
isempty(sz) ? () : (size_to_cstrides(elsz * sz[end], sz[1:end-1])..., elsz)

struct StaticString{T,N} <: AbstractString
codeunits::NTuple{N,T}
StaticString{T,N}(codeunits::NTuple{N,T}) where {T,N} = new{T,N}(codeunits)
end

function Base.String(x::StaticString{T,N}) where {T,N}
ts = x.codeunits
n = N
while n > 0 && iszero(ts[n])
n -= 1
function utf8_allzeros(codeunit::F, n::Int, i::Int) where {F}
@inbounds for j in i:n
iszero(codeunit(j)) || return false
end
cs = T[ts[i] for i = 1:n]
transcode(String, cs)
return true
end

function Base.convert(::Type{StaticString{T,N}}, x::AbstractString) where {T,N}
ts = transcode(T, convert(String, x))
n = length(ts)
n > N && throw(InexactError(:convert, StaticString{T,N}, x))
n > 0 && iszero(ts[n]) && throw(InexactError(:convert, StaticString{T,N}, x))
z = zero(T)
cs = ntuple(i -> i > n ? z : @inbounds(ts[i]), N)
StaticString{T,N}(cs)
end

StaticString{T,N}(x::AbstractString) where {T,N} = convert(StaticString{T,N}, x)

Base.ncodeunits(x::StaticString{T,N}) where {T,N} = N

Base.codeunit(x::StaticString, i::Integer) = x.codeunits[i]

Base.codeunit(x::StaticString{T}) where {T} = T

function Base.isvalid(x::StaticString{UInt8,N}, i::Int) where {N}
if i < 1 || i > N
function utf8_isvalid(codeunit::F, n::Int, i::Int; zeroterminated::Bool = false) where {F}
if i < 1 || i > n
return false
end
cs = x.codeunits
c = @inbounds cs[i]
if all(iszero, (cs[j] for j = i:N))
return false
elseif (c & 0x80) == 0x00
zeroterminated && utf8_allzeros(codeunit, n, i) && return false
c = @inbounds codeunit(i)
if (c & 0x80) == 0x00
return true
elseif (c & 0x40) == 0x00
return false
elseif (c & 0x20) == 0x00
return @inbounds (i ≤ N - 1) && ((cs[i+1] & 0xC0) == 0x80)
return @inbounds (i ≤ n - 1) && ((codeunit(i + 1) & 0xC0) == 0x80)
elseif (c & 0x10) == 0x00
return @inbounds (i ≤ N - 2) &&
((cs[i+1] & 0xC0) == 0x80) &&
((cs[i+2] & 0xC0) == 0x80)
return @inbounds (i ≤ n - 2) &&
((codeunit(i + 1) & 0xC0) == 0x80) &&
((codeunit(i + 2) & 0xC0) == 0x80)
elseif (c & 0x08) == 0x00
return @inbounds (i ≤ N - 3) &&
((cs[i+1] & 0xC0) == 0x80) &&
((cs[i+2] & 0xC0) == 0x80) &&
((cs[i+3] & 0xC0) == 0x80)
return @inbounds (i ≤ n - 3) &&
((codeunit(i + 1) & 0xC0) == 0x80) &&
((codeunit(i + 2) & 0xC0) == 0x80) &&
((codeunit(i + 3) & 0xC0) == 0x80)
else
return false
end
return false
end

function Base.iterate(x::StaticString{UInt8,N}, i::Int = 1) where {N}
i > N && return
cs = x.codeunits
c = @inbounds cs[i]
if all(iszero, (cs[j] for j = i:N))
function utf8_iterate(x, codeunit::F, n::Int, i::Int = 1; zeroterminated::Bool = false) where {F}
i > n && return
c = @inbounds codeunit(i)
if zeroterminated && utf8_allzeros(codeunit, n, i)
return
elseif (c & 0x80) == 0x00
return (reinterpret(Char, UInt32(c) << 24), i + 1)
elseif (c & 0x40) == 0x00
nothing
elseif (c & 0x20) == 0x00
if @inbounds (i ≤ N - 1) && ((cs[i+1] & 0xC0) == 0x80)
if @inbounds (i ≤ n - 1) && ((codeunit(i + 1) & 0xC0) == 0x80)
return (
reinterpret(Char, (UInt32(cs[i]) << 24) | (UInt32(cs[i+1]) << 16)),
reinterpret(Char, (UInt32(codeunit(i)) << 24) | (UInt32(codeunit(i + 1)) << 16)),
i + 2,
)
end
elseif (c & 0x10) == 0x00
if @inbounds (i ≤ N - 2) && ((cs[i+1] & 0xC0) == 0x80) && ((cs[i+2] & 0xC0) == 0x80)
if @inbounds (i ≤ n - 2) && ((codeunit(i + 1) & 0xC0) == 0x80) && ((codeunit(i + 2) & 0xC0) == 0x80)
return (
reinterpret(
Char,
(UInt32(cs[i]) << 24) |
(UInt32(cs[i+1]) << 16) |
(UInt32(cs[i+2]) << 8),
(UInt32(codeunit(i)) << 24) |
(UInt32(codeunit(i + 1)) << 16) |
(UInt32(codeunit(i + 2)) << 8),
),
i + 3,
)
end
elseif (c & 0x08) == 0x00
if @inbounds (i ≤ N - 3) &&
((cs[i+1] & 0xC0) == 0x80) &&
((cs[i+2] & 0xC0) == 0x80) &&
((cs[i+3] & 0xC0) == 0x80)
if @inbounds (i ≤ n - 3) &&
((codeunit(i + 1) & 0xC0) == 0x80) &&
((codeunit(i + 2) & 0xC0) == 0x80) &&
((codeunit(i + 3) & 0xC0) == 0x80)
return (
reinterpret(
Char,
(UInt32(cs[i]) << 24) |
(UInt32(cs[i+1]) << 16) |
(UInt32(cs[i+2]) << 8) |
UInt32(cs[i+3]),
(UInt32(codeunit(i)) << 24) |
(UInt32(codeunit(i + 1)) << 16) |
(UInt32(codeunit(i + 2)) << 8) |
UInt32(codeunit(i + 3)),
),
i + 4,
)
Expand All @@ -291,6 +261,45 @@ function Base.iterate(x::StaticString{UInt8,N}, i::Int = 1) where {N}
throw(StringIndexError(x, i))
end

struct StaticString{T,N} <: AbstractString
codeunits::NTuple{N,T}
StaticString{T,N}(codeunits::NTuple{N,T}) where {T,N} = new{T,N}(codeunits)
end

function Base.String(x::StaticString{T,N}) where {T,N}
ts = x.codeunits
n = N
while n > 0 && iszero(ts[n])
n -= 1
end
cs = T[ts[i] for i = 1:n]
transcode(String, cs)
end

function Base.convert(::Type{StaticString{T,N}}, x::AbstractString) where {T,N}
ts = transcode(T, convert(String, x))
n = length(ts)
n > N && throw(InexactError(:convert, StaticString{T,N}, x))
n > 0 && iszero(ts[n]) && throw(InexactError(:convert, StaticString{T,N}, x))
z = zero(T)
cs = ntuple(i -> i > n ? z : @inbounds(ts[i]), N)
StaticString{T,N}(cs)
end

StaticString{T,N}(x::AbstractString) where {T,N} = convert(StaticString{T,N}, x)

Base.ncodeunits(x::StaticString{T,N}) where {T,N} = N

Base.codeunit(x::StaticString, i::Integer) = x.codeunits[i]

Base.codeunit(x::StaticString{T}) where {T} = T

Base.isvalid(x::StaticString{UInt8,N}, i::Int) where {N} =
utf8_isvalid(j -> @inbounds(x.codeunits[j]), N, i; zeroterminated = true)

Base.iterate(x::StaticString{UInt8,N}, i::Int = 1) where {N} =
utf8_iterate(x, j -> @inbounds(x.codeunits[j]), N, i; zeroterminated = true)

function Base.isvalid(x::StaticString{UInt32,N}, i::Int) where {N}
i < 1 && return false
cs = x.codeunits
Expand Down
15 changes: 15 additions & 0 deletions src/Wrap/PyString.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
ispy(::PyString) = true
Py(x::PyString) = x.py

Base.ncodeunits(x::PyString) = x.nbytes
Base.codeunit(::PyString) = UInt8
Base.codeunit(::Type{PyString}) = UInt8
Base.codeunit(x::PyString, i::Integer) = @inbounds unsafe_load(x.ptr + (i - 1))

Base.isvalid(x::PyString, i::Int) =
Utils.utf8_isvalid(j -> @inbounds(unsafe_load(x.ptr + (j - 1))), x.nbytes, i)

Base.iterate(x::PyString, i::Int = 1) =
Utils.utf8_iterate(x, j -> @inbounds(unsafe_load(x.ptr + (j - 1))), x.nbytes, i)

Base.String(x::PyString) = unsafe_string(x.ptr, x.nbytes)
3 changes: 2 additions & 1 deletion src/Wrap/Wrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using ..Convert
using ..PyMacro

import ..PythonCall:
PyArray, PyDict, PyIO, PyIterable, PyList, PyPandasDataFrame, PySet, PyTable
PyArray, PyDict, PyIO, PyIterable, PyList, PyPandasDataFrame, PySet, PyString, PyTable

using Base: @propagate_inbounds
using Tables: Tables
Expand All @@ -23,6 +23,7 @@ using UnsafePointers: UnsafePtr
import ..Core: Py, ispy

include("PyIterable.jl")
include("PyString.jl")
include("PyDict.jl")
include("PyList.jl")
include("PySet.jl")
Expand Down
23 changes: 23 additions & 0 deletions test/Wrap.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,26 @@
@testitem "PyString" begin
py_s = pystr("héllo 🌍")
y = PyString(py_s)

@test y isa PyString
@test PythonCall.ispy(y)
@test Py(y) === py_s

ptr, n = PythonCall.Core.pystr_utf8_pointer(py_s)
@test y.ptr == ptr
@test y.nbytes == n

expected = "héllo 🌍"
@test String(y) == expected
@test length(y) == length(expected)
@test collect(y) == collect(expected)
@test collect(codeunits(y)) == collect(codeunits(expected))

z = PyString("abc")
@test String(z) == "abc"
@test ncodeunits(z) == 3
end

@testitem "PyArray" begin
x = pyimport("array").array("i", pylist([1, 2, 3]))
y = PyArray(x)
Expand Down
Loading