diff --git a/fgpyo/vcf/__init__.py b/fgpyo/vcf/__init__.py new file mode 100644 index 00000000..4b350974 --- /dev/null +++ b/fgpyo/vcf/__init__.py @@ -0,0 +1,118 @@ +""" +Classes for generating VCF and records for testing +-------------------------------------------------- + +This module contains utility classes for the generation of VCF files and variant records, for use +in testing. + +The module contains the following public classes: + + - :class:`~VariantBuilder` -- A builder class that allows the + accumulation of variant records and access as a list and writing to file. + +Examples +~~~~~~~~ + +Typically, we have :class:`~pysam.VariantRecord` records obtained from reading from a VCF file. +The :class:`~VariantBuilder` class builds such records. + +Variants are added with the :func:`~VariantBuilder.add()` method, which +returns a `Variant`. + + >>> import pysam + >>> from fgpyo.vcf.builder import VariantBuilder + >>> builder: VariantBuilder = VariantBuilder() + >>> new_record_1: pysam.VariantRecord = builder.add() # uses the defaults + >>> new_record_2: pysam.VariantRecord = builder.add( + >>> contig="chr2", pos=1001, id="rs1234", ref="C", alts=["T"], + >>> qual=40, filter=["PASS"] + >>> ) + +VariantBuilder can create sites-only, single-sample, or multi-sample VCF files. If not producing a +sites-only VCF file, VariantBuilder must be created by passing a list of sample IDs + + >>> builder: VariantBuilder = VariantBuilder(sample_ids=["sample1", "sample2"]) + >>> new_record_1: pysam.VariantRecord = builder.add() # uses the defaults + >>> new_record_2: pysam.VariantRecord = builder.add( + >>> samples={"sample1": {"GT": "0|1"}, "sample2": {"GT": "0|0"}} + >>> ) + +The variants stored in the builder can be retrieved as a coordinate sorted VCF file via the +:func:`~VariantBuilder.to_path()` method: + + >>> from pathlib import Path + >>> path_to_vcf: Path = builder.to_path() + +The variants may also be retrieved in the order they were added via the +:func:`~VariantBuilder.to_unsorted_list()` method and in coordinate sorted +order via the :func:`~VariantBuilder.to_sorted_list()` method. + +""" +import os +import sys +from contextlib import contextmanager +from pathlib import Path +from typing import Generator +from typing import TextIO +from typing import Union + +from pysam import VariantFile +from pysam import VariantFile as VcfReader +from pysam import VariantFile as VcfWriter +from pysam import VariantHeader + +"""The valid base classes for opening a VCF file.""" +VcfPath = Union[Path, str, TextIO] + + +@contextmanager +def redirect_dev_null(file_num: int = sys.stderr.fileno()) -> Generator[None, None, None]: + """A context manager that redirects output of file handle to /dev/null + + Args: + file_num: number of filehandle to redirect. Uses stderr by default + """ + # open /dev/null for writing + f_devnull = os.open(os.devnull, os.O_RDWR) + # save old file descriptor and redirect stderr to /dev/null + save_stderr = os.dup(file_num) + os.dup2(f_devnull, file_num) + + yield + + # restore file descriptor and close devnull + os.dup2(save_stderr, file_num) + os.close(f_devnull) + + +@contextmanager +def reader(path: VcfPath) -> Generator[VcfReader, None, None]: + """Opens the given path for VCF reading + + Args: + path: the path to a VCF, or an open file handle + """ + with redirect_dev_null(): + # to avoid spamming log about index older than vcf, redirect stderr to /dev/null: only + # when first opening the file + _reader = VariantFile(path, mode="r") # type: ignore + # now stderr is back, so any later stderr messages will go through + yield _reader + _reader.close() + + +@contextmanager +def writer(path: VcfPath, header: VariantHeader) -> Generator[VcfWriter, None, None]: + """Opens the given path for VCF writing. + Args: + path: the path to a VCF, or an open filehandle + header: the source for the output VCF header. If you are modifying a VCF file that you are + reading from, you can pass reader.header + """ + # Convert Path to str such that pysam will autodetect to write as a gzipped file if provided + # with a .vcf.gz suffix. + if isinstance(path, Path): + path = str(path) + _writer = VariantFile(path, header=header, mode="w") + yield _writer + _writer.close() diff --git a/fgpyo/vcf/builder.py b/fgpyo/vcf/builder.py new file mode 100644 index 00000000..a1307096 --- /dev/null +++ b/fgpyo/vcf/builder.py @@ -0,0 +1,382 @@ +""" +Classes for generating VCF and records for testing +-------------------------------------------------- +""" + +from enum import Enum +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Any +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + +from pysam import VariantHeader +from pysam import VariantRecord + +from fgpyo.sam.builder import SamBuilder +from fgpyo.vcf import writer as PysamWriter + + +class VcfFieldType(Enum): + """Codes for VCF field types""" + + INTEGER = "Integer" + FLOAT = "Float" + FLAG = "Flag" + CHARACTER = "Character" + STRING = "String" + + +class VcfFieldNumber(Enum): + """Special codes for VCF field numbers""" + + NUM_ALT_ALLELES = "A" + NUM_ALLELES = "R" + NUM_GENOTYPES = "G" + UNKNOWN = "." + + +MissingRep = Union[None, Tuple[None, ...]] + + +class VariantBuilder: + """ + Builder for constructing one or more variant records (pysam.VariantRecord) for a VCF. The VCF + can be sites-only, single-sample, or multi-sample. + + Provides the ability to manufacture variants from minimal arguments, while generating + any remaining attributes to ensure a valid variant. + + A builder is constructed with a handful of defaults including the sample name and sequence + dictionary. If the VCF will not be sites-only, the list of sample IDS ("sample_ids") must be + provided to the VariantBuilder constructor. + + Variants are then added using the :func:`~fgpyo.vcf.VariantBuilder.add` method. + Once accumulated the variants can be accessed in the order in which they were created through + the :func:`~fgpyo.vcf.VariantBuilder.to_unsorted_list` function, or in a + list sorted by coordinate order via + :func:`~fgpyo.vcf.VariantBuilder.to_sorted_list`. Lastly, the records can + be written to a temporary file using :func:`~fgpyo.vcf.VariantBuilder.to_path`. + + Attributes: + sample_ids: the sample name(s) + sd: sequence dictionary, implemented as python dict from contig name to dictionary with + contig properties. At a minimum, each contig dict in sd must contain "ID" (the same as + contig_name) and "length", the contig length. Other values will be added to the VCF + header line for that contig. + seq_idx_lookup: dictionary mapping contig name to index of contig in sd + records: the list of variant records + header: the pysam header + """ + + sample_ids: List[str] + sd: Dict[str, Dict[str, Any]] + seq_idx_lookup: Dict[str, int] + records: List[VariantRecord] + header: VariantHeader + + def __init__( + self, + sample_ids: Optional[Iterable[str]] = None, + sd: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> None: + """Initializes a new VariantBuilder for generating variants and VCF files. + + Args: + sample_ids: the name of the sample(s) + sd: optional sequence dictionary + """ + self.sample_ids: List[str] = list(sample_ids) if sample_ids is not None else [] + self.sd: Dict[str, Dict[str, Any]] = sd if sd is not None else VariantBuilder.default_sd() + self.seq_idx_lookup: Dict[str, int] = {name: i for i, name in enumerate(self.sd.keys())} + self.records: List[VariantRecord] = [] + self.header = VariantHeader() + for line in VariantBuilder._build_header_string(sd=self.sd): + self.header.add_line(line) + if sample_ids is not None: + self.header.add_samples(sample_ids) + + @classmethod + def default_sd(cls) -> Dict[str, Dict[str, Any]]: + """Generates the sequence dictionary that is used by default by VariantBuilder. + Re-uses the dictionary from SamBuilder for consistency. + + Returns: + A new copy of the sequence dictionary as a map of contig name to dictionary, one per + contig. + """ + sd: Dict[str, Dict[str, Any]] = {} + for sequence in SamBuilder.default_sd(): + contig = sequence["SN"] + sd[contig] = {"ID": contig, "length": sequence["LN"]} + return sd + + @classmethod + def _build_header_string(cls, sd: Optional[Dict[str, Dict[str, Any]]] = None) -> Iterator[str]: + """Builds the VCF header with the given sample name(s) and sequence dictionary. + + Args: + sd: the sequence dictionary mapping the contig name to the key-value pairs for the + given contig. Must include "ID" and "length" for each contig. If no sequence + dictionary is given, will use the default dictionary. + """ + if sd is None: + sd = VariantBuilder.default_sd() + # add mandatory VCF format + yield "##fileformat=VCFv4.2" + # add GT + yield '##FORMAT=' + # add additional common INFO lines + yield '##INFO=' + yield ( + '##INFO=' + ) + yield '##INFO=' + # add additional common FORMAT lines + yield ( + '##FORMAT=' + ) + yield '##FORMAT=' + yield '##FORMAT=' + + for d in sd.values(): + if "ID" not in d or "length" not in d: + raise ValueError( + "Sequence dictionary must include 'ID' and 'length' for each contig." + ) + contig_id = d["ID"] + contig_length = d["length"] + contig_header = f"##contig= int: + return len(self.sample_ids) + + def add( + self, + contig: Optional[str] = None, + pos: int = 1000, + id: str = ".", + ref: str = "A", + alts: Union[None, str, Iterable[str]] = (".",), + qual: int = 60, + filter: Union[None, str, Iterable[str]] = None, + info: Optional[Dict[str, Any]] = None, + samples: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> VariantRecord: + """Generates a new variant and adds it to the internal collection. + + Notes: + * Very little validation is done with respect to INFO and FORMAT keys being defined in the + header. + * VCFs are 1-based, but pysam is (mostly) 0-based. We define the function in terms of the + VCF property "pos", which is 1-based. pysam will also report "pos" as 1-based, so that is + the property that should be accessed when using the records produced by this function (not + "start"). + + Args: + contig: the chromosome name. If None, will use the first contig in the sequence + dictionary. + pos: the 1-based position of the variant + id: the variant id + ref: the reference allele + alts: the list of alternate alleles, None if no alternates. If a single string is + passed, that will be used as the only alt. + qual: the variant quality + filter: the list of filters, None if no filters (ex. PASS). If a single string is + passed, that will be used as the only filter. + info: the dictionary of INFO key-value pairs + samples: the dictionary from sample name to FORMAT key-value pairs. + if a sample property is supplied for any sample but omitted in some, it will + be set to missing (".") for samples that don't have that property explicitly + assigned. If a sample in the VCF is omitted, all its properties will be set to + missing. + """ + if contig is None: + contig = next(iter(self.sd.keys())) + + if contig not in self.sd: + raise ValueError(f"Chromosome `{contig}` not in the sequence dictionary.") + # because there are a lot of slightly different objects related to samples or called + # "samples" in this function, we alias samples to sample_formats + # we still want to keep the API labeled "samples" because that keeps the naming scheme the + # same as the pysam API + sample_formats = samples + if sample_formats is not None: + unknown_samples = set(sample_formats.keys()).difference(self.sample_ids) + if len(unknown_samples) > 0: + raise ValueError("Unknown sample(s) given: " + ", ".join(unknown_samples)) + + if isinstance(alts, str): + alts = (alts,) + alleles = (ref,) if alts is None else (ref, *alts) + if isinstance(filter, str): + filter = (filter,) + + # pysam expects a list of format dicts provided in the same order as the samples in the + # header (self.sample_ids). (This is despite the fact that it will internally represent the + # values as a map from sample ID to format values, as we do in this function.) + # Convert to that form and rename to record_samples; to a) disambiguate from the input + # values, and b) prevent mypy from complaining about the type changing from dict to list. + if self.num_samples == 0: + # this is a sites-only VCF + record_samples = None + elif sample_formats is None or len(sample_formats) == 0: + # not a sites-only VCF, but no FORMAT values were passed. set FORMAT to missing (with + # no fields) + record_samples = None + else: + # convert to list form that pysam expects, in order pysam expects + # note: the copy {**format_dict} below is present because pysam actually alters the + # input values, which would be an unintended side-effect (in fact without this, tests + # fail because the expected input values are changed) + record_samples = [ + {**sample_formats.get(sample_id, {})} for sample_id in self.sample_ids + ] + + # pysam is zero based, half-open [start, stop) + start = pos - 1 # pysam "start" is zero-based + stop = start + len(ref) + variant = self.header.new_record( + contig=contig, + start=start, + stop=stop, + id=id, + alleles=alleles, + qual=qual, + filter=filter, + info=info, + samples=record_samples, + ) + + self.records.append(variant) + return variant + + def to_path(self, path: Optional[Path] = None) -> Path: + """Returns a path to a VCF for variants added to this builder. + Args: + path: optional path to the VCF + """ + # update the path + path = self._to_vcf_path(path) + + # Create a writer and write to it + with PysamWriter(path, header=self.header) as writer: + for variant in self.to_sorted_list(): + writer.write(variant) + + return path + + @staticmethod + def _to_vcf_path(path: Optional[Path]) -> Path: + """Gets the path to a VCF file. If path is a directory, a temporary VCF will be created in + that directory. If path is `None`, then a temporary VCF will be created. Otherwise, the + given path is simply returned. + + Args: + path: optionally the path to the VCF, or a directory to create a temporary VCF. + """ + if path is None: + with NamedTemporaryFile(suffix=".vcf", delete=False) as fp: + path = Path(fp.name) + assert path.is_file() + return path + + def to_unsorted_list(self) -> List[VariantRecord]: + """Returns the accumulated records in the order they were created.""" + return list(self.records) + + def to_sorted_list(self) -> List[VariantRecord]: + """Returns the accumulated records in coordinate order.""" + return sorted(self.records, key=self._sort_key) + + def _sort_key(self, variant: VariantRecord) -> Tuple[int, int, int]: + return self.seq_idx_lookup[variant.contig], variant.start, variant.stop + + def add_header_line(self, line: str) -> None: + """Adds a header line to the header""" + self.header.add_line(line) + + def add_info_header( + self, + name: str, + field_type: VcfFieldType, + number: Union[int, VcfFieldNumber] = 1, + description: Optional[str] = None, + source: Optional[str] = None, + version: Optional[str] = None, + ) -> None: + """Add an INFO header field to the VCF header. + + Args: + name: the name of the field + field_type: the field_type of the field + number: the number of the field + description: the description of the field + source: the source of the field + version: the version of the field + """ + if field_type == VcfFieldType.FLAG: + number = 0 # FLAGs always have number = 0 + header_line = f"##INFO= None: + """ + Add a FORMAT header field to the VCF header. + + Args: + name: the name of the field + field_type: the field_type of the field + number: the number of the field + description: the description of the field + """ + header_line = f"##FORMAT= None: + """ + Add a FILTER header field to the VCF header. + + Args: + name: the name of the field + description: the description of the field + """ + header_line = f"##FILTER= Path: + return tmp_path_factory.mktemp("test_vcf") + + +@pytest.fixture(scope="function") +def random_generator(seed: int = 42) -> random.Random: + return random.Random(seed) + + +@pytest.fixture(scope="function") +def sequence_dict() -> Dict[str, Dict[str, Any]]: + return VariantBuilder.default_sd() + + +def _get_random_contig( + random_generator: random.Random, sequence_dict: Dict[str, Dict[str, Any]] +) -> (str, int): + """Randomly select a contig from the sequence dictionary and return its name and length.""" + contig = random_generator.choice(list(sequence_dict.values())) + return contig["ID"], contig["length"] + + +_ALL_FILTERS = frozenset({"MAYBE", "FAIL", "SOMETHING"}) +_INFO_FIELD_TYPES = MappingProxyType( + { + "TEST_INT": VcfFieldType.INTEGER, + "TEST_STR": VcfFieldType.STRING, + "TEST_FLOAT": VcfFieldType.FLOAT, + } +) + + +def _get_random_variant_inputs( + random_generator: random.Random, + sequence_dict: Dict[str, Dict[str, Any]], +) -> Mapping[str, Any]: + """ + Randomly generate inputs that should produce a valid Variant. Don't include format fields. + """ + contig, contig_len = _get_random_contig(random_generator, sequence_dict) + variant_reference_len = random_generator.choice([0, 1, 5, 100]) + variant_read_len = random_generator.choice( + [1, 5, 100] if variant_reference_len == 0 else [0, 1, 5, 100] + ) + num_filters = random_generator.randint(0, 3) + filter = tuple(random_generator.sample(list(_ALL_FILTERS), k=num_filters)) + start = random_generator.randint(1, contig_len - variant_reference_len) + # stop is not directly passed by current API, but this is what its value would be: + # stop = start + variant_reference_len + ref = "".join(random_generator.choices("ATCG", k=variant_reference_len)) + alt = ref + while alt == ref: + alt = "".join(random_generator.choices("ATCG", k=variant_read_len)) + if variant_reference_len == 0 or variant_read_len == 0: + # represent ref/alt for insertions/deletions as starting with the last unaltered base. + random_start = random_generator.choices("ATCG")[0] + ref = random_start + ref + alt = random_start + alt + + info = { + key: ( + random_generator.randint(0, 100) + if value_type == VcfFieldType.INTEGER + else random_generator.uniform(0, 1) + if value_type == VcfFieldType.FLOAT + else random_generator.choice(["Up", "Down"]) + ) + for key, value_type in _INFO_FIELD_TYPES.items() + } + + return MappingProxyType( + { + "contig": contig, + "pos": start, + "ref": ref, + "alts": (alt,), + "filter": filter, + "info": info, + } + ) + + +@pytest.fixture(scope="function") +def zero_sample_record_inputs( + random_generator: random.Random, sequence_dict: Dict[str, Dict[str, Any]] +) -> Tuple[Mapping[str, Any]]: + """ + Fixture with inputs to create test Variant records for zero-sample VCFs (no genotypes). + Make them MappingProxyType so that they are immutable. + """ + return tuple(_get_random_variant_inputs(random_generator, sequence_dict) for _ in range(100)) + + +def _add_headers(variant_builder: VariantBuilder) -> None: + """Add needed headers to the VariantBuilder.""" + for filter in _ALL_FILTERS: + variant_builder.add_filter_header(filter) + for field_name, field_type in _INFO_FIELD_TYPES.items(): + variant_builder.add_info_header(field_name, field_type=field_type) + + +def _fix_value(value: Any) -> Any: + """Helper to convert pysam data types to basic python types for testing/comparison.""" + if isinstance(value, pysam.VariantRecord): + return { + "contig": value.contig, + "id": value.id, + "pos": value.pos, + "ref": value.ref, + "qual": value.qual, + "alts": _fix_value(value.alts), + "filter": _fix_value(value.filter), + "info": _fix_value(value.info), + "samples": _fix_value(value.samples), + } + elif isinstance(value, str): + # this has __iter__, so just get it out of the way early + return value + elif isinstance(value, float): + return round(value, 4) # only keep a few decimal places, VCF changes type, rounds, etc + elif isinstance(value, pysam.VariantRecordFilter): + return tuple(value.keys()) + elif hasattr(value, "items"): + return {_key: _fix_value(_value) for _key, _value in value.items()} + elif hasattr(value, "__iter__"): + return tuple(_fix_value(_value) for _value in value) + else: + return value + + +def _assert_equal(expected_value: Any, actual_value: Any) -> None: + """Helper to assert that two values are equal, handling pysam data types.""" + __tracebackhide__ = True + assert _fix_value(expected_value) == _fix_value(actual_value) + + +def test_minimal_inputs() -> None: + """Show that all inputs can be None and the builder will succeed.""" + variant_builder = VariantBuilder() + variant_builder.add() + variants = variant_builder.to_sorted_list() + assert len(variants) == 1 + assert isinstance(variants[0], pysam.VariantRecord) + assert variants[0].contig == "chr1" # 1st contig in the default sequence dictionary + + # now the same, but with a non-default sequence dictionary + non_standard_sequence_dict = {"contig1": {"ID": "contig1", "length": 10000}} + variant_builder = VariantBuilder(sd=non_standard_sequence_dict) + variant_builder.add() + variants = variant_builder.to_sorted_list() + assert len(variants) == 1 + assert isinstance(variants[0], pysam.VariantRecord) + assert variants[0].contig == "contig1" + + +def test_sort_order(random_generator: random.Random) -> None: + """Test if the VariantBuilder sorts the Variant records in the correct order.""" + sorted_inputs = [ + {"contig": "chr1", "pos": 100}, + {"contig": "chr1", "pos": 500}, + {"contig": "chr2", "pos": 1000}, + {"contig": "chr2", "pos": 10000}, + {"contig": "chr10", "pos": 10}, + {"contig": "chr10", "pos": 20}, + {"contig": "chr11", "pos": 5}, + ] + scrambled_inputs = random_generator.sample(sorted_inputs, k=len(sorted_inputs)) + assert scrambled_inputs != sorted_inputs # there should be something to actually sort + variant_builder = VariantBuilder() + for record_input in scrambled_inputs: + variant_builder.add(**record_input) + + for sorted_input, variant_record in zip(sorted_inputs, variant_builder.to_sorted_list()): + for key, value in sorted_input.items(): + _assert_equal(expected_value=value, actual_value=getattr(variant_record, key)) + + +def test_zero_sample_records_match_inputs( + zero_sample_record_inputs: Tuple[Mapping[str, Any]] +) -> None: + """Test if zero-sample VCF (no genotypes) records produced match the requested inputs.""" + variant_builder = VariantBuilder() + _add_headers(variant_builder) + for record_input in zero_sample_record_inputs: + variant_builder.add(**record_input) + + for record_input, variant_record in zip( + zero_sample_record_inputs, variant_builder.to_unsorted_list() + ): + for key, value in record_input.items(): + _assert_equal(expected_value=value, actual_value=getattr(variant_record, key)) + + +def _get_is_compressed(input_file: Path) -> bool: + """Returns True if the input file is gzip-compressed, False otherwise.""" + with gzip.open(f"{input_file}", "r") as f_in: + try: + f_in.read(1) + return True + except OSError: + return False + + +@pytest.mark.parametrize("compress", (True, False)) +def test_zero_sample_vcf_round_trip( + temp_path: Path, + zero_sample_record_inputs, + compress: bool, +) -> None: + """ + Test if zero-sample VCF (no genotypes) output records match the records read in from the + resulting VCF. + """ + vcf = temp_path / ("test.vcf.gz" if compress else "test.vcf") + variant_builder = VariantBuilder() + _add_headers(variant_builder) + for record_input in zero_sample_record_inputs: + variant_builder.add(**record_input) + + variant_builder.to_path(vcf) + + # this can fail if pysam.VariantFile is not invoked correctly with pathlib.Path objects + assert _get_is_compressed(vcf) == compress + + with vcf_reader(vcf) as reader: + for vcf_record, builder_record in zip(reader, variant_builder.to_sorted_list()): + _assert_equal(expected_value=builder_record, actual_value=vcf_record) + + +def _add_random_genotypes( + random_generator: random.Random, + record_input: Mapping[str, Any], + sample_ids: Iterable[str], +) -> Mapping[str, Any]: + """Add random genotypes to the record input.""" + genotypes = { + sample_id: { + "GT": random_generator.choice( + [ + (None,), + (0, 0), + (0, 1), + (1, 0), + (1, 1), + (None, 0), + (0, None), + (1, None), + ] + ) + } + for sample_id in sample_ids + } + return MappingProxyType({**record_input, "samples": genotypes}) + + +@pytest.mark.parametrize("num_samples", (1,)) +@pytest.mark.parametrize("add_genotypes_to_records", (True, False)) +def test_variant_sample_records_match_inputs( + random_generator: random.Random, + zero_sample_record_inputs: Tuple[Mapping[str, Any]], + num_samples: int, + add_genotypes_to_records: bool, +) -> None: + """ + Test if records with samples / genotypes match the requested inputs. + If add_genotypes is True, then add random genotypes to the record input, otherwise test that + the VariantBuilder will work even if genotypes are not supplied. + """ + sample_ids = [f"sample{i}" for i in range(num_samples)] + variant_builder = VariantBuilder(sample_ids=sample_ids) + _add_headers(variant_builder) + variant_sample_records = ( + tuple( + _add_random_genotypes( + random_generator=random_generator, record_input=record_input, sample_ids=sample_ids + ) + for record_input in zero_sample_record_inputs + ) + if add_genotypes_to_records + else zero_sample_record_inputs + ) + for record_input in variant_sample_records: + variant_builder.add(**record_input) + + for record_input, variant_record in zip( + variant_sample_records, variant_builder.to_unsorted_list() + ): + for key, input_value in record_input.items(): + _assert_equal(expected_value=input_value, actual_value=getattr(variant_record, key)) + + +@pytest.mark.parametrize("num_samples", (1, 5)) +@pytest.mark.parametrize("compress", (True, False)) +@pytest.mark.parametrize("add_genotypes_to_records", (True, False)) +def test_variant_sample_vcf_round_trip( + temp_path: Path, + random_generator: random.Random, + zero_sample_record_inputs: Tuple[Mapping[str, Any]], + num_samples: int, + compress: bool, + add_genotypes_to_records: bool, +) -> None: + """ + Test if 1 or multi-sample VCF output records match the records read in from the resulting VCF. + If add_genotypes is True, then add random genotypes to the record input, otherwise test that + the VariantBuilder will work even if genotypes are not supplied. + """ + sample_ids = [f"sample{i}" for i in range(num_samples)] + vcf = temp_path / ("test.vcf.gz" if compress else "test.vcf") + variant_builder = VariantBuilder(sample_ids=sample_ids) + _add_headers(variant_builder) + variant_sample_records = ( + tuple( + _add_random_genotypes( + random_generator=random_generator, record_input=record_input, sample_ids=sample_ids + ) + for record_input in zero_sample_record_inputs + ) + if add_genotypes_to_records + else zero_sample_record_inputs + ) + for record_input in variant_sample_records: + variant_builder.add(**record_input) + variant_builder.to_path(vcf) + + # this can fail if pysam.VariantFile is not invoked correctly with pathlib.Path objects + assert _get_is_compressed(vcf) == compress + + with vcf_reader(vcf) as reader: + for vcf_record, builder_record in zip(reader, variant_builder.to_sorted_list()): + _assert_equal(expected_value=builder_record, actual_value=vcf_record)