使用tensorflow2自定义损失函数需要注意什么
发表于:2024-11-18 作者:千家信息网编辑
千家信息网最后更新 2024年11月18日,小编给大家分享一下使用tensorflow2自定义损失函数需要注意什么,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!Ker
千家信息网最后更新 2024年11月18日使用tensorflow2自定义损失函数需要注意什么
小编给大家分享一下使用tensorflow2自定义损失函数需要注意什么,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!
Keras的核心原则是逐步揭示复杂性,可以在保持相应的高级便利性的同时,对操作细节进行更多控制。当我们要自定义fit中的训练算法时,可以重写模型中的train_step方法,然后调用fit来训练模型。
这里以tensorflow2官网中的例子来说明:
import numpy as npimport tensorflow as tffrom tensorflow import kerasx = np.random.random((1000, 32))y = np.random.random((1000, 1))class CustomModel(keras.Model): tf.random.set_seed(100) def train_step(self, data): # Unpack the data. Its structure depends on your model and # on what you pass to `fit()`. x, y = data with tf.GradientTape() as tape: y_pred = self(x, training=True) # Forward pass # Compute the loss value # (the loss function is configured in `compile()`) loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) # Compute gradients trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) # Update weights self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # Update metrics (includes the metric that tracks the loss) self.compiled_metrics.update_state(y, y_pred) # Return a dict mapping metric names to current value return {m.name: m.result() for m in self.metrics} # Construct and compile an instance of CustomModelinputs = keras.Input(shape=(32,))outputs = keras.layers.Dense(1)(inputs)model = CustomModel(inputs, outputs)model.compile(optimizer="adam", loss=tf.losses.MSE, metrics=["mae"])# Just use `fit` as usualmodel.fit(x, y, epochs=1, shuffle=False)32/32 [==============================] - 0s 1ms/step - loss: 0.2783 - mae: 0.4257
这里的loss是tensorflow库中实现了的损失函数,如果想自定义损失函数,然后将损失函数传入model.compile中,能正常按我们预想的work吗?
答案竟然是否定的,而且没有错误提示,只是loss计算不会符合我们的预期。
def custom_mse(y_true, y_pred): return tf.reduce_mean((y_true - y_pred)**2, axis=-1)a_true = tf.constant([1., 1.5, 1.2])a_pred = tf.constant([1., 2, 1.5])custom_mse(a_true, a_pred)tf.losses.MSE(a_true, a_pred)
以上结果证实了我们自定义loss的正确性,下面我们直接将自定义的loss置入compile中的loss参数中,看看会发生什么。
my_model = CustomModel(inputs, outputs)my_model.compile(optimizer="adam", loss=custom_mse, metrics=["mae"])my_model.fit(x, y, epochs=1, shuffle=False)32/32 [==============================] - 0s 820us/step - loss: 0.1628 - mae: 0.3257
我们看到,这里的loss与我们与标准的tf.losses.MSE明显不同。这说明我们自定义的loss以这种方式直接传递进model.compile中,是完全错误的操作。
正确运用自定义loss的姿势是什么呢?下面揭晓。
loss_tracker = keras.metrics.Mean(name="loss")mae_metric = keras.metrics.MeanAbsoluteError(name="mae")class MyCustomModel(keras.Model): tf.random.set_seed(100) def train_step(self, data): # Unpack the data. Its structure depends on your model and # on what you pass to `fit()`. x, y = data with tf.GradientTape() as tape: y_pred = self(x, training=True) # Forward pass # Compute the loss value # (the loss function is configured in `compile()`) loss = custom_mse(y, y_pred) # loss += self.losses # Compute gradients trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) # Update weights self.optimizer.apply_gradients(zip(gradients, trainable_vars)) # Compute our own metrics loss_tracker.update_state(loss) mae_metric.update_state(y, y_pred) return {"loss": loss_tracker.result(), "mae": mae_metric.result()} @property def metrics(self): # We list our `Metric` objects here so that `reset_states()` can be # called automatically at the start of each epoch # or at the start of `evaluate()`. # If you don't implement this property, you have to call # `reset_states()` yourself at the time of your choosing. return [loss_tracker, mae_metric] # Construct and compile an instance of CustomModelinputs = keras.Input(shape=(32,))outputs = keras.layers.Dense(1)(inputs)my_model_beta = MyCustomModel(inputs, outputs)my_model_beta.compile(optimizer="adam")# Just use `fit` as usualmy_model_beta.fit(x, y, epochs=1, shuffle=False)32/32 [==============================] - 0s 960us/step - loss: 0.2783 - mae: 0.4257
终于,通过跳过在 compile() 中传递损失函数,而在 train_step 中手动完成所有计算内容,我们获得了与之前默认tf.losses.MSE完全一致的输出,这才是我们想要的结果。
以上是"使用tensorflow2自定义损失函数需要注意什么"这篇文章的所有内容,感谢各位的阅读!相信大家都有了一定的了解,希望分享的内容对大家有所帮助,如果还想学习更多知识,欢迎关注行业资讯频道!
函数
损失
内容
篇文章
更多
模型
结果
错误
训练
不同
复杂
明显
高级
一致
不怎么
例子
便利性
原则
参数
只是
数据库的安全要保护哪些东西
数据库安全各自的含义是什么
生产安全数据库录入
数据库的安全性及管理
数据库安全策略包含哪些
海淀数据库安全审计系统
建立农村房屋安全信息数据库
易用的数据库客户端支持安全管理
连接数据库失败ssl安全错误
数据库的锁怎样保障安全
数据库2005怎么用
往数据库中插入数据有没有返回值
天翼翰潮网络技术有限公司
怎么查看服务器的所有操作记录
国开数据库应用技术实验心得
T-SQL可用于什么数据库
软件开发收费计算公式
新产品新技术网络安全
bp网络安全是什么意思
数据库实体和关系模型
小白数据库手机续航排行
网络技术及其应用考试
网络安全的社会环境
首选的数据库管理系统是
奉贤区一站式数据库新报价
非结构化数据库外键
顺义区信息化网络技术服务平台
win10系统搭建dns服务器
给数据库的数据进行加密再解密
u8系统无法创建数据库文件
上海爱高网络技术有限公司
全国报纸数据库
数据库索引有哪些优点
齐齐哈尔公安局网络安全保卫大队
学习通服务器错误怎么弄
微博大v数据库
mysql是一个怎样的数据库
东南亚网络安全
黄浦区一站式数据库服务行业
网络安全有关的游戏