Source code for bert_embedding.bert

# coding=utf-8
"""BERT embedding."""
import logging
from typing import List

import mxnet as mx
from mxnet.gluon.data import DataLoader
import gluonnlp
from gluonnlp.data import BERTTokenizer, BERTSentenceTransform

from bert_embedding.dataset import BertEmbeddingDataset

__author__ = 'Gary Lai'

logger = logging.getLogger(__name__)


[docs]class BertEmbedding(object): """ Encoding from BERT model. Parameters ---------- ctx : Context. running BertEmbedding on which gpu device id. dtype: str data type to use for the model. model : str, default bert_12_768_12. pre-trained BERT model dataset_name : str, default book_corpus_wiki_en_uncased. pre-trained model dataset params_path: str, default None path to a parameters file to load instead of the pretrained model. max_seq_length : int, default 25 max length of each sequence batch_size : int, default 256 batch size """
[docs] def __init__(self, ctx=mx.cpu(), dtype='float32', model='bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased', params_path=None, max_seq_length=25, batch_size=256): """ Encoding from BERT model. Parameters ---------- ctx : Context. running BertEmbedding on which gpu device id. dtype: str data type to use for the model. model : str, default bert_12_768_12. pre-trained BERT model dataset_name : str, default book_corpus_wiki_en_uncased. pre-trained model dataset params_path: str, default None path to a parameters file to load instead of the pretrained model. max_seq_length : int, default 25 max length of each sequence batch_size : int, default 256 batch size """ self.ctx = ctx self.dtype = dtype self.max_seq_length = max_seq_length self.batch_size = batch_size self.dataset_name = dataset_name if params_path is not None: # Don't download the pretrained models if we have a parameter path pretrained = False else: pretrained = True self.bert, self.vocab = gluonnlp.model.get_model(model, dataset_name=self.dataset_name, pretrained=pretrained, ctx=self.ctx, use_pooler=False, use_decoder=False, use_classifier=False) self.bert.cast(self.dtype) if params_path: logger.info('Loading params from %s', params_path) self.bert.load_parameters(params_path, ctx=ctx, ignore_extra=True) lower = 'uncased' in self.dataset_name self.tokenizer = BERTTokenizer(self.vocab, lower=lower) self.transform = BERTSentenceTransform(tokenizer=self.tokenizer, max_seq_length=self.max_seq_length, pair=False)
def __call__(self, sentences, oov_way='avg'): """ Get tokens, tokens embedding Parameters ---------- sentences : List[str] sentences for encoding. oov_way : str, default avg. use **avg**, **sum** or **last** to get token embedding for those out of vocabulary words Returns ------- List[(List[str], List[ndarray])] List of tokens, and tokens embedding """ return self.embedding(sentences, oov_way='avg')
[docs] def embedding(self, sentences, oov_way='avg'): """ Get tokens, tokens embedding Parameters ---------- sentences : List[str] sentences for encoding. oov_way : str, default avg. use **avg**, **sum** or **last** to get token embedding for those out of vocabulary words Returns ------- List[(List[str], List[ndarray])] List of tokens, and tokens embedding """ data_iter = self.data_loader(sentences=sentences) batches = [] for token_ids, valid_length, token_types in data_iter: token_ids = token_ids.as_in_context(self.ctx) valid_length = valid_length.as_in_context(self.ctx) token_types = token_types.as_in_context(self.ctx) sequence_outputs = self.bert(token_ids, token_types, valid_length.astype(self.dtype)) for token_id, sequence_output in zip(token_ids.asnumpy(), sequence_outputs.asnumpy()): batches.append((token_id, sequence_output)) return self.oov(batches, oov_way)
def data_loader(self, sentences, shuffle=False): dataset = BertEmbeddingDataset(sentences, self.transform) return DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=shuffle) def oov(self, batches, oov_way='avg'): """ How to handle oov. Also filter out [CLS], [SEP] tokens. Parameters ---------- batches : List[(tokens_id, sequence_outputs, pooled_output]. batch token_ids (max_seq_length, ), sequence_outputs (max_seq_length, dim, ), pooled_output (dim, ) oov_way : str use **avg**, **sum** or **last** to get token embedding for those out of vocabulary words Returns ------- List[(List[str], List[ndarray])] List of tokens, and tokens embedding """ sentences = [] for token_ids, sequence_outputs in batches: tokens = [] tensors = [] oov_len = 1 for token_id, sequence_output in zip(token_ids, sequence_outputs): if token_id == 1: # [PAD] token, sequence is finished. break if token_id in (2, 3): # [CLS], [SEP] continue token = self.vocab.idx_to_token[token_id] if token.startswith('##'): token = token[2:] tokens[-1] += token if oov_way == 'last': tensors[-1] = sequence_output else: tensors[-1] += sequence_output if oov_way == 'avg': oov_len += 1 else: # iv, avg last oov if oov_len > 1: tensors[-1] /= oov_len oov_len = 1 tokens.append(token) tensors.append(sequence_output) if oov_len > 1: # if the whole sentence is one oov, handle this special case tensors[-1] /= oov_len sentences.append((tokens, tensors)) return sentences