Source code for variant_extractor.VariantExtractor

# Copyright 2022 - Barcelona Supercomputing Center
# Author: Rodrigo Martin
# MIT License
from typing import List, Optional
import warnings
import pysam

from .private._utils import compare_contigs, permute_breakend_sv, convert_inv_to_breakend, convert_del_to_ins
from .private._parser import parse_breakend_sv, parse_shorthand_sv, parse_sgl_sv, parse_standard_record
from .private._PendingBreakends import PendingBreakends
from .variants import VariantType
from .variants import VariantRecord

DATAFRAME_COLUMNS = ['start_chrom', 'start', 'end_chrom', 'end', 'ref',
                     'alt', 'length', 'brackets', 'type_inferred', 'variant_record_obj']
DATAFRAME_DTYPES = {'start_chrom': 'string', 'start': 'uint64', 'end_chrom': 'string', 'end': 'uint64', 'ref': 'string',
                    'alt': 'string', 'length': 'uint64', 'brackets': 'string', 'type_inferred': 'string', 'variant_record_obj': object}

def _use_lowest_type(series):
    series_max = series.max()
    if series_max < 2 ** 8:
        series = series.astype('uint8')
    elif series_max < 2 ** 16:
        series = series.astype('uint16')
    elif series_max < 2 ** 32:
        series = series.astype('uint32')
        series = series.astype('uint64')
    return series

