Keras Model模型方法_qq_41007606的博客-程序员宅基地_model.metrics_names

Model模型方法
compile

compile(self, optimizer, loss, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None)

本函数编译模型以供训练,参数有
optimizer:优化器,为预定义优化器名或优化器对象,参考优化器
loss:损失函数,为预定义损失函数名或一个目标函数,参考损失函数
metrics:列表,包含评估模型在训练和测试时的性能的指标,典型用法是metrics=[‘accuracy’]如果要在多输出模型中为不同的输出指定不同的指标,可像该参数传递一个字典,例如metrics={‘ouput_a’: ‘accuracy’}
sample_weight_mode:如果你需要按时间步为样本赋权(2D权矩阵),将该值设为“temporal”。默认为“None”,代表按样本赋权(1D权)。如果模型有多个输出,可以向该参数传入指定sample_weight_mode的字典或列表。在下面fit函数的解释中有相关的参考内容。
weighted_metrics: metrics列表,在训练和测试过程中,这些metrics将由sample_weight或clss_weight计算并赋权
target_tensors: 默认情况下,Keras将为模型的目标创建一个占位符,该占位符在训练过程中将被目标数据代替。如果你想使用自己的目标张量(相应的,Keras将不会在训练时期望为这些目标张量载入外部的numpy数据),你可以通过该参数手动指定。目标张量可以是一个单独的张量(对应于单输出模型),也可以是一个张量列表,或者一个name->tensor的张量字典。
kwargs:使用TensorFlow作为后端请忽略该参数,若使用Theano/CNTK作为后端,kwargs的值将会传递给 K.function。如果使用TensorFlow为后端,这里的值会被传给tf.Session.run
当为参数传入非法值时会抛出异常
【Tips】如果你只是载入模型并利用其predict,可以不用进行compile。在Keras中,compile主要完成损失函数和优化器的一些配置,是为训练服务的。predict会在内部进行符号函数的编译工作(通过调用_make_predict_function生成函数)

fit

fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

本函数用以训练模型,参数有:
x:输入数据。如果模型只有一个输入,那么x的类型是numpy array,如果模型有多个输入,那么x的类型应当为list,list的元素是对应于各个输入的numpy array。如果模型的每个输入都有名字,则可以传入一个字典,将输入名与其输入数据对应起来。
y:标签,numpy array。如果模型有多个输出,可以传入一个numpy array的list。如果模型的输出拥有名字,则可以传入一个字典,将输出名与其标签对应起来。
batch_size:整数,指定进行梯度下降时每个batch包含的样本数。训练时一个batch的样本会被计算一次梯度下降,使目标函数优化一步。
epochs:整数,训练终止时的epoch值,训练将在达到该epoch值时停止,当没有设置initial_epoch时,它就是训练的总轮数,否则训练的总轮数为epochs - inital_epoch
verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录
callbacks:list,其中的元素是keras.callbacks.Callback的对象。这个list中的回调函数将会在训练过程中的适当时机被调用,参考回调函数
validation_split:0~1之间的浮点数,用来指定训练集的一定比例数据作为验证集。验证集将不参与训练,并在每个epoch结束后测试的模型的指标,如损失函数、精确度等。注意,validation_split的划分在shuffle之后,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split,否则可能会出现验证集样本不均匀。
validation_data:形式为(X,y)或(X,y,sample_weights)的tuple,是指定的验证集。此参数将覆盖validation_spilt。
shuffle:布尔值,表示是否在训练过程中每个epoch前随机打乱输入样本的顺序。
class_weight:字典,将不同的类别映射为不同的权值,该参数用来在训练过程中调整损失函数(只能用于训练)。该参数在处理非平衡的训练数据(某些类的训练样本数很少)时,可以使得损失函数对样本数不足的数据更加关注。
sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode=‘temporal’。
initial_epoch: 从该参数指定的epoch开始训练,在继续之前的训练时有用。
steps_per_epoch: 一个epoch包含的步数(每一步是一个batch的数据送入),当使用如TensorFlow数据Tensor之类的输入张量进行训练时,默认的None代表自动分割,即数据集样本数/batch样本数。
validation_steps: 仅当steps_per_epoch被指定时有用,在验证集上的step总数。
输入数据与规定数据不匹配时会抛出错误
fit函数返回一个History的对象,其History.history属性记录了损失函数和其他指标的数值随epoch变化的情况,如果有验证集的话,也包含了验证集的这些指标变化情况
evaluate
evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None)

本函数按batch计算在某些输入数据上模型的误差,其参数有:
x:输入数据,与fit一样,是numpy array或numpy array的list
y:标签,numpy array
batch_size:整数,含义同fit的同名参数
verbose:含义同fit的同名参数,但只能取0或1
sample_weight:numpy array,含义同fit的同名参数
本函数返回一个测试误差的标量值(如果模型没有其他评价指标),或一个标量的list(如果模型还有其他的评价指标)。model.metrics_names将给出list中各个值的含义。
如果没有特殊说明,以下函数的参数均保持与fit的同名参数相同的含义
如果没有特殊说明,以下函数的verbose参数(如果有)均只能取0或1

