昇思25天学习打卡营第7天|函数式自动微分

函数式自动微分

  • 概念
  • 函数与计算图
  • 微分函数与梯度计算
  • 自定义神经网络梯度计算
  • 参考

概念

神经网络的训练主要使用反向传播算法,模型预测值(logits)与正确标签(label)送入损失函数(loss function)获得loss,然后进行反向传播计算,求得梯度(gradients),最终更新至模型参数(parameters)。自动微分能够计算可导函数在某点处的导数值,是反向传播算法的一般化。自动微分主要解决的问题是将一个复杂的数学运算分解为一系列简单的基本运算,该功能对用户屏蔽了大量的求导细节和过程,大大降低了框架的使用门槛。

函数与计算图

计算图是用图论语言表示数学函数的一种方式,也是深度学习框架表达神经网络模型的统一方法。
计算图
根据计算图构造计算函数和神经网络,x为输入,y为正确值,wb是需要优化的参数。
代码示例:

import numpy as np
import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor, Parameter

x = ops.ones(5, mindspore.float32)  # input tensor
y = ops.zeros(3, mindspore.float32)  # expected output
w = Parameter(Tensor(np.random.randn(5, 3), mindspore.float32), name='w') # weight
b = Parameter(Tensor(np.random.randn(3,), mindspore.float32), name='b') # bias

# 构造计算函数,其中 binary_cross_entropy_with_logits 是一个损失函数,计算预测值和目标值之间的二值交叉熵损失
def function(x, y, w, b):
    z = ops.matmul(x, w) + b
    loss = ops.binary_cross_entropy_with_logits(z, y, ops.ones_like(z), ops.ones_like(z))
    return loss

# 执行计算函数
loss = function(x, y, w, b)
print(loss)

# 运行结果:
# 1.3518304

微分函数与梯度计算

为了优化模型参数,需要求参数对loss的导数: ∂ loss ⁡ ∂ w \frac{\partial \operatorname{loss}}{\partial w} wloss ∂ loss ⁡ ∂ b \frac{\partial \operatorname{loss}}{\partial b} bloss,可以调用mindspore.grad函数,来获得function的微分函数。
代码示例:

# fn:待求导的函数
# grad_position:指定求导输入位置的索引
grad_fn = mindspore.grad(function, (2, 3))

# 执行微分函数
grads = grad_fn(x, y, w, b)
# 打印出w、b对应的梯度
print(grads)

# 运行结果:
'''
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01],
 [ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01],
 [ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01],
 [ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01],
 [ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01]))
'''

自定义神经网络梯度计算

可以通过Cell构造神经网络,然后利用函数式自动微分来实现反向传播。
代码示例:

# 定义神经网络模型
class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.w = w
        self.b = b

    def construct(self, x):
        z = ops.matmul(x, self.w) + self.b
        return z

# 实例化模型和损失函数
model = Network()
loss_fn = nn.BCEWithLogitsLoss()

# 定义前向计算函数
def forward_fn(x, y):
    z = model(x)
    loss = loss_fn(z, y)
    return loss

grad_fn = mindspore.value_and_grad(forward_fn, None, weights=model.trainable_params())
loss, grads = grad_fn(x, y)
print(grads)

# 运行结果:
'''
(Tensor(shape=[5, 3], dtype=Float32, value=
[[ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01],
 [ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01],
 [ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01],
 [ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01],
 [ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01]]), Tensor(shape=[3], dtype=Float32, value= [ 1.74830750e-01,  2.91909929e-02,  3.20021242e-01]))
'''

我们会发现,自定义神经网络与MindSpore内置梯度计算函数,所得到的梯度值结果一致。

参考

MindSpore教程

截图时间截图时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/765933.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

这几类热销品被Ozon限制销售,ozon还有什么产品好卖?

OZON是俄罗斯最大的B2C电商平台,占据俄罗斯电商市场份额的62%,日均订单量高达37万单,拥有超过1600万的活跃用户。ozon平台对中国卖家招商的产品品类涵盖了多个领域,但近日Ozon官方发布将对这三大类目实行销售限制,一起…

DNS访问百度

DNS,英文全称是 domain name system,域名解析系统,它的作用也很明确,就是域名和 IP 相互映射。 假设你要查询 baidu.com 的 IP 地址: 首先会查找浏览器的缓存,看看是否能找到 baidu.com 对应的IP地址,找到就直接返回&…

【热门会议|见刊快】2024年管理创新与教育国际会议 (ICMIE 2024)

2024年管理创新与教育国际会议 (ICMIE 2024) 2024 International Conference on Management Innovation and Education 【重要信息】 大会地点:洛阳 大会官网:http://www.icicmie.com 投稿邮箱:icicpsssub-conf.com 【注意:稿将稿…

工厂方法模式:概念与应用

目录 工厂方法模式工厂方法模式结构工厂方法适合的应用场景工厂方法模式的优缺点练手题目题目描述输入描述输出描述**提示信息**解题: 工厂方法模式 工厂方法模式是一种创建型设计模式, 其在父类中提供一个创建对象的方法, 允许子类决定实例…

苹果电脑废纸篓数据被清空了,有什么方法可以恢复吗?

使用电脑的用户都知道,被删除的文件一般都会经过回收站,想要恢复它直接点击“还原”就可以恢复到原始位置。mac电脑同理也是这样,但是“回收站”在mac电脑显示为“废纸篓”。 苹果电脑废纸篓数据被清空了,有什么方法可以恢复吗&am…

