Skip to content

Dataset Splitting

Functions for splitting DerivaML datasets into training and testing subsets with full provenance tracking. Supports random, stratified, and custom selection strategies.

Generic dataset splitting for DerivaML.

This module provides functions to split a DerivaML dataset into training, testing, and optionally validation subsets with full provenance tracking. It works with any DerivaML catalog and any registered element type.

The splitting API follows scikit-learn conventions (test_size, train_size, val_size, shuffle, seed, stratify) while integrating with DerivaML's dataset hierarchy, execution provenance, and versioning.

Splitting Strategies

Random (default): Shuffles members and splits at the partition boundaries. No denormalization required.

Stratified: Maintains class distribution across splits using scikit-learn's stratified splitting. Requires specifying a column to stratify by from the denormalized DataFrame.

Custom: Users can provide a SelectionFunction callable for arbitrary selection logic (balanced labels, filtered subsets, etc.).

Example

split_dataset runs inside an Execution the caller has already opened. The caller's workflow identifies the code making the splitting decision; deriva-ml never invents a workflow on the caller's behalf, so this function is safe to call from environments without a git checkout (notebook kernels, MCP servers, scheduled jobs) as long as the caller has wired up a workflow with honest provenance::

from deriva_ml import DerivaML
from deriva_ml.dataset.split import split_dataset
from deriva_ml.execution import ExecutionConfiguration

ml = DerivaML("localhost", "9")

workflow = ml.create_workflow(
    name="My splitting script",
    workflow_type="Dataset_Split",
    description="80/20 train/test for sleep-stage classifier v3",
)
config = ExecutionConfiguration(workflow=workflow)

with ml.create_execution(config) as exe:
    result = split_dataset(ml, "28D0", exe, test_size=0.2, seed=42)
exe.commit_output_assets(clean_folder=True)

Three-way train/val/test split (same execution, reuse exe)::

result = split_dataset(
    ml, "28D0", exe,
    test_size=0.2,
    val_size=0.1,
    seed=42,
)

Stratified split::

result = split_dataset(
    ml, "28D0", exe,
    test_size=0.2,
    stratify_by_column="Image_Class.Name",
    include_tables=["Image", "Image_Class"],
)

Custom selection function::

def my_selector(df, partition_sizes, seed):
    # Custom logic...
    return {"Training": train_indices, "Testing": test_indices}

result = split_dataset(
    ml, "28D0", exe,
    test_size=100,
    selection_fn=my_selector,
    include_tables=["Image", "Image_Classification"],
)
See Also
  • sklearn.model_selection.train_test_split
  • Dataset.get_denormalized_as_dataframe
  • Dataset.list_dataset_members

PartitionInfo

Bases: BaseModel

Information about a single partition (Training, Testing, or Validation).

Source code in src/deriva_ml/dataset/split.py
113
114
115
116
117
118
class PartitionInfo(BaseModel):
    """Information about a single partition (Training, Testing, or Validation)."""

    rid: str
    version: str
    count: int

SelectionFunction

Bases: Protocol

Protocol for custom partition selection functions.

A selection function receives the denormalized dataset DataFrame and returns a dict mapping partition names to integer index arrays into the DataFrame rows.

The function is responsible for:

  • Deciding which records go into each partition
  • Ensuring the sizes match the requested partition_sizes
  • Implementing any balancing or stratification logic

Parameters:

Name Type Description Default
df

Denormalized DataFrame from dataset.get_denormalized_as_dataframe(). Columns use dot notation Table.column (e.g., Image.RID, Image_Class.Name) — see :func:denormalize_column_name.

required
partition_sizes

Dict mapping partition names (e.g., "Training", "Testing", "Validation") to the number of records for each.

required
seed

Random seed for reproducibility.

required

Returns:

Type Description

Dict mapping partition names to numpy arrays of integer indices

into the DataFrame.

Example

def balanced_selector(df, partition_sizes, seed): # doctest: +SKIP ... rng = np.random.default_rng(seed) ... # ... balance classes ... ... return {"Training": train_indices, "Testing": test_indices}

Source code in src/deriva_ml/dataset/split.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
@runtime_checkable
class SelectionFunction(Protocol):
    """Protocol for custom partition selection functions.

    A selection function receives the denormalized dataset DataFrame and
    returns a dict mapping partition names to integer index arrays into
    the DataFrame rows.

    The function is responsible for:

    - Deciding which records go into each partition
    - Ensuring the sizes match the requested partition_sizes
    - Implementing any balancing or stratification logic

    Args:
        df: Denormalized DataFrame from ``dataset.get_denormalized_as_dataframe()``.
            Columns use dot notation ``Table.column`` (e.g., ``Image.RID``,
            ``Image_Class.Name``) — see :func:`denormalize_column_name`.
        partition_sizes: Dict mapping partition names (e.g., "Training",
            "Testing", "Validation") to the number of records for each.
        seed: Random seed for reproducibility.

    Returns:
        Dict mapping partition names to numpy arrays of integer indices
        into the DataFrame.

    Example:
        >>> def balanced_selector(df, partition_sizes, seed):  # doctest: +SKIP
        ...     rng = np.random.default_rng(seed)
        ...     # ... balance classes ...
        ...     return {"Training": train_indices, "Testing": test_indices}
    """

    def __call__(
        self,
        df: pd.DataFrame,
        partition_sizes: dict[str, int],
        seed: int,
    ) -> dict[str, np.ndarray]: ...

SplitResult

Bases: BaseModel

Result of a dataset split operation.

Source code in src/deriva_ml/dataset/split.py
121
122
123
124
125
126
127
128
129
130
131
132
class SplitResult(BaseModel):
    """Result of a dataset split operation."""

    source: str
    split: PartitionInfo
    training: PartitionInfo
    testing: PartitionInfo
    validation: PartitionInfo | None = None
    strategy: str
    element_table: str
    seed: int
    dry_run: bool = False

SubsampleResult

Bases: BaseModel

Result of a :func:subsample operation.

Mirrors :class:SplitResult but carries a single output (subsample) rather than a Split parent + per-partition children. Dry-run instances have rid / version set to "(dry run)" placeholders, matching :class:SplitResult's convention.

Attributes:

Name Type Description
source str

RID of the source dataset that was sampled.

subsample PartitionInfo

:class:PartitionInfo for the produced dataset.

strategy str

Human-readable strategy ("random" or "stratified by ...").

element_table str

The element table the sample was drawn from.

seed int

Random seed used.

dry_run bool

True when the result represents a plan rather than a created dataset.

Source code in src/deriva_ml/dataset/split.py
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class SubsampleResult(BaseModel):
    """Result of a :func:`subsample` operation.

    Mirrors :class:`SplitResult` but carries a single output
    (``subsample``) rather than a Split parent + per-partition
    children. Dry-run instances have ``rid`` / ``version`` set to
    ``"(dry run)"`` placeholders, matching :class:`SplitResult`'s
    convention.

    Attributes:
        source: RID of the source dataset that was sampled.
        subsample: :class:`PartitionInfo` for the produced dataset.
        strategy: Human-readable strategy (``"random"`` or
            ``"stratified by ..."``).
        element_table: The element table the sample was drawn from.
        seed: Random seed used.
        dry_run: ``True`` when the result represents a plan rather
            than a created dataset.
    """

    source: str
    subsample: PartitionInfo
    strategy: str
    element_table: str
    seed: int
    dry_run: bool = False

main

main() -> int

CLI entry point for deriva-ml-split-dataset.

Parses command-line arguments, connects to a DerivaML catalog, and splits the specified dataset into training, testing, and optionally validation subsets.

Returns:

Type Description
int

Exit code: 0 for success, 1 for failure.