[docs] class VariantExtractor: """ Reads and extracts variants from VCF files. This class is designed to be used in a pipeline, where the variants are ingested from VCF files and then used in downstream analysis. """
[docs] @staticmethod def empty_dataframe(): """Returns an empty pandas DataFrame with the columns used by this class. """ import pandas as pd df = pd.DataFrame(columns=DATAFRAME_COLUMNS) df = df.astype(DATAFRAME_DTYPES) return df
[docs] def __init__(self, vcf_file: str, pass_only=False, ensure_pairs=True, fasta_ref: Optional[str] = None): """ Parameters ---------- vcf_file : str A VCF formatted file. The file is automatically opened. pass_only : bool, optional If :code:`True`, only records with PASS filter will be considered. ensure_pairs : bool, optional If :code:`True`, throws an exception if a breakend is missing a pair when all other were paired successfully. fasta_ref : str, optional A FASTA file with the reference genome. Must be indexed. """ self.__ensure_pairs = ensure_pairs self.__pass_only = pass_only self.__pairs_found = 0 self.__pending_breakends = PendingBreakends() self.__fasta_ref = None # Open FASTA file if fasta_ref is not None: self.__fasta_ref = pysam.FastaFile(fasta_ref) # Open VCF file vcf_handle = open(file=vcf_file, mode='r') save = pysam.set_verbosity(0) self.__variant_file = pysam.VariantFile(vcf_handle) pysam.set_verbosity(save)
[docs] def close(self): """Closes the VCF file. """ self.__variant_file.close()
def __iter__(self): # Read the next record from the VCF file for rec in self.__variant_file: yield from self.__handle_record(rec) # Remove non-PASS records from the pending breakends if pass_only is True if self.__pass_only: vcf_records = list(self.__pending_breakends.values()) for vcf_record in vcf_records: if 'PASS' not in vcf_record.filter: self.__pending_breakends.remove(vcf_record) # Only single-paired records or not ensuring pairs if not self.__ensure_pairs or self.__pairs_found == 0: for vcf_record in self.__pending_breakends.values(): yield from self.__handle_breakend_individual_sv(vcf_record) # Found unpaired records elif len(self.__pending_breakends) > 0: exception_text = '' for vcf_record in self.__pending_breakends.values(): exception_text += str(vcf_record)+'\n' raise Exception( (f'There are {len(self.__pending_breakends)} unpaired SV breakends. ' 'Please, check the entires shown below in VCF file. ' f'Use ensure_pairs=False to ignore unpaired SV breakends.\n{exception_text}')) def __handle_record(self, rec: pysam.VariantRecord) -> List[VariantRecord]: if not rec.alts: return [] if not rec.ref: raise ValueError('Record does not have a REF field') # Handle multiallelic records if len(rec.alts) != 1: return self.__handle_multiallelic_record(rec) # Check if breakend SV record vcf_record = parse_breakend_sv(rec) if vcf_record: return self.__handle_breakend_sv(vcf_record) # Check PASS filter if self.__pass_only and 'PASS' not in rec.filter: return [] # Check if shorthand SV record vcf_record = parse_shorthand_sv(rec) if vcf_record: return self.__handle_shorthand_sv(vcf_record) # Check if single breakend SV record vcf_record = parse_sgl_sv(rec) if vcf_record: return [vcf_record] # Check if standard record vcf_record = parse_standard_record(rec) if vcf_record: return self.__handle_standard_record(vcf_record) else: warnings.warn(f'Skipping unrecognized record:\n{rec}') return [] def __handle_standard_record(self, vcf_record: VariantRecord) -> List[VariantRecord]: record_list = [] if len(vcf_record.ref) == len(vcf_record.alt): for i in range(len(vcf_record.ref)): # Atomize SNVs if vcf_record.ref[i] != vcf_record.alt[i]: variant_id = f'{}_{i}' if len(vcf_record.ref) > 1 and is not None \ else new_vcf_record = vcf_record._replace( ref=vcf_record.ref[i], pos=i+vcf_record.pos, end=i+vcf_record.pos, length=1, alt=vcf_record.alt[i], id=variant_id, variant_type=VariantType.SNV) record_list.append(new_vcf_record) elif len(vcf_record.ref) > len(vcf_record.alt): # Deletion vcf_record.variant_type=VariantType.DEL record_list.append(vcf_record) elif len(vcf_record.ref) < len(vcf_record.alt): # Insertion vcf_record.variant_type=VariantType.INS record_list.append(vcf_record) return record_list def __handle_breakend_sv(self, vcf_record: VariantRecord) -> List[VariantRecord]: # Check for pending breakends previous_record = self.__pending_breakends.pop(vcf_record) if previous_record is None: self.__pending_breakends.push(vcf_record) return [] # Mate breakend found, handle it self.__pairs_found += 1 return self.__handle_braked_paired_sv(previous_record, vcf_record) def __handle_braked_paired_sv(self, vcf_record_1: VariantRecord, vcf_record_2: VariantRecord) -> List[VariantRecord]: # Check PASS filter if self.__pass_only and ('PASS' not in vcf_record_1.filter or 'PASS' not in vcf_record_2.filter): return [] # Unify filters filters = set(vcf_record_1.filter) | set(vcf_record_2.filter) filters.discard('PASS') if len(filters) > 0: vcf_record_1.filter = list(filters) vcf_record_2.filter = list(filters) contig_comparison = compare_contigs(vcf_record_1.contig, vcf_record_2.contig) if contig_comparison == 0: if vcf_record_1.pos < vcf_record_2.pos: return self.__handle_breakend_individual_sv(vcf_record_1) else: return self.__handle_breakend_individual_sv(vcf_record_2) elif contig_comparison == -1: return self.__handle_breakend_individual_sv(vcf_record_1) else: return self.__handle_breakend_individual_sv(vcf_record_2) def __handle_breakend_individual_sv(self, vcf_record: VariantRecord) -> List[VariantRecord]: assert vcf_record.alt_sv_breakend is not None contig_comparison = compare_contigs(vcf_record.contig, vcf_record.alt_sv_breakend.contig) # Transform REF/ALT to equivalent notation so that REF contains the lowest contig and position if contig_comparison == 1 or (contig_comparison == 0 and vcf_record.pos > vcf_record.alt_sv_breakend.pos): vcf_record = permute_breakend_sv(vcf_record, self.__fasta_ref) # Handle DEL notated variant as INS if vcf_record.length == 1 and vcf_record.variant_type == VariantType.DEL and vcf_record.alt_sv_breakend is not None: vcf_record = convert_del_to_ins(vcf_record) if vcf_record.length == 0: return [] return [vcf_record] def __handle_shorthand_sv(self, vcf_record: VariantRecord) -> List[VariantRecord]: if vcf_record.variant_type == VariantType.INV: # Transform INV into breakend notation vcf_record_1, vcf_record_2 = convert_inv_to_breakend(vcf_record, self.__fasta_ref) return [vcf_record_1, vcf_record_2] else: return [vcf_record] def __handle_multiallelic_record(self, rec: pysam.VariantRecord) -> List[VariantRecord]: record_list = [] fake_rec = rec.copy() assert fake_rec.alts is not None and len(fake_rec.alts) > 1 alts = fake_rec.alts samples = dict() for sample_name in rec.samples: sample_dict = dict() for key, value in rec.samples[sample_name].items(): sample_dict[key] = value samples[sample_name] = sample_dict original_id = for i, alt in enumerate(alts): # WARNING: This overrides the record fake_rec.alts = (alt,) if original_id: new_id = f'{original_id}_{i}' = new_id new_records = self.__handle_record(fake_rec) new_samples = dict() for sample_name in samples: new_samples[sample_name] = dict() for key, value in samples[sample_name].items(): if not hasattr(value, '__iter__') or len(value) == self.__variant_file.header.formats[key].number: new_samples[sample_name][key] = value else: if key == 'GT': new_samples[sample_name][key] = (0, samples[sample_name][key][1]) elif len(value) == len(alts) + 1: new_samples[sample_name][key] = (value[0], value[i + 1]) elif hasattr(value, '__iter__') and len(value) % len(alts) == 0: new_samples[sample_name][key] = value[i::len(alts)] else: new_samples[sample_name][key] = value for new_record in new_records: new_record.samples = new_samples record_list.extend(new_records) return record_list
[docs] def to_dataframe(self): import pandas as pd variants = [] for variant_record in self: start_chrom = variant_record.contig.replace('chr', '') start = variant_record.pos ref = variant_record.ref alt = variant_record.alt length = variant_record.length end = variant_record.end if variant_record.alt_sv_breakend: end_chrom = variant_record.alt_sv_breakend.contig.replace('chr', '') if start_chrom != end_chrom: end = variant_record.alt_sv_breakend.pos else: end_chrom = start_chrom # Inferred type type_inferred = # breakends breakends = '' if type_inferred == breakends = 'N[' elif type_inferred == breakends = ']N' elif type_inferred == assert variant_record.alt_sv_breakend is not None prefix = 'N' if variant_record.alt_sv_breakend.prefix else '' suffix = 'N' if variant_record.alt_sv_breakend.suffix else '' breakends = prefix + variant_record.alt_sv_breakend.bracket + suffix elif type_inferred == assert variant_record.alt_sv_breakend is not None prefix = 'N' if variant_record.alt_sv_breakend.prefix else '' suffix = 'N' if variant_record.alt_sv_breakend.suffix else '' breakends = prefix + variant_record.alt_sv_breakend.bracket + variant_record.alt_sv_breakend.bracket + suffix variants.append([start_chrom, start, end_chrom, end, ref, alt, length, breakends, type_inferred, variant_record]) df = pd.DataFrame(variants, columns=DATAFRAME_COLUMNS) df.astype(DATAFRAME_DTYPES) # Reduce memory usage by using the smallest possible data type for start, end and length df['start'] = _use_lowest_type(df['start']) df['end'] = _use_lowest_type(df['end']) df['length'] = _use_lowest_type(df['length']) return df