页面速度是如何影响SEO的?

搜索引擎使用复杂的算法来衡量您网站的重要方面,以决定是否向您发送流量。 搜索引擎使用您网站的小元素来确定您网站的质量和真实性,然后此操作将转化为您的网页在搜索引擎结果页面 中出现的位置。提高您在 SERP 中的排名的过程称为搜索引擎优化 (SEO)。…

在 Mac 上使用 本地 LLM 文本终结

我们可使用本地大型语言模型,如Mistral、Llama等,来给文本做总结,相比在线的 Kimi ,ChatGPT, 我们不用担心数据泄露,因为整个操作都是在本地电脑完成的。 我们用 ollama 举例 首先安装 ollama https://ol…

从零搭建Prometheus到Grafana告警推送

目录 一、Prometheus源码安装和动态更新配置 二、Prometheus操作面板和常见配置 三、Prometheus常用监控组件exporter配置 3.1 exporter是什么 3.2 有哪些exporter 3.3 exporter怎么用 3.4 实战 node_exporter ​3.5 其它exporter都怎么用 四、Promethus整合新版Sprin…

数据结构常见图算法

深度优先搜索 时间复杂度 领接矩阵表示 O( n2) 领接表表示 O(n+e) 空间复杂度 O(e) DFS与回溯法类似,一条路径走到底后需要返回上一步,搜索第二条路径。在树的遍历中,首先一直访问到最深的节点,然后回溯到它的父节点,遍历另一条路径,直到遍历完所有节点…

怎样在《语文世界》期刊上发表论文?

怎样在《语文世界》期刊上发表论文? 《语文世界》知网国家级 1.5-2版 2500字符左右 正常收25年4-6月版面 可加急24年内(初中,高中,中职,高职,大学均可,操作周期2个月左右) 《语文世…

【CH32V305FBP6】USBD HS 虚拟串口分析

文章目录 前言分析端点 0USBHS_UIS_TOKEN_OUT 端点 2USBHS_UIS_TOKEN_OUTUSBHS_UIS_TOKEN_IN 前言 虚拟串口,端口 3 单向上报,端口 2 双向收发。 分析 端点 0 USBHS_UIS_TOKEN_OUT 设置串口参数: 判断 USBHS_SetupReqCode CDC_SET_LIN…

解锁应用商店新玩法:Xinstall渠道包,让你的App推广效率飙升

在移动应用竞争日益激烈的今天,如何在众多应用商店中脱颖而出,实现精准推广与高效获客,成为每位App开发者与广告主的共同追求。幸运的是,Xinstall作为一款一站式App全渠道统计服务商,以其专业的渠道包解决方案&#xf…

Yi-1.5 9B Chat 上线Amazon SageMaker JumpStart

你是否对简单的API调用大模型感到不满足?是否因为无法亲自部署属于自己的大模型而烦恼? 好消息来了,Amazon SageMaker JumpStart 初体验 CloudLab实验上线啦! 本实验将以零一万物最新发布的中文基础模型 Yi-1.5 9B Chat 为例&am…

如何指定Microsoft Print To PDF的输出路径

在上一篇文章中,介绍了三种将文件转换为PDF的方式。默认情况下,在Microsoft Print To PDF的首选项里,是看不到输出路径的设置的。 需要一点小小的手段。 运行输入 control 打开控制面板,选择硬件和声音下的查看设备和打印机 找到…

在卷积神经网络(CNN)中为什么可以使用多个较小的卷积核替代一个较大的卷积核,以达到相同的感受野

在卷积神经网络(CNN)中为什么可以使用多个较小的卷积核替代一个较大的卷积核,以达到相同的感受野 flyfish 在卷积神经网络(CNN)中,可以使用多个较小的卷积核替代一个较大的卷积核,以达到相同的…

探索大型语言模型自动评估 LLM 输出长句准确性的方法

LLM现在能够自动评估较长文本中的事实真实性 源码地址:https://github.com/google-deepmind/long-form-factuality 论文地址:https://arxiv.org/pdf/2403.18802.pdf 这篇论文是关于谷歌DeepMind的,提出了新的数据集、评估方法和衡量标准&am…

一篇文章搞懂时间复杂度和空间复杂度

不知道小伙伴们有没有刷过力扣上的算法题,我在上研究生的时候,刷过了前40道题,上面的算法题,我觉得还挺难的,当你写完代码的时候,就可以提交自己写的代码到系统上,系统会给你写的代码计算时间复…

嵌入式c语言1——gcc以及linux嵌入式

GCC全名GNU Complier Collection,是一个开源的程序语言解释器,运行在linux系统中 对以程序名后缀结尾源代码文件,gcc可以做解释并生成可执行文件

uniapp做小程序内打开地图展示位置信息

使用场景&#xff1a;项目中需要通过位置信息打开地图查看当前位置信息在地图那个位置&#xff0c;每个酒店有自己的经纬度和详细地址&#xff0c;点击地图按钮打开内置地图如图 方法如下&#xff1a; <view class"dttu" click"openMap(info.locationY,info.…

解决Linux环境Qt报“cannot find -lgl“问题

今天&#xff0c;在Ubuntu 18.04.6环境下&#xff0c;安装Qt5.14.2之后&#xff0c;运行一个QWidget工程&#xff0c;发现Qt报"cannot find -lgl"错误。     出现这种现象的原因&#xff1a;Qt的Path路径没有配置&#xff0c;缺少libqt4-dev依赖包和一些必要的组件…