Source code in src/deriva_ml/dataset/split.py
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
2396
2397
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
2475
2476
2477
2478
2479
2480
2481
2482
2483
2484
2485
2486
2487
2488
2489
2490
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
2505
2506
2507
2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
2543
2544
2545
2546
2547
2548
2549
2550
2551
2552
2553
def main() -> int:
    """CLI entry point for ``deriva-ml-split-dataset``.

    Parses command-line arguments, connects to a DerivaML catalog, and
    splits the specified dataset into training, testing, and optionally
    validation subsets.

    Returns:
        Exit code: 0 for success, 1 for failure.
    """
    import argparse
    import sys
    import textwrap

    parser = argparse.ArgumentParser(
        description="Split a DerivaML dataset into training/testing/validation subsets",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog=textwrap.dedent("""\
        Examples:
            # Simple random 80/20 split
            deriva-ml-split-dataset --hostname localhost --catalog-id 9 \\
                --dataset-rid 28D0

            # Three-way train/val/test split
            deriva-ml-split-dataset --hostname localhost --catalog-id 9 \\
                --dataset-rid 28D0 --val-size 0.1

            # Stratified split by class label (stratify on the vocab
            # table's Name column, reached transparently through the
            # Image_Classification feature)
            deriva-ml-split-dataset --hostname localhost --catalog-id 9 \\
                --dataset-rid 28D0 \\
                --stratify-by-column Image_Class.Name \\
                --include-tables Image,Image_Class

            # Fixed-count split
            deriva-ml-split-dataset --hostname localhost --catalog-id 9 \\
                --dataset-rid 28D0 --train-size 400 --test-size 100

            # Dry run (show plan without modifying catalog)
            deriva-ml-split-dataset --hostname localhost --catalog-id 9 \\
                --dataset-rid 28D0 --dry-run

        For more information, see:
            https://github.com/informatics-isi-edu/deriva-ml
        """),
    )

    # Connection parameters
    parser.add_argument(
        "--hostname",
        required=True,
        help="Deriva server hostname (e.g., localhost, ml.derivacloud.org)",
    )
    parser.add_argument(
        "--catalog-id",
        required=True,
        help="Catalog ID to connect to",
    )
    parser.add_argument(
        "--domain-schema",
        help="Domain schema name (auto-detected if not provided)",
    )

    # Source dataset
    parser.add_argument(
        "--dataset-rid",
        required=True,
        help="RID of the source dataset to split",
    )

    # Split parameters (scikit-learn conventions)
    parser.add_argument(
        "--test-size",
        type=float,
        default=0.2,
        help="Test set size as fraction (0-1) or absolute count (default: 0.2)",
    )
    parser.add_argument(
        "--train-size",
        type=float,
        default=None,
        help="Train set size as fraction (0-1) or absolute count (default: complement of test-size)",
    )
    parser.add_argument(
        "--val-size",
        type=float,
        default=None,
        help="Validation set size as fraction (0-1) or absolute count (default: None, no validation split)",
    )
    parser.add_argument(
        "--no-shuffle",
        action="store_true",
        help="Do not shuffle before splitting",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed for reproducibility (default: 42)",
    )
    parser.add_argument(
        "--stratify-by-column",
        help="Column name in denormalized DataFrame (dot notation) for stratified "
        "splitting (e.g., Image_Class.Name). Requires --include-tables.",
    )
    parser.add_argument(
        "--stratify-missing",
        choices=["error", "drop", "include"],
        default="error",
        help="Policy for null values in the stratify column: "
        "'error' (default) raises, 'drop' excludes nulls, "
        "'include' treats nulls as a separate class.",
    )

    # DerivaML parameters
    parser.add_argument(
        "--element-table",
        help="Element table to split (e.g., Image). Auto-detected if omitted.",
    )
    parser.add_argument(
        "--include-tables",
        help="Comma-separated tables for denormalization (e.g., Image,Image_Class). Required for stratified splitting.",
    )
    parser.add_argument(
        "--row-per",
        help="Explicit leaf table for denormalization. Defaults to --element-table when stratifying (issue #174).",
    )
    parser.add_argument(
        "--via",
        help="Comma-separated tables forced into the join chain without "
        "contributing columns (denormalizer via= parameter). Use to "
        "disambiguate path ambiguity (Rule 6) without polluting output.",
    )
    parser.add_argument(
        "--ignore-unrelated-anchors",
        action="store_true",
        help="Silently drop dataset anchors with no FK path to any requested table (denormalizer Rule 8 escape hatch).",
    )
    parser.add_argument(
        "--partition-by",
        choices=["element", "row"],
        default=None,
        help="Partition unit: 'element' dedupes per element_table RID before "
        "partitioning (disjoint at the element level); 'row' partitions "
        "denormalized rows directly (element RIDs may overlap). Required when "
        "--row-per is set and differs from --element-table; auto-defaults to "
        "'element' otherwise.",
    )
    parser.add_argument(
        "--training-types",
        default="Labeled",
        help="Comma-separated additional dataset types for training set (default: Labeled)",
    )
    parser.add_argument(
        "--testing-types",
        default="Labeled",
        help="Comma-separated additional dataset types for testing set (default: Labeled)",
    )
    parser.add_argument(
        "--validation-types",
        default="Labeled",
        help="Comma-separated additional dataset types for validation set (default: Labeled)",
    )
    parser.add_argument(
        "--description",
        default="",
        help="Description for the parent split dataset",
    )
    parser.add_argument(
        "--workflow-type",
        default="Dataset_Split",
        help="Workflow type vocabulary term (default: Dataset_Split)",
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Print plan without modifying catalog",
    )
    parser.add_argument(
        "--show-urls",
        action="store_true",
        help="Show Chaise web interface URLs for created datasets",
    )

    args = parser.parse_args()

    # Configure logging
    handler = logging.StreamHandler(sys.stderr)
    handler.setLevel(logging.INFO)
    handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
    root_logger = get_logger()
    root_logger.addHandler(handler)
    root_logger.setLevel(logging.INFO)

    sys.stdout.reconfigure(line_buffering=True)
    sys.stderr.reconfigure(line_buffering=True)

    try:
        from deriva_ml import DerivaML
        from deriva_ml.execution import ExecutionConfiguration

        # Connect
        logger.info(f"Connecting to {args.hostname}, catalog {args.catalog_id}")
        ml = DerivaML(
            hostname=args.hostname,
            catalog_id=str(args.catalog_id),
            domain_schemas={args.domain_schema} if args.domain_schema else None,
        )
        logger.info(f"Connected, domain schema: {ml.default_schema}")

        # Parse comma-separated lists
        include_tables = [t.strip() for t in args.include_tables.split(",")] if args.include_tables else None
        training_types = [t.strip() for t in args.training_types.split(",")] if args.training_types else None
        testing_types = [t.strip() for t in args.testing_types.split(",")] if args.testing_types else None
        validation_types = [t.strip() for t in args.validation_types.split(",")] if args.validation_types else None
        via = [t.strip() for t in args.via.split(",")] if args.via else None

        # Dry-run: skip workflow/execution overhead entirely. split_dataset's
        # dry-run path doesn't touch the catalog and doesn't need a live
        # execution -- pass a sentinel so the type-check is satisfied and the
        # early-return at the top of split_dataset fires before any execution
        # methods are called.
        if args.dry_run:
            result = split_dataset(
                ml=ml,
                source_dataset_rid=args.dataset_rid,
                execution=None,  # type: ignore[arg-type]  -- dry-run returns before use
                test_size=args.test_size,
                train_size=args.train_size,
                val_size=args.val_size,
                shuffle=not args.no_shuffle,
                seed=args.seed,
                stratify_by_column=args.stratify_by_column,
                stratify_missing=args.stratify_missing,
                split_description=args.description,
                training_types=training_types,
                testing_types=testing_types,
                validation_types=validation_types,
                element_table=args.element_table,
                include_tables=include_tables,
                row_per=args.row_per,
                via=via,
                ignore_unrelated_anchors=args.ignore_unrelated_anchors,
                partition_by=args.partition_by,
                dry_run=True,
            )
        else:
            # The CLI itself is the caller -- it lives in a git checkout
            # of deriva-ml, so its workflow URL/checksum come from this
            # script's git context (via the Workflow validator's
            # built-in introspection). The MCP server, by contrast,
            # would never reach this code path -- it opens its own
            # execution from a caller-supplied workflow_rid.
            workflow = ml.create_workflow(
                name=f"deriva-ml-split-dataset CLI: {args.dataset_rid}",
                workflow_type=args.workflow_type,
                description="Split dataset via the deriva-ml-split-dataset CLI",
            )
            with ml.create_execution(
                ExecutionConfiguration(
                    workflow=workflow,
                    description=args.description or f"Split of {args.dataset_rid}",
                )
            ) as exe:
                result = split_dataset(
                    ml=ml,
                    source_dataset_rid=args.dataset_rid,
                    execution=exe,
                    test_size=args.test_size,
                    train_size=args.train_size,
                    val_size=args.val_size,
                    shuffle=not args.no_shuffle,
                    seed=args.seed,
                    stratify_by_column=args.stratify_by_column,
                    stratify_missing=args.stratify_missing,
                    split_description=args.description,
                    training_types=training_types,
                    testing_types=testing_types,
                    validation_types=validation_types,
                    element_table=args.element_table,
                    include_tables=include_tables,
                    row_per=args.row_per,
                    via=via,
                    ignore_unrelated_anchors=args.ignore_unrelated_anchors,
                    dry_run=False,
                )
            exe.commit_output_assets(clean_folder=True)

        # Print summary
        if args.dry_run:
            print(f"\n{'=' * 60}")
            print("  DRY RUN - No changes will be made")
            print(f"{'=' * 60}")
            print(f"  Source dataset:  {result.source}")
            print(f"  Element table:   {result.element_table}")
            print(f"  Strategy:        {result.strategy}")
            print(f"  Seed:            {result.seed}")
            print(f"  Training size:   {result.training.count}")
            if result.validation:
                print(f"  Validation size: {result.validation.count}")
            print(f"  Testing size:    {result.testing.count}")
            print(f"{'=' * 60}\n")
        else:
            print(f"\n{'=' * 60}")
            print("  SPLIT COMPLETE")
            print(f"{'=' * 60}")
            print(f"  Source dataset:  {result.source}")
            print(f"  Split dataset:   {result.split.rid} (v{result.split.version})")
            print(f"  Training:        {result.training.rid} (v{result.training.version})")
            if result.validation:
                print(f"  Validation:      {result.validation.rid} (v{result.validation.version})")
            print(f"  Testing:         {result.testing.rid} (v{result.testing.version})")

            if args.show_urls:
                print()
                print("  Chaise URLs:")
                for name, info in [
                    ("split", result.split),
                    ("training", result.training),
                    ("validation", result.validation),
                    ("testing", result.testing),
                ]:
                    if info is None:
                        continue
                    try:
                        url = ml.cite(info.rid, current=True)
                        print(f"    {name}: {url}")
                    except Exception:
                        pass

            print(f"{'=' * 60}\n")

        return 0

    except Exception as e:
        logger.error(f"Split failed: {e}")
        return 1

random_split

random_split(
    df: DataFrame,
    partition_sizes: dict[str, int],
    seed: int,
) -> dict[str, np.ndarray]

Random split into N partitions.

Shuffles the DataFrame indices and splits at partition boundaries. This is the default selector used by :func:split_dataset when neither stratify_by_column nor selection_fn is supplied — the unified selector pipeline produces one random partition per name in partition_sizes by handing this function the synthetic dataframe whose only column is the element-table RID.

Parameters:

Name Type Description Default
df DataFrame

Source DataFrame.

required
partition_sizes dict[str, int]

Dict mapping partition names to counts.

required
seed int

Random seed for reproducibility.

required

Returns:

Type Description
dict[str, ndarray]

Dict mapping partition names to index arrays.

