-
Notifications
You must be signed in to change notification settings - Fork 3
samplers: create HeadTableSampler and RandomStartTableSampler table samplers #151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements two new table sampling strategies: HeadTableSampler for retrieving the first N rows of a table, and RandomStartTableSampler for sampling consecutive rows from a random starting position.
Changes:
- Introduced abstract
TableSamplerbase class for table sampling operations - Implemented
HeadTableSamplerto sample first N rows from a table - Implemented
RandomStartTableSamplerto sample consecutive rows from a random starting point
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| agentune/core/sampler/base.py | Added abstract TableSampler base class defining the interface for table sampling |
| agentune/core/sampler/table_samples.py | Implemented HeadTableSampler and RandomStartTableSampler concrete classes |
| tests/agentune/core/sampler/test_table_samplers.py | Comprehensive test suite for both sampler implementations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Dataset containing the first sample_size rows | ||
| """ | ||
| table_name = str(table.table.name) | ||
| sql_query = f'SELECT * FROM {table_name} LIMIT {sample_size}' |
Copilot
AI
Jan 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SQL query is constructed using f-string interpolation without sanitization. While table_name comes from a DuckdbName object which should be safe, sample_size is a user-provided integer parameter that should be validated or parameterized to prevent potential SQL injection.
| table_name = str(table.table.name) | ||
|
|
||
| # Get table size | ||
| count_query = f'SELECT COUNT(*) as count FROM {table_name}' |
Copilot
AI
Jan 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SQL queries use f-string interpolation for start_rowid and sample_size parameters. These values should be validated or use parameterized queries to ensure they are safe integers and prevent potential SQL injection vectors.
|
|
||
| # Select consecutive rows starting from the random rowid | ||
| # Using DuckDB's built-in rowid pseudocolumn for deterministic and efficient filtering | ||
| sql_query = f'SELECT * FROM {table_name} WHERE rowid >= {start_rowid} LIMIT {sample_size}' |
Copilot
AI
Jan 12, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SQL queries use f-string interpolation for start_rowid and sample_size parameters. These values should be validated or use parameterized queries to ensure they are safe integers and prevent potential SQL injection vectors.
|
|
||
| # Select consecutive rows starting from the random rowid | ||
| # Using DuckDB's built-in rowid pseudocolumn for deterministic and efficient filtering | ||
| sql_query = f'SELECT * FROM {table_name} WHERE rowid >= {start_rowid} LIMIT {sample_size}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here too, best to write 'order by rowid' explicitly whenever we rely on it.
Also, I agree with copilot: we should use parameterized queries whenever possible, as a general best practice.
f125c36 to
be4f8cb
Compare
agentune/core/sampler/base.py
Outdated
| class TableSampler(ABC): | ||
| """Abstract base class for data sampling from TableWithJoinStrategies.""" | ||
| @abstractmethod | ||
| def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = None) -> Dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The interface should accept just a DuckdbName, it doesn't need a DuckdbTable and certainly not the one that's WithStrategies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the future we may want to use information from join strategies to customize the examples from each table, for example for conversations we may want a few complete conversations, or for lookup we may want to use a specific selection according to the lookup field.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
after talking to leonid - I did change it to be DuckdbName | str
agentune/core/sampler/base.py
Outdated
| class TableSampler(ABC): | ||
| """Abstract base class for data sampling from TableWithJoinStrategies.""" | ||
| @abstractmethod | ||
| def sample(self, table: TableWithJoinStrategies, conn: DuckDBPyConnection, sample_size: int, random_seed: int | None = None) -> Dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the default for random seeds should be a constant (traditionally 42) rather than None. The default behavior should be deterministic. (This also applies to the existing DataSampler. @leonidb WDYT?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer deterministic defaults too. The user can override it with None, if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed the default to 42 for the TableSampler, for the existing DataSampler we can do in another PR
| This sampler executes a simple SELECT * query with a LIMIT clause | ||
| to retrieve the head of the table. The order of rows is determined | ||
| by the table's natural order (or index if present). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this 'or index if present'? I know of no such thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the docstring
| }) | ||
|
|
||
| # Create actual table in DuckDB (not just register) | ||
| qualified_name = DuckdbName.qualify(table_name, conn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tip: you can skip all this code by working with higher-level APIs. Given a dataframe data,
DatasetSink.into_unqualified_duckdb_table(table_name).write(Dataset.from_polars(data).as_source(), conn)(It's also a one-liner when using the RunContext API)
BTW, the code as you wrote it isn't quite correct because to preserve the schema correctly (for an arbitrary dataframe) you need to call restore_relation_types which, at the API level you were using, means registering the Relation returned by Dataset.to_duckdb and not the dataframe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great!! I updated the code
| result = sampler.sample(table, conn, sample_size=20) | ||
|
|
||
| # Validate result | ||
| assert result.data.height == 20 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tip: it's both shorter and more comprehensive (includes dtypes) to compare dataframes: result.data.equals(expected_data.head(20)). However, you'd need the expected_data dataframe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, I now read expected = conn.table(str(table.table.name)).pl(), I think this is the correct way to do that for this small scale tests
| # Different seeds should (very likely) produce different starting points | ||
| ids1 = result1.data['id'].to_list() | ||
| ids2 = result2.data['id'].to_list() | ||
| assert ids1[0] != ids2[0] or ids1 != ids2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either side of the or should be enough, why test both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, I changed to only compare the first id values
be4f8cb to
abd4a31
Compare
0e825c5 to
a1a9ec2
Compare
Reviewed by me, according to comments
What does this PR do?
implement HeadTableSampler and RandomStartTableSampler for table sampling
Changes
Related Issues
Fixes https://github.com/SparkBeyond/ao-core/issues/87