predict

predict(self, x, batch_size=32, verbose=0)

本函数按batch获得输入数据对应的输出,其参数有:
函数的返回值是预测值的numpy array

train_on_batch

train_on_batch(self, x, y, class_weight=None, sample_weight=None)

本函数在一个batch的数据上进行一次参数更新
函数返回训练误差的标量值或标量值的list,与evaluate的情形相同。

test_on_batch

test_on_batch(self, x, y, sample_weight=None)

本函数在一个batch的样本上对模型进行评估
函数的返回与evaluate的情形相同

predict_on_batch

predict_on_batch(self, x)

本函数在一个batch的样本上对模型进行测试
函数返回模型在一个batch上的预测结果

fit_generator

fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)

利用Python的生成器,逐个生成数据的batch并进行训练。生成器与模型将并行执行以提高效率。例如,该函数允许我们在CPU上进行实时的数据提升,同时在GPU上进行模型训练
函数的参数是:
generator:生成器函数,生成器的输出应该为:
一个形如(inputs,targets)的tuple
一个形如(inputs, targets,sample_weight)的tuple。所有的返回值都应该包含相同数目的样本。生成器将无限在数据集上循环。每个epoch以经过模型的样本数达到samples_per_epoch时,记一个epoch结束。
steps_per_epoch:整数,当生成器返回steps_per_epoch次数据时计一个epoch结束,执行下一个epoch。
epochs:整数,数据迭代的轮数。
verbose:日志显示,0为不在标准输出流输出日志信息,1为输出进度条记录,2为每个epoch输出一行记录。
validation_data:具有以下三种形式之一
1>生成验证集的生成器
2>一个形如(inputs,targets)的tuple
3>一个形如(inputs,targets,sample_weights)的tuple
validation_steps: 当validation_data为生成器时,本参数指定验证集的生成器返回次数
class_weight:规定类别权重的字典,将类别映射为权重,常用于处理样本不均衡问题。
sample_weight:权值的numpy array,用于在训练时调整损失函数(仅用于训练)。可以传递一个1D的与样本等长的向量用于对样本进行1对1的加权,或者在面对时序数据时,传递一个的形式为(samples,sequence_length)的矩阵来为每个时间步上的样本赋不同的权。这种情况下请确定在编译模型时添加了sample_weight_mode=‘temporal’。
workers:最大进程数
max_q_size:生成器队列的最大容量
pickle_safe: 若为真,则使用基于进程的线程。由于该实现依赖多进程,不能传递non picklable(无法被pickle序列化)的参数到生成器中,因为无法轻易将它们传入子进程中。
initial_epoch: 从该参数指定的epoch开始训练,在继续之前的训练时有用。
函数返回一个History对象
例子

def generate_arrays_from_file(path):
    while 1:
    f = open(path)
    for line in f:
        # create numpy arrays of input data
        # and labels, from each line in the file
        x1, x2, y = process_line(line)
        yield ({'input_1': x1, 'input_2': x2}, {'output': y})
    f.close()

model.fit_generator(generate_arrays_from_file('/my_file.txt'),
        steps_per_epoch=10000, epochs=10)

evaluate_generator

evaluate_generator(self, generator, steps, max_q_size=10, workers=1, pickle_safe=False)

本函数使用一个生成器作为数据源,来评估模型,生成器应返回与test_on_batch的输入数据相同类型的数据。
函数的参数是:
generator:生成输入batch数据的生成器
val_samples:生成器应该返回的总样本数
steps:生成器要返回数据的轮数
max_q_size:生成器队列的最大容量
nb_worker:使用基于进程的多线程处理时的进程数
pickle_safe:若设置为True,则使用基于进程的线程。注意因为它的实现依赖于多进程处理,不可传递不可pickle的参数到生成器中,因为它们不能轻易的传递到子进程中。

predict_generator

predict_generator(self, generator, steps, max_queue_size=10, workers=1, use_multiprocessing=False, verbose=0)

从一个生成器上获取数据并进行预测,生成器应返回与predict_on_batch输入类似的数据
函数的参数是:
generator:生成输入batch数据的生成器
val_samples:生成器应该返回的总样本数
max_q_size:生成器队列的最大容量
nb_worker:使用基于进程的多线程处理时的进程数
pickle_safe:若设置为True,则使用基于进程的线程。注意因为它的实现依赖于多进程处理,不可传递不可pickle的参数到生成器中,因为它们不能轻易的传递到子进程中。

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_41007606/article/details/83183452

智能推荐

【FZU2150】Fire Game(两起点bfs)_糖炒栗之的博客-程序员宅基地

题目链接 Problem 2150 Fire Game Accept: 3827    Submit: 13039Time Limit: 1000 mSec    Memory Limit : 32768 KB Problem DescriptionFat brother and Maze are playing a kind of special (hentai) game o...

干货:NIST评测(SRE19)获胜团队声纹识别技术分析 | CSDN博文精选_AI科技大本营的博客-程序员宅基地