Source code in src/deriva_ml/dataset/split.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def random_split(
    df: pd.DataFrame,
    partition_sizes: dict[str, int],
    seed: int,
) -> dict[str, np.ndarray]:
    """Random split into N partitions.

    Shuffles the DataFrame indices and splits at partition boundaries.
    This is the **default selector** used by :func:`split_dataset` when
    neither ``stratify_by_column`` nor ``selection_fn`` is supplied —
    the unified selector pipeline produces one random partition per
    name in ``partition_sizes`` by handing this function the synthetic
    dataframe whose only column is the element-table RID.

    Args:
        df: Source DataFrame.
        partition_sizes: Dict mapping partition names to counts.
        seed: Random seed for reproducibility.

    Returns:
        Dict mapping partition names to index arrays.
    """
    rng = np.random.default_rng(seed)
    total_needed = sum(partition_sizes.values())
    indices = np.arange(len(df))
    rng.shuffle(indices)
    indices = indices[:total_needed]

    result = {}
    offset = 0
    for name, size in partition_sizes.items():
        result[name] = indices[offset : offset + size]
        offset += size
    return result

split_dataset

split_dataset(
    ml: DerivaML,
    source_dataset_rid: str,
    execution: Execution,
    *,
    test_size: float | int = 0.2,
    train_size: float
    | int
    | None = None,
    val_size: float | int | None = None,
    shuffle: bool = True,
    seed: int = 42,
    stratify_by_column: str
    | None = None,
    stratify_missing: str = "error",
    split_description: str = "",
    training_types: list[str]
    | None = None,
    testing_types: list[str]
    | None = None,
    validation_types: list[str]
    | None = None,
    element_table: str | None = None,
    include_tables: list[str]
    | None = None,
    selection_fn: SelectionFunction
    | None = None,
    dry_run: bool = False,
    row_per: str | None = None,
    via: list[str] | None = None,
    ignore_unrelated_anchors: bool = False,
    partition_by: Literal[
        "element", "row"
    ]
    | None = None,
) -> SplitResult

Split a DerivaML dataset into training, testing, and optionally validation subsets.

Creates a new dataset hierarchy in the catalog::

Split (parent, type: "Split")
+-- Training (child, type: "Training", + training_types)
+-- Validation (child, type: "Validation", + validation_types)  # if val_size
+-- Testing (child, type: "Testing", + testing_types)

All operations are performed within an execution context for full provenance tracking.

This function is generic and works with any DerivaML dataset that has registered element types.

Provenance — the source dataset's relationship to the split: The new Split is a standalone, self-contained dataset hierarchy. The source_dataset_rid you pass in is NOT a parent of the Split and the Split is NOT nested under the source: there is no Dataset_Dataset edge between them, and source.list_dataset_children() / list_dataset_relations(source) will not list the Split. That is intentional — the source is an input the split consumed, not a container the split lives inside (nesting the Split under the source would re-partition the source's own members and flip the source's version on every split).

The derivation is instead recorded as **execution provenance**:
``split_dataset`` registers ``source_dataset_rid`` as an input
of ``execution`` (via :meth:`Execution.add_input_dataset`), and
the Split / Training / Testing / Validation datasets as that
execution's outputs. So the walkable path is
``source -> (input of) -> execution -> (output) -> split``:
``execution.list_input_datasets()`` returns the source, and a
lineage walk (``deriva_ml_get_lineage``) reaches the splits
from the source and vice versa. The ``SplitResult.source``
field returned by this call also carries the source RID for
immediate use.

Membership consequence: the Training / Testing / Validation
partitions are carved from the source's elements, so they
**share element rows with the source** (and, in a two-way
split, ``Training`` ∪ ``Testing`` reconstructs the source's
element set). The train/eval relationship therefore lives in
*shared membership*, not in a parent/child lineage edge —
evaluating a model trained on the source against one of these
partitions would leak. Reason about overlap via member sets,
not via the dataset hierarchy.

**Role types do not inherit from the source and do not
propagate to children.** The Training / Testing / Validation
tags on the partition children are assigned based on the
partition's position in the split, **not** copied from the
source's ``dataset_types``. A source tagged ``Testing``
(because it is a testing corpus) produces a Training partition
tagged ``Training`` (because that partition is the training
half of the split). This is intentional: role-axis types
describe a dataset's role in its *immediate context*, not a
property the operation should preserve. See CONTEXT.md's
``Datasets — types and partitions`` subsection for the
canonical three-axis (role / content / origin) framing — the
``training_types`` / ``testing_types`` / ``validation_types``
arguments exist precisely so the caller can propagate
*content-axis* types (e.g., ``Labeled``) onto the children
when that propagation is meaningful.

Parameters:

Name Type Description Default
ml DerivaML

Connected DerivaML instance.

required
source_dataset_rid str

RID of the source dataset to split.

required
execution Execution

A live :class:Execution the caller has already opened (typically via with ml.create_execution(config) as exe:). All datasets created by this split — the parent Split row and the Training / Validation / Testing children — are attributed to this execution, which in turn is attributed to the execution's workflow. The caller owns execution provenance: their workflow URL and checksum identify the code making the splitting decision, and deriva-ml never invents a workflow on the caller's behalf. The caller is responsible for committing the execution (exe.commit_output_assets() / context-manager exit). split_dataset will write a split_config.json artifact into exe.working_dir that the caller's upload will pick up.

required
test_size float | int

If float (0-1), fraction of data for testing. If int, absolute number of test samples. Default: 0.2.

0.2
train_size float | int | None

If float (0-1), fraction of data for training. If int, absolute number of training samples. If None, complement of test_size (and val_size). Default: None.

None
val_size float | int | None

If float (0-1), fraction of data for validation. If int, absolute number of validation samples. If None, no validation split is created (two-way split). Default: None.

None
shuffle bool

Whether to shuffle before splitting. Default: True. Ignored when using stratified or custom selection functions (they handle their own shuffling).

True
seed int

Random seed for reproducibility. Default: 42.

42
stratify_by_column str | None

Column name for stratified splitting. Must be a column in the denormalized DataFrame using dot notation (e.g., Image_Class.Name). Use :meth:Dataset.list_denormalized_columns to discover available columns. Mutually exclusive with selection_fn.

None
stratify_missing str

Policy for null values in the stratify column. "error" (default) raises if any nulls exist, "drop" excludes rows with nulls, "include" treats nulls as a separate class. Only used when stratify_by_column is set.

'error'
split_description str

Description for the parent Split dataset.

''
training_types list[str] | None

Additional dataset types for the training set beyond "Training" (e.g., ["Labeled"]). Default: None.

None
testing_types list[str] | None

Additional dataset types for the testing set beyond "Testing" (e.g., ["Labeled"]). Default: None.

None
validation_types list[str] | None

Additional dataset types for the validation set beyond "Validation" (e.g., ["Labeled"]). Default: None. Ignored when val_size is None.

None
element_table str | None

Name of the element table to split (e.g., "Image"). If None, auto-detected from the source dataset's members.

None
include_tables list[str] | None

Tables to include when denormalizing for the selection function. Required when using stratify_by_column or a custom selection_fn.

None
selection_fn SelectionFunction | None

Custom selection function conforming to the SelectionFunction protocol. Mutually exclusive with stratify_by_column.

None
dry_run bool

If True, return what would happen without modifying catalog.

False
row_per str | None

