使用TensorFlow进行高效DNA嵌入
介绍
利用DNA或其他生物序列(如RNA或蛋白质序列)进行深度学习在过去几年里取得了很大进展。大多数从DNA序列中学习的模型使用one-hot编码方案,该方案使用四个通道来表示四个可能的核苷酸A、C、G和T。
长度为N的DNA序列的one-hot编码是(N×4)矩阵,其中列对应于字母A,C,G和T,并且每行恰好具有一个等于一的项,其他条目为零。
任何此类模型中的第一步是以one-hot方案对输入DNA序列进行编码,该方案可以由以下神经网络层处理,无论它们是全连接的,卷积的还是循环的。
通常,one-hot编码被视为在将数据馈送到模型之前发生的预处理步骤。例如,可以对整个训练数据集进行预编码,然后将其存储在磁盘上以用于训练。在其他情况下,模型包含在一些Python预处理层中,该层在Python-land中执行编码,而不是在TensorFlow图中执行。例如,Selene包是一个在PyTorch中处理生物序列的框架,它将编码作为预处理步骤实现,并在Cython中实现。
然而,纯粹从界面设计的角度来看,从DNA序列进行预测的模型应该接受序列作为字符串,而one-hot编码表示应该被视为一个实现细节,不需要对用户可见。这也使得将模型打包为tf.SavedModel并共享或部署更容易,因为该图本身接受DNA序列而无需用户执行他们自己的编码。
编码器只需要执行一个简单的工作:取一个序列,对于每个核苷酸,根据以下映射输出一个向量。
正常情况下,与运行模型本身的成本相比,DNA嵌入将是一项非常廉价的操作,但一个糟糕的嵌入实现可能最终成为一个严重的瓶颈。
事实证明,有很多方法可以使用本机TensorFlow操作来实现DNA one-hot编码,这里我将介绍三种方法。
我将首先介绍我的三个实现,然后我将对它们进行基准测试,看看是否存在相当大的差异。
使用lookup table
使用DNA序列学习在某些方面类似于自然语言处理,通常使用lookup table将字符串键从词汇表映射到整数id。
TensorFlow在tf.contrib.lookup模块中有一个lookup table类。该contrib模块将在即将推出的TensorFlow 2.0中弃用,但可能会有非常接近的替代品。TensorFlow有tf.one_hot函数,它可以将这些整数ID转换为单热嵌入的功能。
以下Python函数将字符串格式的DNA输入映射到整数ID。我现在将省略one-hot 编码步骤,我们可以稍后单独对这两个步骤进行基准测试。
由于我们需要首先将字符串拆分为单个字符,因此我们使用该tf.string_split函数,因此该函数稍微复杂一些。tf.string_split但是,由于返回稀疏张量,我们需要将其转换回密集向量(lookup table只接受密集向量)。
这个函数稍微有点复杂,因为我们需要先将字符串分割成单独的字符,然后使用tfstring_split函数处理这些字符。因为tfstring_split返回一个稀疏张量,所以我们需要将它转换回一个dense向量(lookup table只接受dense向量)。
最后,table.lookup(seq)将结果作为整数id的张量返回。
使用位操作来计算整数索引
使用只有四个键的lookup table似乎有点过火。但是,有什么更简单的方法可以将DNA字母表映射到tf.one_hot函数的索引呢?一种方法是使用基本的位函数直接计算索引。我们需要执行的映射如下:
我们所需要做的就是使用基本的位操作符&、|、^、~、<<和>>找到一个操作序列,将左边的值转换为右边的值。
为了简单起见,我首先在纯Python中实现了以下一个这样的转换:
对于序列中的每个字母,这个Python代码片段首先清除第5位和第7位最低有效位(从右边开始),然后向右移1位,除了G和T的值交换之外,几乎得到了正确的位模式。因此,剩下的唯一步骤是用2替换3,用3替换2。第三行使用表达式(nt & 1 << 1)作为掩码,只影响右边第二位的值,然后使用xor翻转最右边的位。下表显示了每个输入的逐步转换。
总而言之,此操作将DNA alphabete映射ACGT到索引0,1,2,3。
TensorFlow中此函数的Python实现如下:
def dna_encode_bit_manipulation(seq, name='dna_encode'): with tf.name_scope(name): bytes = tf.decode_raw(seq, tf.uint8) bytes = tf.bitwise.bitwise_and(bytes, ~((1 << 6) | (1 << 4)) bytes = tf.bitwise.right_shift(bytes, 1) mask = tf.bitwise.bitwise_and(bytes, 2) mask = tf.bitwise.right_shift(mask, 1) bytes = tf.bitwise.bitwise_xor(bytes, mask) return bytes
此函数可以替换上面的lookup table,然后可以使用tf.one_hot函数来获取最终编码。由于此
函数仅使用元素运算,因此它也可以在GPU上非常有效地运行。
使用嵌入表
自然语言处理的另一个概念是嵌入表,它将整数id作为输入并输出包含该id的嵌入的向量。在自然语言处理中,这些嵌入是随机初始化和训练的,但我们可以利用相同的工具将我们的序列映射到固定的one-hot编码。
在上面的两种方法中,我们将DNA字母映射到整数0,…,3,以便我们可以使用tf.one_hot函数。但是如果我们直接使用A、C、G和T的整数ASCII码作为索引呢?我们只需要把表做大一些,这样我们就可以使用tf.nn.embedding_lookup函数。
需要分配一个包含84行的嵌入表,因为这是ASCII码的T。
这种方法还有一些其他优点:它可以很容易地适应编码其他字母,如氨基酸序列,或者我们可以使用这种方法来解释IUPAC核苷酸通配符,例如定义R为嘌呤(A或G)或B定义为“除A之外的任何项”(C、G或T)。如果这是我们想要的,我们可以定义以下嵌入表:
基准
为了验证这些方法,我反复使用它们对人类基因组DMD (dystrophine)中最长的基因(约2.24 Mbp)进行one_hot编码。我使用twobitreader模块从hg38 2bit文件(http://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/)中提取序列,计算反补码,因为DMD在负链上:
基准测试使用的软件是Python 3.6.8(Anaconda),TensorFlow 1.12.0(conda包,支持GPU和MKL扩展)和CUDA 9.2。
我首先运行了dna_encode_lookup_table和dna_encode_bit_manipulation,没有使用tf.one_hot步骤,以便分别对这两个步骤进行基准测试。
我用timeit.repeat(number=10和repeats=10)。对于位操作方法,执行编码10次的10次重复中的最小值为14ms,对于lookup table方法,为2.17s。
dna_encode_embedding_table函数花了38毫秒直接计算最终的one-hot编码。
如果对前两个函数进行基准测试,然后对tf.one_hot应用程序进行基准测试,就会发现总时间主要由从索引中计算one-hot编码决定。位操作方法和one-hot编码的总时间为41 ms,lookup table方法和one-hot编码的总时间为2.18 s。
最后,下图显示了所有结果:
上图为运行时间用于DNA序列的one-hot编码的不同实现。任务是将DMD基因的完整序列(2.24 Mbp)嵌入 10次。蓝条仅显示计算整数索引的时间,橙条显示计算整数索引的总时间,然后是one-hot编码。
lookup table显然是一个糟糕的选择,非常低效。在计算整数索引时,位操作非常快,但如果我们包括计算one-hot编码所需的时间,则其大致与嵌入方法一样快。
总的来说,位操作方法具有一些简洁的魅力,但是当需要嵌入IUPAC通配符值时,或者当您使用不同的字母表(如蛋白质序列)时,它不能被使用。在大多数情况下,嵌入表可能是最实用的,因为它可以与任何序列字母表一起使用,并且运行速度与位操作方法一样快。
完整Python示例代码
import numpy as np import tensorflow as tf import twobitreader import timeit def tf_dna_encode_lookup_table(seq, name="dna_encode"): """Map DNA string inputs to integer ids using a lookup table.""" with tf.name_scope(name): # Defining the lookup table mapping_strings = tf.constant(["A", "C", "G", "T"]) table = tf.contrib.lookup.index_table_from_tensor( mapping=mapping_strings, num_oov_buckets=0, default_value=-1) # Splitting the string into single characters seq = tf.squeeze( tf.sparse.to_dense( tf.string_split([seq], delimiter=""), default_value=""), 0) return table.lookup(seq) def tf_dna_encode_bit_manipulation(seq, name='dna_encode'): with tf.name_scope(name): bytes = tf.decode_raw(seq, tf.uint8) bytes = tf.bitwise.bitwise_and(bytes, ~(1 << 6)) bytes = tf.bitwise.bitwise_and(bytes, ~(1 << 4)) bytes = tf.bitwise.right_shift(bytes, 1) mask = tf.bitwise.bitwise_and(bytes, 2) mask = tf.bitwise.right_shift(mask, 1) bytes = tf.bitwise.bitwise_xor(bytes, mask) return bytes #%% def tf_dna_encode_embedding_table(dna_input, name="dna_encode"): """Map DNA sequence to one-hot encoding using an embedding table.""" # Define the embedding table _embedding_values = np.zeros([89, 4], np.float32) _embedding_values[ord('A')] = np.array([1, 0, 0, 0]) _embedding_values[ord('C')] = np.array([0, 1, 0, 0]) _embedding_values[ord('G')] = np.array([0, 0, 1, 0]) _embedding_values[ord('T')] = np.array([0, 0, 0, 1]) _embedding_values[ord('W')] = np.array([.5, 0, 0, .5]) _embedding_values[ord('S')] = np.array([0, .5, .5, 0]) _embedding_values[ord('M')] = np.array([.5, .5, 0, 0]) _embedding_values[ord('K')] = np.array([0, 0, .5, .5]) _embedding_values[ord('R')] = np.array([.5, 0, .5, 0]) _embedding_values[ord('Y')] = np.array([0, .5, 0, .5]) _embedding_values[ord('B')] = np.array([0, 1. / 3, 1. / 3, 1. / 3]) _embedding_values[ord('D')] = np.array([1. / 3, 0, 1. / 3, 1. / 3]) _embedding_values[ord('H')] = np.array([1. / 3, 1. / 3, 0, 1. / 3]) _embedding_values[ord('V')] = np.array([1. / 3, 1. / 3, 1. / 3, 0]) _embedding_values[ord('N')] = np.array([.25, .25, .25, .25]) embedding_table = tf.get_variable( 'dna_lookup_table', _embedding_values.shape, initializer=tf.constant_initializer(_embedding_values), trainable=False) # Ensure that embedding table is not trained with tf.name_scope(name): dna_input = tf.decode_raw(dna_input, tf.uint8) # Interpret string as bytes dna_32 = tf.cast(dna_input, tf.int32) encoded_dna = tf.nn.embedding_lookup(embedding_table, dna_32) return encoded_dna #%% if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "genome_file", help="Location to genome 2bit file (hg38)") parser.add_argument( "-N", type=int, help="Number of iterations for each method") parser.add_argument("-r", type=int, help="Number of repeats") args = parser.parse_args() # Extract DMD sequence and compute reverse complement genome = twobitreader.TwoBitFile(args.genome_file) dmd_sequence = genome['chrX'][31097676:33339441].upper() def reverse_complement(seq): return "".join("TGCA"["ACGT".index(s)] for s in seq[::-1]) dmd_sequence_r = reverse_complement(dmd_sequence) # Set up TensorFlow graph seq_t = tf.constant(dmd_sequence_r, tf.string) seq_encoded_bit_manip_t = tf.one_hot(tf_dna_encode_bit_manipulation(seq_t), 4) seq_encoded_lookup_t = tf.one_hot(tf_dna_encode_lookup_table(seq_t), 4) seq_encoded_embedding_table_t = tf_dna_encode_embedding_table(seq_t) # TensorFlow boilerplate session = tf.Session() with session.as_default(): tf.tables_initializer().run() tf.global_variables_initializer().run() # Now benchmark each method print("### Benchmarking bit manipulation method ###") results = timeit.repeat(lambda: session.run(seq_encoded_bit_manip_t), number=args.N, repeat=args.r) print("""Bit manipulation method ({} iterations, {} repeats): Total time: {} Best time: {} """.format(args.N, args.r, sum(results), min(results))) print("### Benchmarking embedding table method ###") results = timeit.repeat(lambda: session.run(seq_encoded_embedding_table_t), number=args.N, repeat=args.r) print("""Embedding table method ({} iterations, {} repeats): Total time: {} Best time: {} """.format(args.N, args.r, sum(results), min(results))) print("### Benchmarking lookup table method ###") results = timeit.repeat(lambda: session.run(seq_encoded_lookup_t), number=args.N, repeat=args.r) print("""Lookup table method ({} iterations, {} repeats): Total time: {} Best time: {} """.format(args.N, args.r, sum(results), min(results)))