我想在输入函数中执行基本的预处理和标记化 . 我的数据包含在谷歌 Cloud 存储桶位置(gs://)中的csv中,我无法修改 . 此外,我在ml-engine包中对输入文本执行任何修改,以便可以在服务时复制行为 .
我的输入函数遵循以下基本结构:
filename_queue = tf.train.string_input_producer(filenames)
reader = tf.TextLineReader()
_, rows = reader.read_up_to(filename_queue, num_records=batch_size)
text, label = tf.decode_csv(rows, record_defaults = [[""],[""]])
# add logic to filter special characters
# add logic to make all words lowercase
words = tf.string_split(text) # splits based on white space
是否有任何选项可以避免事先对整个数据集执行此预处理?这post表明tf.py_func()可用于进行这些转换,但是他们建议"The drawback is that as it is not saved in the graph, I cannot restore my saved model"所以我不相信这在服务时间会有用 . 如果我正在定义我自己的tf.py_func()来进行预处理,并且我在上传到 Cloud 的培训包中定义了我会遇到任何问题吗?有没有我不考虑的替代选择?
1 回答
最佳做法是编写一个函数,您可以从training / eval input_fn和服务input_fn调用该函数 .
例如:
然后,在input_fn中,通过调用add_engineered来包装返回的功能:
并且在您的serving_input fn中,确保通过调用add_engineered类似地包装返回的要素(而不是feature_placeholders):
你的模型会使用'单词' . 但是,您在预测时的JSON输入只需要包含“文本”即原始值 .
这是一个完整的工作示例:
https://github.com/GoogleCloudPlatform/training-data-analyst/blob/master/courses/machine_learning/feateng/taxifare/trainer/model.py#L107