作者|xjdier来源 |CSDN博文精选(*点击阅读原文,查看作者更多精彩文章)近日,NIST说话人识别技术评测 (Speaker Recognition Evaluation,...

从基本组件到结构创新,67页论文解读深度卷积神经网络架构_AI科技大本营的博客-程序员宅基地

作者 | 王广胜来源 | 我爱计算机视觉(ID:aicvml)【导读】近期一篇CNN综述文章《A Survey of the Recent Architectures of Deep Convolutional Neural Networks 》发布,受到了大家的关注,今天作者对论文中的内容做了中文的解读,帮助大家全面了解CNN架构进展。论文地址:https://arxiv.org/pdf/190...

量子计算+人工智能——这才是未来科技的最大热门!_AI科技大本营的博客-程序员宅基地

编译 | AI科技大本营参与 | shawn编辑 | 明明90年代初,当卫奇塔州立大学(Wichita State University)的物理学教授Elizabeth Behrman开始结合量子物理学和人工智能(主要是当时备受争议的神经网络技术)时,大多数人认为这两门学科就像油和水一样,根本没办法结合。“当时我连发表论文都很困难。神经网络学术期刊问我‘量子力学是什么’,而物理学期刊则会问‘神经网

NLG之语言模型_搬用工tyler的博客-程序员宅基地_nlg models

语言模型演化1.N-Gram概率语言模型需要做平滑处理,因为语料不能覆盖所有情况,否则概率都为0,无法生成句子(数据稀疏问题)2.基于NN(神经网络)与N-Gram模型很像是矩阵因子分解(Matrix Factorization)的进化相比N-Gram减少了参数量3.基于RNN(循环神经网络)可以依赖更长的信息减少了参数量4.Class-based Languag...

mRMRe进行特征选择_吹泡泡的宗介的博客-程序员宅基地_mrmr特征筛选

我会分享一下我使用的mRMRe包的整个代码过程,以及遇到的一些坑,有问题的欢迎一起讨论。首先,在Rgui上安装mRMRe包,选择顶部栏的程序包,先加载一下CRAN国内镜像,选择China中的任何一个就行,然后选择安装程序包,找到mRMRe包,会自行安装(耐心寻找,按字母排序,不知道为啥没发明个搜索框搜索的)ps:下面这种安装方式只能处理46340个特征安装完成后,输入library(mRMRe)没有报错即为安装成功。然后接下来就可以愉快的使用了这里代码也可参考:https://blog.

随便推点

UICollectionFlowLayOut_daiqiao_ios的博客-程序员宅基地

1.自定义UICollectionFlowLayOut 支持长按 拖动Cell 交换位置 支持水平和垂直两个方位的滚动 2.支持拖动Cell 到自定义 附加的View 来选择时复制Cell 还是删除Cell 附加的View可以自己定义 在相应的代理中实现即可 2.Demo中用使用的例子,可以根据自己的需求变化class ViewController: U...

Nginx 负载均衡 - fair_程序员35的博客-程序员宅基地_nginx负载均衡fair

学习在 Nginx 中使用 fair 模块(第三方)来实现负载均衡,fair 采用的不是内建负载均衡使用的轮换的均衡算法,而是可以根据页面大小、响应时间智能的进行负载均衡。1 准备工作nginx-upstream-fair 官方下载地址:https://github.com/gnosek/nginx-upstream-fair版本问题:如果使用的 Nginx 版本 >= 1.14....

Android ViewPager2获取当前fragment_飞不高的鱼的博客-程序员宅基地_viewpager2获取当前fragment

献上结论:index 是想获取的fragment的索引思路:第一时间肯定想在源码里找到相应的API发现没有还是要看源码,有没有在添加Fragment的时候设置Tag?

回家的票抢上了吗?聊聊12306为什么时不时要崩一下_AI科技大本营的博客-程序员宅基地

作者 | 半佛仙人来源|仙人JUMP(ID:xrtiaotiao)放假了吗?过年回家的火车票,你们买到了吗?我知道你们很多人都没有买到,我能感受到你们内心的绝望。前几天12306崩了...

Mac下Intellij IDea发布Java Web项目详解四 为所有Module配置Tomcat Deployment_cy20180101的博客-程序员宅基地

step5 为所有项目配置Deployment5.1 如图5.2 【+】【Artifact】5.3 将这里列出的所有内容选中后,点【OK】5.4 选完是这样,表示,这三个java ee 项目会在tomcat启动后,自动发布到Application context路径下。  5.5 为每个web项目设置Application context【WebWorkSpace1】-【/】=== 【http:l...

还在为女性成为程序员而质疑吗!女程序员数量激增,IT行业成为时代风向标_测试萌萌的博客-程序员宅基地

只要说到程序员,大家脑中的第一印象总是穿着格子衫、发量稀疏的男性,但随着互联网行业的不断发展,越多越多的女性程序员悄然进入,为编程界注入新的色彩。据《中国女性程序员职场力大数据报告》中显示,女性程序员的数量在近几年间增长70%,同时女性程序员还体现出了更高的学习消费热情和年轻化特征——女性程序员学习花费达男程序员1.5倍,00后女生则更愿意成为程序员。