Explicit leaf table for denormalization (passed through to :meth:Dataset.get_denormalized_as_dataframe). When stratify_by_column or selection_fn is set and row_per is None, defaults to element_table — the natural anchor when partitioning element rows. Set explicitly to override (e.g., when projecting a feature value table's columns through a feature-association bridge and you want one row per feature value). When row_per != element_table the partition unit becomes ambiguous; partition_by must then be set explicitly.

None
via list[str] | None

Tables forced into the join chain without contributing columns (denormalizer via= parameter). Useful to disambiguate path ambiguity (Rule 6) without polluting the output column list.

None
ignore_unrelated_anchors bool

If True, silently drop dataset anchors whose table has no FK path to any requested table. Pass-through to the denormalizer (Rule 8) — useful when the source dataset has heterogeneous member tables and only a subset participates in the split.

False
partition_by Literal['element', 'row'] | None

Explicit declaration of the partition unit when row_per is set and differs from element_table. Either "element" (one element_table RID per partition; dedupe rows before partitioning; enforces within-element agreement on the stratify column) or "row" (one denormalized row per partition; element RIDs may legitimately appear in multiple partitions). Auto-defaults to "element" when row_per is None or equals element_table (the unambiguous case). Required — no default — when row_per is set and differs from element_table. See the "When to use partition_by='element' vs partition_by='row'" section below.

None

When to use partition_by='element' vs partition_by='row': The (row_per, element_table) pair encodes two independent choices that the old API conflated:

- ``element_table`` — what catalog entity does each partition
  collect (Image, Subject, Trial, ...).
- ``row_per`` — how does the denormalized dataframe shape
  its rows (one per element_table RID, one per
  feature-value, one per visit, ...).

When ``row_per`` equals ``element_table`` (or is unset) the
two intents collapse: one element RID = one row, the
selector partitions rows, and the resulting partitions are
naturally disjoint at the element level. This is the
unambiguous case and ``partition_by`` auto-defaults to
``"element"``.

When ``row_per`` differs from ``element_table`` the same
element RID can have multiple denormalized rows (the 1:N
feature case). The selector now faces a real architectural
choice the caller must make explicitly:

``partition_by="element"`` — partition the *elements*. The
dataframe is deduplicated to one row per element_table RID
before the selector runs. Partitions are guaranteed
disjoint at the element-RID level. Use this when downstream
consumers (training loaders, ROC analysis, accuracy
metrics) operate at the element level — every reasonable ML
evaluation does. Requires within-element agreement on any
selector-read column: stratifying on
``Image_Classification.Image_Class`` only makes sense if
every Image_RID has one class. When multiple annotators
disagree per image, resolve them upstream (the deriva-ml
pattern is a separate consensus feature that records the
resolved label per element, written by your adjudication
workflow) and stratify on the consensus feature, not on
the raw annotator rows. ``split_dataset`` enforces this
with a within-element uniformity check that names the
offending RIDs.

``partition_by="row"`` — partition the *rows*. No dedupe,
no uniformity check. Element RIDs may appear in multiple
partitions; this is the expected shape for legitimate
per-row use cases such as per-annotation statistics (each
annotator-image pair scored independently) or time-series
splits within a subject. The caller is responsible for
ensuring partition disjointness at whatever granularity
downstream consumers actually need.

Migration note: callers that previously relied on the
implicit-row-partition behavior of
``row_per=<feature_table>`` get a ``ValueError`` at the
call site directing them to choose. Adding
``partition_by="row"`` restores the prior behavior;
``partition_by="element"`` switches to the safer
per-element semantics (and almost always what the caller
meant).

Returns:

Type Description
SplitResult

SplitResult with partition info for split, training, testing,

SplitResult

and optionally validation datasets.

Raises:

Type Description
ValueError

If sizes are invalid, dataset has no members, or parameters conflict.

Example

split_dataset always runs inside an Execution the caller has already opened — the execution argument is required. Every example below assumes exe is the live execution from::

from deriva_ml import DerivaML
from deriva_ml.dataset.split import split_dataset
from deriva_ml.execution import ExecutionConfiguration

ml = DerivaML("localhost", "9")
workflow = ml.create_workflow(
    name="My splitting script",
    workflow_type="Dataset_Split",
)
config = ExecutionConfiguration(workflow=workflow)

Simple random 80/20 split::

with ml.create_execution(config) as exe:
    result = split_dataset(ml, "28D0", exe, test_size=0.2, seed=42)
print(f"Training: {result.training.rid} ({result.training.count} samples)")
print(f"Testing:  {result.testing.rid} ({result.testing.count} samples)")

Three-way train/val/test split::

result = split_dataset(
    ml, "28D0", exe,
    test_size=0.2,
    val_size=0.1,
    seed=42,
)
print(f"Validation: {result.validation.rid} ({result.validation.count} samples)")

Fixed-count split with labeled types::

result = split_dataset(
    ml, "28D0", exe,
    test_size=100,
    train_size=400,
    seed=42,
    training_types=["Labeled"],
    testing_types=["Labeled"],
)

Stratified split preserving class distribution (one row per Image, projecting the Image_Class vocab term as a column)::

# Image and Image_Class are linked by the feature-
# association table Execution_Image_Image_Classification,
# which is a transparent bridge for the denormalizer.
# Pass the **vocab/value table** (``Image_Class``) in
# ``include_tables``, not the feature-name shorthand
# (``Image_Classification``): the shorthand resolves to
# the feature-association table, which is downstream of
# Image and would trip Rule 5 against the auto-defaulted
# ``row_per="Image"``. Stratify on the dotted column
# against the vocab table.
result = split_dataset(
    ml, "28D0", exe,
    test_size=0.2,
    stratify_by_column="Image_Class.Name",
    include_tables=["Image", "Image_Class"],
    element_table="Image",
    partition_by="element",
)

Override row_per to project one row per feature value instead — per-annotation statistics. Because row_per differs from element_table, partition_by must be set explicitly. "row" accepts that the same Image RID may appear in multiple partitions (its multiple annotation rows can land independently); "element" would dedupe to one row per Image before partitioning and would raise if annotators disagreed::

# Per-annotation statistics — element RIDs may legitimately
# appear in multiple partitions because each annotator-image
# pair is its own observation. The feature-name shorthand
# ``Image_Classification`` resolves to the feature-
# association table; setting ``row_per`` to that table
# explicitly makes the per-observation intent visible.
# Stratify on the FK column on the feature-association
# table (the resolver does not pull the vocab table into
# the join when the shorthand is used with an explicit
# feature-assoc ``row_per``).
result = split_dataset(
    ml, "28D0", exe,
    test_size=0.2,
    stratify_by_column="Execution_Image_Image_Classification.Image_Class",
    include_tables=["Image", "Image_Classification"],
    row_per="Execution_Image_Image_Classification",
    partition_by="row",
)

Note: to get "one row per element with a feature value projected as a column," pass the vocab/value table in include_tables (as in the first stratified example above), not the feature-name shorthand. Rule 5 of the denormalizer rejects the shorthand combined with row_per=<element> because the feature-association table the shorthand resolves to is strictly downstream of the element — aggregation is not supported. To partition by feature observation instead (per-annotation statistics), use the shorthand together with an explicit row_per=<feature-assoc-table> and partition_by="row" as in the second example above.

Stratified split dropping rows with missing labels::

result = split_dataset(
    ml, "28D0", exe,
    test_size=0.2,
    stratify_by_column="Image_Class.Name",
    stratify_missing="drop",
    include_tables=["Image", "Image_Class"],
    element_table="Image",
    partition_by="element",
)

Custom selection function for balanced sampling::

import numpy as np

def balanced_selector(df, partition_sizes, seed):
    rng = np.random.default_rng(seed)
    label_col = "Image_Class.Name"
    classes = df[label_col].unique()
    result = {name: [] for name in partition_sizes}
    for cls in classes:
        cls_indices = df.index[df[label_col] == cls].to_numpy()
        rng.shuffle(cls_indices)
        offset = 0
        for name, size in partition_sizes.items():
            per_class = size // len(classes)
            result[name].extend(cls_indices[offset:offset + per_class])
            offset += per_class
    return {name: np.array(idx) for name, idx in result.items()}

result = split_dataset(
    ml, "28D0", exe,
    test_size=100,
    selection_fn=balanced_selector,
    include_tables=["Image", "Image_Class"],
    element_table="Image",
    partition_by="element",
)

Dry run to preview the split plan without modifying the catalog::

result = split_dataset(
    ml, "28D0", exe,
    test_size=0.2,
    dry_run=True,
)
print(f"Would create: {result.training.count} train, "
      f"{result.testing.count} test")

Use returned RIDs to create a hydra-zen configuration::

from deriva_ml.dataset import DatasetSpecConfig

result = split_dataset(ml, "28D0", exe, test_size=0.2, seed=42)
split_config = DatasetSpecConfig(
    rid=result.split.rid,
    version=result.split.version,
)

Train directly from the split partitions (composition with framework adapters)::

result = split_dataset(ml, "28D0", exe, test_size=0.2, seed=42)
train_bag = ml.lookup_dataset(result.training.rid).download_dataset_bag(
    version=result.training.version
)
test_bag = ml.lookup_dataset(result.testing.rid).download_dataset_bag(
    version=result.testing.version
)
train_ds = train_bag.as_torch_dataset(
    element_type="Image",
    sample_loader=PIL.Image.open,
    targets=["Glaucoma_Grade"],
)
# Each partition bag feeds independently into PyTorch / TensorFlow;
# the split hierarchy IS the train/val/test partitioning.
See Also

DatasetBag.as_torch_dataset, DatasetBag.as_tf_dataset: Build framework-native datasets from any partition bag; same targets / target_transform / missing vocabulary. DatasetBag.restructure_assets: Class-folder layout for ImageFolder-style consumers.

Source code in src/deriva_ml/dataset/split.py
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
def split_dataset(
    ml: DerivaML,
    source_dataset_rid: str,
    execution: Execution,
    *,
    # scikit-learn compatible parameters
    test_size: float | int = 0.2,
    train_size: float | int | None = None,
    val_size: float | int | None = None,
    shuffle: bool = True,
    seed: int = 42,
    stratify_by_column: str | None = None,
    stratify_missing: str = "error",
    # DerivaML-specific parameters
    split_description: str = "",
    training_types: list[str] | None = None,
    testing_types: list[str] | None = None,
    validation_types: list[str] | None = None,
    element_table: str | None = None,
    include_tables: list[str] | None = None,
    selection_fn: SelectionFunction | None = None,
    dry_run: bool = False,
    # Denormalization-control parameters (issue #174)
    row_per: str | None = None,
    via: list[str] | None = None,
    ignore_unrelated_anchors: bool = False,
    # Partition-unit selector
    partition_by: Literal["element", "row"] | None = None,
) -> SplitResult:
    """Split a DerivaML dataset into training, testing, and optionally validation subsets.

    Creates a new dataset hierarchy in the catalog::

        Split (parent, type: "Split")
        +-- Training (child, type: "Training", + training_types)
        +-- Validation (child, type: "Validation", + validation_types)  # if val_size
        +-- Testing (child, type: "Testing", + testing_types)

    All operations are performed within an execution context for
    full provenance tracking.

    This function is generic and works with any DerivaML dataset
    that has registered element types.

    Provenance — the source dataset's relationship to the split:
        The new Split is a **standalone, self-contained** dataset
        hierarchy. The ``source_dataset_rid`` you pass in is **NOT**
        a parent of the Split and the Split is **NOT** nested under
        the source: there is no ``Dataset_Dataset`` edge between them,
        and ``source.list_dataset_children()`` /
        ``list_dataset_relations(source)`` will **not** list the Split.
        That is intentional — the source is an *input* the split
        *consumed*, not a container the split lives inside (nesting
        the Split under the source would re-partition the source's own
        members and flip the source's version on every split).

        The derivation is instead recorded as **execution provenance**:
        ``split_dataset`` registers ``source_dataset_rid`` as an input
        of ``execution`` (via :meth:`Execution.add_input_dataset`), and
        the Split / Training / Testing / Validation datasets as that
        execution's outputs. So the walkable path is
        ``source -> (input of) -> execution -> (output) -> split``:
        ``execution.list_input_datasets()`` returns the source, and a
        lineage walk (``deriva_ml_get_lineage``) reaches the splits
        from the source and vice versa. The ``SplitResult.source``
        field returned by this call also carries the source RID for
        immediate use.

        Membership consequence: the Training / Testing / Validation
        partitions are carved from the source's elements, so they
        **share element rows with the source** (and, in a two-way
        split, ``Training`` ∪ ``Testing`` reconstructs the source's
        element set). The train/eval relationship therefore lives in
        *shared membership*, not in a parent/child lineage edge —
        evaluating a model trained on the source against one of these
        partitions would leak. Reason about overlap via member sets,
        not via the dataset hierarchy.

        **Role types do not inherit from the source and do not
        propagate to children.** The Training / Testing / Validation
        tags on the partition children are assigned based on the
        partition's position in the split, **not** copied from the
        source's ``dataset_types``. A source tagged ``Testing``
        (because it is a testing corpus) produces a Training partition
        tagged ``Training`` (because that partition is the training
        half of the split). This is intentional: role-axis types
        describe a dataset's role in its *immediate context*, not a
        property the operation should preserve. See CONTEXT.md's
        ``Datasets — types and partitions`` subsection for the
        canonical three-axis (role / content / origin) framing — the
        ``training_types`` / ``testing_types`` / ``validation_types``
        arguments exist precisely so the caller can propagate
        *content-axis* types (e.g., ``Labeled``) onto the children
        when that propagation is meaningful.

    Args:
        ml: Connected DerivaML instance.
        source_dataset_rid: RID of the source dataset to split.
        execution: A live :class:`Execution` the caller has already
            opened (typically via ``with ml.create_execution(config) as
            exe:``). All datasets created by this split — the parent
            Split row and the Training / Validation / Testing children —
            are attributed to *this* execution, which in turn is
            attributed to the execution's workflow. The caller owns
            execution provenance: their workflow URL and checksum
            identify the code making the splitting decision, and
            deriva-ml never invents a workflow on the caller's behalf.
            The caller is responsible for committing the execution
            (``exe.commit_output_assets()`` / context-manager exit).
            ``split_dataset`` will write a ``split_config.json``
            artifact into ``exe.working_dir`` that the caller's upload
            will pick up.
        test_size: If float (0-1), fraction of data for testing.
            If int, absolute number of test samples. Default: 0.2.
        train_size: If float (0-1), fraction of data for training.
            If int, absolute number of training samples.
            If None, complement of test_size (and val_size). Default: None.
        val_size: If float (0-1), fraction of data for validation.
            If int, absolute number of validation samples.
            If None, no validation split is created (two-way split).
            Default: None.
        shuffle: Whether to shuffle before splitting. Default: True.
            Ignored when using stratified or custom selection functions
            (they handle their own shuffling).
        seed: Random seed for reproducibility. Default: 42.
        stratify_by_column: Column name for stratified splitting.
            Must be a column in the denormalized DataFrame using dot notation
            (e.g., ``Image_Class.Name``). Use
            :meth:`Dataset.list_denormalized_columns` to discover available columns.
            Mutually exclusive with ``selection_fn``.
        stratify_missing: Policy for null values in the stratify column.
            ``"error"`` (default) raises if any nulls exist,
            ``"drop"`` excludes rows with nulls,
            ``"include"`` treats nulls as a separate class.
            Only used when ``stratify_by_column`` is set.
        split_description: Description for the parent Split dataset.
        training_types: Additional dataset types for the training set
            beyond "Training" (e.g., ``["Labeled"]``). Default: None.
        testing_types: Additional dataset types for the testing set
            beyond "Testing" (e.g., ``["Labeled"]``). Default: None.
        validation_types: Additional dataset types for the validation set
            beyond "Validation" (e.g., ``["Labeled"]``). Default: None.
            Ignored when val_size is None.
        element_table: Name of the element table to split (e.g., "Image").
            If None, auto-detected from the source dataset's members.
        include_tables: Tables to include when denormalizing for the
            selection function. Required when using ``stratify_by_column``
            or a custom ``selection_fn``.
        selection_fn: Custom selection function conforming to the
            ``SelectionFunction`` protocol. Mutually exclusive with
            ``stratify_by_column``.
        dry_run: If True, return what would happen without modifying catalog.
        row_per: Explicit leaf table for denormalization (passed
            through to :meth:`Dataset.get_denormalized_as_dataframe`).
            When ``stratify_by_column`` or ``selection_fn`` is set and
            ``row_per`` is None, defaults to ``element_table`` — the
            natural anchor when partitioning element rows. Set
            explicitly to override (e.g., when projecting a feature
            value table's columns through a feature-association
            bridge and you want one row per feature value). When
            ``row_per != element_table`` the partition unit becomes
            ambiguous; ``partition_by`` must then be set explicitly.
        via: Tables forced into the join chain without contributing
            columns (denormalizer ``via=`` parameter). Useful to
            disambiguate path ambiguity (Rule 6) without polluting
            the output column list.
        ignore_unrelated_anchors: If True, silently drop dataset
            anchors whose table has no FK path to any requested
            table. Pass-through to the denormalizer (Rule 8) — useful
            when the source dataset has heterogeneous member tables
            and only a subset participates in the split.
        partition_by: Explicit declaration of the partition unit
            when ``row_per`` is set and differs from ``element_table``.
            Either ``"element"`` (one element_table RID per
            partition; dedupe rows before partitioning; enforces
            within-element agreement on the stratify column) or
            ``"row"`` (one denormalized row per partition; element
            RIDs may legitimately appear in multiple partitions).
            Auto-defaults to ``"element"`` when ``row_per`` is
            ``None`` or equals ``element_table`` (the unambiguous
            case). Required — no default — when ``row_per`` is set
            and differs from ``element_table``. See the
            "When to use ``partition_by='element'`` vs
            ``partition_by='row'``" section below.

    When to use ``partition_by='element'`` vs ``partition_by='row'``:
        The (``row_per``, ``element_table``) pair encodes two
        independent choices that the old API conflated:

        - ``element_table`` — what catalog entity does each partition
          collect (Image, Subject, Trial, ...).
        - ``row_per`` — how does the denormalized dataframe shape
          its rows (one per element_table RID, one per
          feature-value, one per visit, ...).

        When ``row_per`` equals ``element_table`` (or is unset) the
        two intents collapse: one element RID = one row, the
        selector partitions rows, and the resulting partitions are
        naturally disjoint at the element level. This is the
        unambiguous case and ``partition_by`` auto-defaults to
        ``"element"``.

        When ``row_per`` differs from ``element_table`` the same
        element RID can have multiple denormalized rows (the 1:N
        feature case). The selector now faces a real architectural
        choice the caller must make explicitly:

        ``partition_by="element"`` — partition the *elements*. The
        dataframe is deduplicated to one row per element_table RID
        before the selector runs. Partitions are guaranteed
        disjoint at the element-RID level. Use this when downstream
        consumers (training loaders, ROC analysis, accuracy
        metrics) operate at the element level — every reasonable ML
        evaluation does. Requires within-element agreement on any
        selector-read column: stratifying on
        ``Image_Classification.Image_Class`` only makes sense if
        every Image_RID has one class. When multiple annotators
        disagree per image, resolve them upstream (the deriva-ml
        pattern is a separate consensus feature that records the
        resolved label per element, written by your adjudication
        workflow) and stratify on the consensus feature, not on
        the raw annotator rows. ``split_dataset`` enforces this
        with a within-element uniformity check that names the
        offending RIDs.

        ``partition_by="row"`` — partition the *rows*. No dedupe,
        no uniformity check. Element RIDs may appear in multiple
        partitions; this is the expected shape for legitimate
        per-row use cases such as per-annotation statistics (each
        annotator-image pair scored independently) or time-series
        splits within a subject. The caller is responsible for
        ensuring partition disjointness at whatever granularity
        downstream consumers actually need.

        Migration note: callers that previously relied on the
        implicit-row-partition behavior of
        ``row_per=<feature_table>`` get a ``ValueError`` at the
        call site directing them to choose. Adding
        ``partition_by="row"`` restores the prior behavior;
        ``partition_by="element"`` switches to the safer
        per-element semantics (and almost always what the caller
        meant).

    Returns:
        SplitResult with partition info for split, training, testing,
        and optionally validation datasets.

    Raises:
        ValueError: If sizes are invalid, dataset has no members, or
            parameters conflict.

    Example:
        ``split_dataset`` always runs inside an Execution the caller has
        already opened — the ``execution`` argument is required. Every
        example below assumes ``exe`` is the live execution from::

            from deriva_ml import DerivaML
            from deriva_ml.dataset.split import split_dataset
            from deriva_ml.execution import ExecutionConfiguration

            ml = DerivaML("localhost", "9")
            workflow = ml.create_workflow(
                name="My splitting script",
                workflow_type="Dataset_Split",
            )
            config = ExecutionConfiguration(workflow=workflow)

        Simple random 80/20 split::

            with ml.create_execution(config) as exe:
                result = split_dataset(ml, "28D0", exe, test_size=0.2, seed=42)
            print(f"Training: {result.training.rid} ({result.training.count} samples)")
            print(f"Testing:  {result.testing.rid} ({result.testing.count} samples)")

        Three-way train/val/test split::

            result = split_dataset(
                ml, "28D0", exe,
                test_size=0.2,
                val_size=0.1,
                seed=42,
            )
            print(f"Validation: {result.validation.rid} ({result.validation.count} samples)")

        Fixed-count split with labeled types::

            result = split_dataset(
                ml, "28D0", exe,
                test_size=100,
                train_size=400,
                seed=42,
                training_types=["Labeled"],
                testing_types=["Labeled"],
            )

        Stratified split preserving class distribution (one row per
        Image, projecting the Image_Class vocab term as a column)::

            # Image and Image_Class are linked by the feature-
            # association table Execution_Image_Image_Classification,
            # which is a transparent bridge for the denormalizer.
            # Pass the **vocab/value table** (``Image_Class``) in
            # ``include_tables``, not the feature-name shorthand
            # (``Image_Classification``): the shorthand resolves to
            # the feature-association table, which is downstream of
            # Image and would trip Rule 5 against the auto-defaulted
            # ``row_per="Image"``. Stratify on the dotted column
            # against the vocab table.
            result = split_dataset(
                ml, "28D0", exe,
                test_size=0.2,
                stratify_by_column="Image_Class.Name",
                include_tables=["Image", "Image_Class"],
                element_table="Image",
                partition_by="element",
            )

        Override ``row_per`` to project one row per feature value
        instead — *per-annotation* statistics. Because ``row_per``
        differs from ``element_table``, ``partition_by`` must be set
        explicitly. ``"row"`` accepts that the same Image RID may
        appear in multiple partitions (its multiple annotation
        rows can land independently); ``"element"`` would dedupe
        to one row per Image before partitioning and would raise
        if annotators disagreed::

            # Per-annotation statistics — element RIDs may legitimately
            # appear in multiple partitions because each annotator-image
            # pair is its own observation. The feature-name shorthand
            # ``Image_Classification`` resolves to the feature-
            # association table; setting ``row_per`` to that table
            # explicitly makes the per-observation intent visible.
            # Stratify on the FK column on the feature-association
            # table (the resolver does not pull the vocab table into
            # the join when the shorthand is used with an explicit
            # feature-assoc ``row_per``).
            result = split_dataset(
                ml, "28D0", exe,
                test_size=0.2,
                stratify_by_column="Execution_Image_Image_Classification.Image_Class",
                include_tables=["Image", "Image_Classification"],
                row_per="Execution_Image_Image_Classification",
                partition_by="row",
            )

        Note: to get "one row per element with a feature value
        projected as a column," pass the vocab/value table in
        ``include_tables`` (as in the first stratified example
        above), not the feature-name shorthand. Rule 5 of the
        denormalizer rejects the shorthand combined with
        ``row_per=<element>`` because the feature-association table
        the shorthand resolves to is strictly downstream of the
        element — aggregation is not supported. To partition by
        feature *observation* instead (per-annotation statistics),
        use the shorthand together with an explicit
        ``row_per=<feature-assoc-table>`` and ``partition_by="row"``
        as in the second example above.

        Stratified split dropping rows with missing labels::

            result = split_dataset(
                ml, "28D0", exe,
                test_size=0.2,
                stratify_by_column="Image_Class.Name",
                stratify_missing="drop",
                include_tables=["Image", "Image_Class"],
                element_table="Image",
                partition_by="element",
            )

        Custom selection function for balanced sampling::

            import numpy as np

            def balanced_selector(df, partition_sizes, seed):
                rng = np.random.default_rng(seed)
                label_col = "Image_Class.Name"
                classes = df[label_col].unique()
                result = {name: [] for name in partition_sizes}
                for cls in classes:
                    cls_indices = df.index[df[label_col] == cls].to_numpy()
                    rng.shuffle(cls_indices)
                    offset = 0
                    for name, size in partition_sizes.items():
                        per_class = size // len(classes)
                        result[name].extend(cls_indices[offset:offset + per_class])
                        offset += per_class
                return {name: np.array(idx) for name, idx in result.items()}

            result = split_dataset(
                ml, "28D0", exe,
                test_size=100,
                selection_fn=balanced_selector,
                include_tables=["Image", "Image_Class"],
                element_table="Image",
                partition_by="element",
            )

        Dry run to preview the split plan without modifying the catalog::

            result = split_dataset(
                ml, "28D0", exe,
                test_size=0.2,
                dry_run=True,
            )
            print(f"Would create: {result.training.count} train, "
                  f"{result.testing.count} test")

        Use returned RIDs to create a hydra-zen configuration::

            from deriva_ml.dataset import DatasetSpecConfig

            result = split_dataset(ml, "28D0", exe, test_size=0.2, seed=42)
            split_config = DatasetSpecConfig(
                rid=result.split.rid,
                version=result.split.version,
            )

        Train directly from the split partitions (composition with
        framework adapters)::

            result = split_dataset(ml, "28D0", exe, test_size=0.2, seed=42)
            train_bag = ml.lookup_dataset(result.training.rid).download_dataset_bag(
                version=result.training.version
            )
            test_bag = ml.lookup_dataset(result.testing.rid).download_dataset_bag(
                version=result.testing.version
            )
            train_ds = train_bag.as_torch_dataset(
                element_type="Image",
                sample_loader=PIL.Image.open,
                targets=["Glaucoma_Grade"],
            )
            # Each partition bag feeds independently into PyTorch / TensorFlow;
            # the split hierarchy IS the train/val/test partitioning.

    See Also:
        ``DatasetBag.as_torch_dataset``, ``DatasetBag.as_tf_dataset``:
            Build framework-native datasets from any partition bag; same
            ``targets`` / ``target_transform`` / ``missing`` vocabulary.
        ``DatasetBag.restructure_assets``:
            Class-folder layout for ``ImageFolder``-style consumers.
    """
    # Post Ds-split extraction this function dispatches to three
    # helpers (above):
    #
    # 1. ``_validate_split_inputs`` — argument-shape checks.
    # 2. ``_compute_partitions`` — pure read path (members, sizes,
    #    selection); used by both dry-run and live paths.
    # 3. ``_create_split_hierarchy`` — catalog-writing path
    #    (parent/child datasets, member assignment).

    effective_partition_by = _validate_split_inputs(
        stratify_by_column=stratify_by_column,
        selection_fn=selection_fn,
        include_tables=include_tables,
        row_per=row_per,
        element_table=element_table,
        partition_by=partition_by,
    )

    logger.info(f"Looking up source dataset: {source_dataset_rid}")
    source_ds = ml.lookup_dataset(source_dataset_rid)

    partition_rids, partition_sizes, strategy_desc, element_table = _compute_partitions(
        source_ds=source_ds,
        source_dataset_rid=source_dataset_rid,
        element_table=element_table,
        test_size=test_size,
        train_size=train_size,
        val_size=val_size,
        shuffle=shuffle,
        seed=seed,
        stratify_by_column=stratify_by_column,
        stratify_missing=stratify_missing,
        include_tables=include_tables,
        selection_fn=selection_fn,
        row_per=row_per,
        via=via,
        ignore_unrelated_anchors=ignore_unrelated_anchors,
        partition_by=effective_partition_by,
    )

    # Dry-run early return — no catalog writes.
    if dry_run:
        return SplitResult(
            source=source_dataset_rid,
            split=PartitionInfo(rid="(dry run)", version="(dry run)", count=0),
            training=PartitionInfo(
                rid="(dry run)",
                version="(dry run)",
                count=partition_sizes["Training"],
            ),
            testing=PartitionInfo(
                rid="(dry run)",
                version="(dry run)",
                count=partition_sizes["Testing"],
            ),
            validation=(
                PartitionInfo(
                    rid="(dry run)",
                    version="(dry run)",
                    count=partition_sizes["Validation"],
                )
                if "Validation" in partition_sizes
                else None
            ),
            strategy=strategy_desc,
            element_table=element_table,
            seed=seed,
            dry_run=True,
        )

    # Ensure dataset-type vocabulary terms exist (Training, Testing,
    # Validation, Split, Labeled, Unlabeled). Workflow-type vocabulary is
    # the caller's concern — they registered the workflow that owns
    # ``execution`` and chose its type.
    _ensure_dataset_types(ml)

    # Mirror the per-child tag set built in ``_create_split_hierarchy``
    # — every Split child carries ``Split_Partition``. Recorded in the
    # ``split_config.json`` artifact so the config reflects what the
    # operation actually applied.
    train_types = ["Training", "Split_Partition"] + (training_types or [])
    test_types = ["Testing", "Split_Partition"] + (testing_types or [])
    val_types = ["Validation", "Split_Partition"] + (validation_types or []) if val_size is not None else []

    split_params = {
        "source_dataset_rid": source_dataset_rid,
        "test_size": test_size,
        "train_size": train_size,
        "val_size": val_size,
        "partition_sizes": partition_sizes,
        "shuffle": shuffle,
        "seed": seed,
        "stratify_by_column": stratify_by_column,
        "stratify_missing": stratify_missing,
        "element_table": element_table,
        "include_tables": include_tables,
        "row_per": row_per,
        "via": via,
        "ignore_unrelated_anchors": ignore_unrelated_anchors,
        "partition_by": effective_partition_by,
        "training_types": train_types,
        "testing_types": test_types,
        "validation_types": val_types if val_types else None,
        "strategy": strategy_desc,
    }

    return _create_split_hierarchy(
        ml=ml,
        execution=execution,
        source_dataset_rid=source_dataset_rid,
        partition_rids=partition_rids,
        partition_sizes=partition_sizes,
        strategy_desc=strategy_desc,
        element_table=element_table,
        seed=seed,
        split_description=split_description,
        training_types=training_types,
        testing_types=testing_types,
        validation_types=validation_types,
        val_size=val_size,
        split_params=split_params,
    )

stratified_split

stratified_split(
    stratify_column: str,
    missing: str = "error",
) -> SelectionFunction

Create a stratified selection function.

Returns a selection function that maintains the class distribution of the specified column across all partitions. Delegates to scikit-learn's train_test_split for the actual stratification.

For two-way splits, performs a single stratified split. For three-way splits (Training/Validation/Testing), first separates the test set, then splits the remainder into training and validation.

Parameters:

Name Type Description Default
stratify_column str

Column name in the denormalized DataFrame to stratify by, in dot notation (e.g., Image_Class.Name).

required
missing str

Policy for handling null/NaN values in the stratify column. - "error" (default): Raise ValueError if any values are missing. Reports the count and percentage of nulls. - "drop": Silently exclude rows with missing values from the split. Only rows with valid stratify values are assigned to partitions. - "include": Treat null/NaN as a distinct class label ("__missing__"). Missing-value rows are distributed across partitions proportionally like any other class.

'error'

Returns:

Type Description
SelectionFunction

A SelectionFunction that performs stratified splitting.

Raises:

Type Description
ValueError

If missing="error" and the stratify column contains null values.

Example

selector = stratified_split("Image_Class.Name") # doctest: +SKIP partitions = selector(df, {"Training": 400, "Testing": 100}, seed=42) # doctest: +SKIP

Drop rows with missing labels

selector = stratified_split("Image_Class.Name", missing="drop") # doctest: +SKIP partitions = selector(df, {"Training": 300, "Testing": 100}, seed=42) # doctest: +SKIP

Source code in src/deriva_ml/dataset/split.py
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def stratified_split(
    stratify_column: str,
    missing: str = "error",
) -> SelectionFunction:
    """Create a stratified selection function.

    Returns a selection function that maintains the class distribution
    of the specified column across all partitions. Delegates to
    scikit-learn's ``train_test_split`` for the actual stratification.

    For two-way splits, performs a single stratified split. For three-way
    splits (Training/Validation/Testing), first separates the test set,
    then splits the remainder into training and validation.

    Args:
        stratify_column: Column name in the denormalized DataFrame to
            stratify by, in dot notation (e.g., ``Image_Class.Name``).
        missing: Policy for handling null/NaN values in the stratify column.
            - ``"error"`` (default): Raise ``ValueError`` if any values
              are missing. Reports the count and percentage of nulls.
            - ``"drop"``: Silently exclude rows with missing values from
              the split. Only rows with valid stratify values are assigned
              to partitions.
            - ``"include"``: Treat null/NaN as a distinct class label
              (``"__missing__"``). Missing-value rows are distributed
              across partitions proportionally like any other class.

    Returns:
        A ``SelectionFunction`` that performs stratified splitting.

    Raises:
        ValueError: If ``missing="error"`` and the stratify column
            contains null values.

    Example:
        >>> selector = stratified_split("Image_Class.Name")  # doctest: +SKIP
        >>> partitions = selector(df, {"Training": 400, "Testing": 100}, seed=42)  # doctest: +SKIP

        >>> # Drop rows with missing labels
        >>> selector = stratified_split("Image_Class.Name", missing="drop")  # doctest: +SKIP
        >>> partitions = selector(df, {"Training": 300, "Testing": 100}, seed=42)  # doctest: +SKIP
    """
    if missing not in ("error", "drop", "include"):
        raise ValueError(f"missing must be 'error', 'drop', or 'include', got '{missing}'")

    def _stratified_split(
        df: pd.DataFrame,
        partition_sizes: dict[str, int],
        seed: int,
    ) -> dict[str, np.ndarray]:
        from sklearn.model_selection import train_test_split as sklearn_split

        total_needed = sum(partition_sizes.values())

        if stratify_column not in df.columns:
            available = [c for c in df.columns if not c.startswith("_")]
            raise ValueError(
                f"Column '{stratify_column}' not found in denormalized DataFrame. Available columns: {available}"
            )

        # Handle missing values in the stratify column
        null_mask = df[stratify_column].isna()
        null_count = null_mask.sum()

        if null_count > 0:
            null_pct = null_count / len(df) * 100
            if missing == "error":
                raise ValueError(
                    f"Column '{stratify_column}' has {null_count} missing values "
                    f"({null_pct:.1f}% of {len(df)} rows). "
                    f"Use stratify_missing='drop' to exclude these rows, "
                    f"or 'include' to treat nulls as a separate class."
                )
            elif missing == "drop":
                logger.info(f"Dropping {null_count} rows ({null_pct:.1f}%) with missing values in '{stratify_column}'")
                df = df[~null_mask].reset_index(drop=True)
            elif missing == "include":
                logger.info(
                    f"Treating {null_count} missing values ({null_pct:.1f}%) in "
                    f"'{stratify_column}' as class '__missing__'"
                )
                df = df.copy()
                df[stratify_column] = df[stratify_column].fillna("__missing__")

        if total_needed > len(df):
            raise ValueError(
                f"Requested {total_needed} samples but dataset has {len(df)} records"
                + (
                    f" (after dropping {null_count} rows with missing values)"
                    if null_count > 0 and missing == "drop"
                    else ""
                )
            )

        indices = np.arange(len(df))

        # If we need a subset of the data, first do a stratified sample
        if total_needed < len(df):
            _, subset_indices = sklearn_split(
                indices,
                test_size=total_needed,
                stratify=df[stratify_column].values,
                random_state=seed,
            )
            sub_df = df.iloc[subset_indices]
        else:
            subset_indices = indices
            sub_df = df

        # Partition names in the order we'll peel them off
        partition_names = list(partition_sizes.keys())

        if len(partition_names) == 2:
            # Two-way split: single stratified split
            test_name = partition_names[1]
            train_name = partition_names[0]
            test_fraction = partition_sizes[test_name] / total_needed
            train_idx, test_idx = sklearn_split(
                np.arange(len(sub_df)),
                test_size=test_fraction,
                stratify=sub_df[stratify_column].values,
                random_state=seed,
            )
            return {
                train_name: subset_indices[train_idx],
                test_name: subset_indices[test_idx],
            }
        else:
            # Three-way split: peel off Testing first, then split remainder
            # into Training and Validation.
            test_size = partition_sizes["Testing"]
            test_fraction = test_size / total_needed
            remainder_idx, test_idx = sklearn_split(
                np.arange(len(sub_df)),
                test_size=test_fraction,
                stratify=sub_df[stratify_column].values,
                random_state=seed,
            )

            remainder_df = sub_df.iloc[remainder_idx]
            remainder_total = partition_sizes["Training"] + partition_sizes["Validation"]
            val_fraction = partition_sizes["Validation"] / remainder_total
            train_idx, val_idx = sklearn_split(
                np.arange(len(remainder_df)),
                test_size=val_fraction,
                stratify=remainder_df[stratify_column].values,
                random_state=seed,
            )

            return {
                "Training": subset_indices[remainder_idx[train_idx]],
                "Validation": subset_indices[remainder_idx[val_idx]],
                "Testing": subset_indices[test_idx],
            }

    return _stratified_split

subsample

subsample(
    ml: "DerivaML",
    source_dataset_rid: str,
    execution: "Execution",
    *,
    size: int | float,
    seed: int = 42,
    stratify_by_column: str
    | None = None,
    stratify_missing: Literal[
        "error", "drop", "include"
    ] = "error",
    element_table: str | None = None,
    include_tables: list[str]
    | None = None,
    via: list[str] | None = None,
    row_per: str | None = None,
    ignore_unrelated_anchors: bool = False,
    partition_by: Literal[
        "element", "row"
    ]
    | None = None,
    dataset_types: list[str]
    | None = None,
    description: str | None = None,
    dry_run: bool = False,
) -> SubsampleResult

Create a stratified subsample of source_dataset_rid.

Returns one new dataset whose member set is a stratified random subset of the source's members. The source relationship is recorded as execution provenance only — the source is an input of execution; the subsample is an output. No Dataset_Dataset edge is created between source and subsample (mirroring split_dataset's design call; see CONTEXT.md's Datasets — types and partitions subsection for the canonical framing).

Mirrors sklearn's resample(stratify=y, replace=False, n_samples=N) semantics: stratified sample without replacement.

See :func:split_dataset for the meaning of stratify_by_column, element_table, include_tables, via, row_per, and partition_by — they pass through to the same denormalization machinery.

Role types do not inherit from the source and do not propagate to the subsample. The subsample's role-axis types — Training, Testing, Validation — come exclusively from the caller's dataset_types argument. The Subsample origin-axis tag is always applied automatically (deduplicated defensively if the caller also passes it).

Parameters:

Name Type Description Default
ml 'DerivaML'

Connected :class:DerivaML instance.

required
source_dataset_rid str

The dataset to sample from.

required
execution 'Execution'

The caller's open :class:Execution; the subsample is attributed to it for provenance, and the source is recorded as an input of this execution via :meth:Execution.add_input_dataset.

required
size int | float

If float in (0, 1), fraction of source to sample. If int, absolute sample count. Mirrors sklearn train_test_split's shape for test_size.

required
seed int

Random seed for reproducibility. Default: 42.

42
stratify_by_column str | None

Optional column for stratified sampling (preserves class proportions). When None, the subsample is a uniform random sample. Requires include_tables.

None
stratify_missing Literal['error', 'drop', 'include']

How to handle nulls in the stratify column ("error" / "drop" / "include"). Same semantics as :func:split_dataset.

'error'
element_table str | None

Element table to sample. When None, auto-detected from the source dataset's members.

None
include_tables list[str] | None

Tables to include when denormalizing for the stratify column. Required when stratify_by_column is set.

None
via list[str] | None

Tables forced into the join chain without contributing columns. Pass-through to the denormalizer.

None
row_per str | None

Explicit leaf table for denormalization. When row_per != element_table, partition_by must be set explicitly.

None
ignore_unrelated_anchors bool

Pass-through to the denormalizer.

False
partition_by Literal['element', 'row'] | None

Explicit partition unit. "element" (the default when unambiguous) dedupes to one row per element_table RID before sampling; "row" samples denormalized rows directly. See :func:split_dataset's discussion for the trade-offs.

None
dataset_types list[str] | None

Caller-supplied additional dataset types (typically content-axis types like "Labeled" or role-axis types like "Training"). "Subsample" is always appended; duplicates are de-duped defensively if the caller also passes it.

None
description str | None

Description for the output dataset. When None, an auto-description is generated.

None
dry_run bool

If True, return the planned outputs without mutating the catalog.

False

Returns:

Type Description
SubsampleResult

class:SubsampleResult carrying the new dataset's RID,

SubsampleResult

version, and member count (or "(dry run)" placeholders

SubsampleResult

when dry_run=True).

Raises:

Type Description
ValueError

Argument-shape errors (size <= 0 or

= total, stratify_by_column without include_tables, ambiguous partition_by, etc.).

Example

Take 400 stratified samples from a Training dataset::

with ml.create_execution(cfg) as exe:
    small = subsample(
        ml, training_rid, exe,
        size=400,
        stratify_by_column="Image_Class.Name",
        element_table="Image",
        include_tables=["Image", "Image_Class"],
        dataset_types=["Training", "Labeled"],
    )
exe.commit_output_assets()

from deriva_ml.dataset.split import subsample # doctest: +SKIP

Source code in src/deriva_ml/dataset/split.py
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
def subsample(
    ml: "DerivaML",
    source_dataset_rid: str,
    execution: "Execution",
    *,
    size: int | float,
    seed: int = 42,
    stratify_by_column: str | None = None,
    stratify_missing: Literal["error", "drop", "include"] = "error",
    element_table: str | None = None,
    include_tables: list[str] | None = None,
    via: list[str] | None = None,
    row_per: str | None = None,
    ignore_unrelated_anchors: bool = False,
    partition_by: Literal["element", "row"] | None = None,
    dataset_types: list[str] | None = None,
    description: str | None = None,
    dry_run: bool = False,
) -> SubsampleResult:
    """Create a stratified subsample of ``source_dataset_rid``.

    Returns one new dataset whose member set is a stratified random
    subset of the source's members. The source relationship is
    recorded as **execution provenance only** — the source is an
    input of ``execution``; the subsample is an output. No
    ``Dataset_Dataset`` edge is created between source and subsample
    (mirroring ``split_dataset``'s design call; see CONTEXT.md's
    ``Datasets — types and partitions`` subsection for the canonical
    framing).

    Mirrors sklearn's ``resample(stratify=y, replace=False,
    n_samples=N)`` semantics: stratified sample without replacement.

    See :func:`split_dataset` for the meaning of
    ``stratify_by_column``, ``element_table``, ``include_tables``,
    ``via``, ``row_per``, and ``partition_by`` — they pass through
    to the same denormalization machinery.

    Role types do not inherit from the source and do not propagate to
    the subsample. The subsample's role-axis types — Training,
    Testing, Validation — come exclusively from the caller's
    ``dataset_types`` argument. The ``Subsample`` origin-axis tag is
    always applied automatically (deduplicated defensively if the
    caller also passes it).

    Args:
        ml: Connected :class:`DerivaML` instance.
        source_dataset_rid: The dataset to sample from.
        execution: The caller's open :class:`Execution`; the
            subsample is attributed to it for provenance, and the
            source is recorded as an input of this execution via
            :meth:`Execution.add_input_dataset`.
        size: If float in ``(0, 1)``, fraction of source to sample.
            If int, absolute sample count. Mirrors sklearn
            ``train_test_split``'s shape for ``test_size``.
        seed: Random seed for reproducibility. Default: 42.
        stratify_by_column: Optional column for stratified sampling
            (preserves class proportions). When ``None``, the
            subsample is a uniform random sample. Requires
            ``include_tables``.
        stratify_missing: How to handle nulls in the stratify column
            (``"error"`` / ``"drop"`` / ``"include"``). Same
            semantics as :func:`split_dataset`.
        element_table: Element table to sample. When ``None``,
            auto-detected from the source dataset's members.
        include_tables: Tables to include when denormalizing for
            the stratify column. Required when
            ``stratify_by_column`` is set.
        via: Tables forced into the join chain without contributing
            columns. Pass-through to the denormalizer.
        row_per: Explicit leaf table for denormalization. When
            ``row_per != element_table``, ``partition_by`` must be
            set explicitly.
        ignore_unrelated_anchors: Pass-through to the denormalizer.
        partition_by: Explicit partition unit. ``"element"`` (the
            default when unambiguous) dedupes to one row per
            element_table RID before sampling; ``"row"`` samples
            denormalized rows directly. See :func:`split_dataset`'s
            discussion for the trade-offs.
        dataset_types: Caller-supplied additional dataset types
            (typically content-axis types like ``"Labeled"`` or
            role-axis types like ``"Training"``). ``"Subsample"`` is
            always appended; duplicates are de-duped defensively if
            the caller also passes it.
        description: Description for the output dataset. When
            ``None``, an auto-description is generated.
        dry_run: If ``True``, return the planned outputs without
            mutating the catalog.

    Returns:
        :class:`SubsampleResult` carrying the new dataset's RID,
        version, and member count (or ``"(dry run)"`` placeholders
        when ``dry_run=True``).

    Raises:
        ValueError: Argument-shape errors (``size`` <= 0 or
            >= total, ``stratify_by_column`` without
            ``include_tables``, ambiguous ``partition_by``, etc.).

    Example:
        Take 400 stratified samples from a Training dataset::

            with ml.create_execution(cfg) as exe:
                small = subsample(
                    ml, training_rid, exe,
                    size=400,
                    stratify_by_column="Image_Class.Name",
                    element_table="Image",
                    include_tables=["Image", "Image_Class"],
                    dataset_types=["Training", "Labeled"],
                )
            exe.commit_output_assets()

        >>> from deriva_ml.dataset.split import subsample  # doctest: +SKIP
    """
    effective_partition_by = _validate_subsample_inputs(
        size=size,
        stratify_by_column=stratify_by_column,
        include_tables=include_tables,
        row_per=row_per,
        element_table=element_table,
        partition_by=partition_by,
    )

    logger.info(f"Looking up source dataset: {source_dataset_rid}")
    source_ds = ml.lookup_dataset(source_dataset_rid)

    sample_rids, sample_size, strategy_desc, resolved_element_table = _compute_subsample(
        source_ds=source_ds,
        source_dataset_rid=source_dataset_rid,
        element_table=element_table,
        size=size,
        seed=seed,
        stratify_by_column=stratify_by_column,
        stratify_missing=stratify_missing,
        include_tables=include_tables,
        row_per=row_per,
        via=via,
        ignore_unrelated_anchors=ignore_unrelated_anchors,
        partition_by=effective_partition_by,
    )

    # Dry-run early return — mirrors ``split_dataset``'s dry-run
    # contract (no catalog mutations, no source-input edge, RIDs and
    # versions are ``"(dry run)"`` placeholders).
    if dry_run:
        return SubsampleResult(
            source=source_dataset_rid,
            subsample=PartitionInfo(rid="(dry run)", version="(dry run)", count=sample_size),
            strategy=strategy_desc,
            element_table=resolved_element_table,
            seed=seed,
            dry_run=True,
        )

    _ensure_dataset_types(ml)

    # ``Subsample`` origin tag is always applied. Defensively dedupe
    # in case the caller passes it themselves (spec §7 R4).
    types_in_order: list[str] = ["Subsample"]
    for t in dataset_types or []:
        if t not in types_in_order:
            types_in_order.append(t)
    subsample_types = types_in_order

    auto_description = f"Subsample of dataset {source_dataset_rid} ({strategy_desc}, n={sample_size}, seed={seed})"

    # Persist the subsample parameters so the operation is replayable.
    # Mirrors ``split_dataset``'s ``split_config.json`` artifact —
    # written into ``execution.working_dir`` and picked up by the
    # caller's eventual ``commit_output_assets``.
    sub_params: dict[str, Any] = {
        "source_dataset_rid": source_dataset_rid,
        "size": size,
        "sample_size": sample_size,
        "seed": seed,
        "stratify_by_column": stratify_by_column,
        "stratify_missing": stratify_missing,
        "element_table": resolved_element_table,
        "include_tables": include_tables,
        "row_per": row_per,
        "via": via,
        "ignore_unrelated_anchors": ignore_unrelated_anchors,
        "partition_by": effective_partition_by,
        "dataset_types": subsample_types,
        "strategy": strategy_desc,
    }
    params_file = Path(execution.working_dir) / "subsample_config.json"
    params_file.write_text(json.dumps(sub_params, indent=2))
    logger.info(f"  Saved subsample parameters to {params_file}")

    # Create the output dataset.
    subsample_ds = execution.create_dataset(
        description=description or auto_description,
        dataset_types=subsample_types,
    )
    logger.info(f"  Created Subsample dataset: {subsample_ds.dataset_rid}")

    # Record the source as an execution input — same provenance shape
    # as ``split_dataset``. No ``Dataset_Dataset`` edge between source
    # and subsample (the source is consumed, not nested).
    execution.add_input_dataset(source_dataset_rid)
    logger.info("  Recorded source dataset %s as execution input", source_dataset_rid)

    # Add members to the subsample (batched, same shape as
    # ``_create_split_hierarchy``).
    batch_size = 500
    logger.info(f"  Adding {len(sample_rids)} members to Subsample dataset...")
    for i in range(0, len(sample_rids), batch_size):
        batch = sample_rids[i : i + batch_size]
        subsample_ds.add_dataset_members({resolved_element_table: batch}, validate=False)
        added = min(i + batch_size, len(sample_rids))
        if added % 2000 == 0 or added >= len(sample_rids):
            logger.info(f"    Added {added}/{len(sample_rids)}")

    subsample_ds_info = ml.lookup_dataset(subsample_ds.dataset_rid)
    return SubsampleResult(
        source=source_dataset_rid,
        subsample=PartitionInfo(
            rid=subsample_ds.dataset_rid,
            version=str(subsample_ds_info.current_version),
            count=sample_size,
        ),
        strategy=strategy_desc,
        element_table=resolved_element_table,
        seed=seed,
    )