Usage

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_name = 'csdc-atl/doc2query'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def create_queries(history, next_question):
    inputs_ids = []
    for line in history:
        inputs_ids.extend([32127]+tokenizer.encode(line[0], add_special_tokens=False)+[32126]+tokenizer.encode(line[1], add_special_tokens=False))
    inputs_ids.extend([32127]+tokenizer.encode(next_question, add_special_tokens=False))
    inputs_ids = inputs_ids + [1]
    inputs_ids = torch.Tensor([inputs_ids]).long()
    with torch.no_grad():
        sampling_outputs = model.generate(
            input_ids=inputs_ids,
            max_length=512,
            do_sample=True,
            top_p=0.95,
            top_k=10
            )
    print("\nSampling Outputs:")
    for i in range(len(sampling_outputs)):
        rewrite_question = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
        print(f'{i + 1}: {rewrite_question}')
history = [['loghub是什么', 'AWS 上的loghub解决方案可帮助组织在单个控制面板上收集、分析和显示 Amazon CloudWatch Logs。该解决方案可整合、管理和分析来自各种来源的日志文件,例如访问、配置更改和计费事件的审计日志。您也可以从多个账户和 AWS 区域收集 Amazon CloudWatch Logs。']]
next_question = '它的优点是什么?'
create_queries(history, next_question)
# 1: loghub解决方案的优点是什么?
Downloads last month
11
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.