基于__torch_dispatch__机制的dump方法

基于__torch_dispatch__机制的dump方法

  • 1.参考链接
  • 2.原理
  • 3.代码
  • 4.效果

之前拦截torch和torch.Tensor的办法,在处理backward时,不能看到aten算子的细节.以下基于__torch_dispatch__机制的方案更节约代码,且能看到调用栈

1.参考链接

[原理] (https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557)

2.原理

在这里插入图片描述

3.代码

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import torch
from torch import nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import time
import os
import threading

device="cuda"
from torch.utils._python_dispatch import TorchDispatchMode
import inspect
import traceback
from dataclasses import dataclass
from typing import Any

@dataclass
class _ProfilerState:
    cls: Any
    object: Any = None

lock=threading.Lock()
gindex=0
def save_tensor(name,args,index=0):
    if isinstance(args,torch.Tensor):
        print(name,index,args.shape)
        global gindex
        lock.acquire()
        torch.save(args,"{}_{}_{}_{}.pt".format(device,gindex,name,index))
        gindex+=1
        lock.release()
    if isinstance(args,tuple):
        for idx,x in enumerate(args):
            save_tensor(name,x,index+idx)

class TorchDumpDispatchMode(TorchDispatchMode):
    def __init__(self,parent):
        super().__init__()
        self.parent=parent
    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
        func_packet = func._overloadpacket        
        if kwargs is None:
            kwargs = {}        
        enable_dump=False
        if func_packet.__name__ not in ["detach"]:
            enable_dump=True
            print(f"Profiling {func_packet.__name__}") 
            for idx,stack in enumerate(inspect.stack()):
                print(f'{"*"*idx}{stack.filename}{stack.lineno}')
        if enable_dump:     
            save_tensor(f"{func_packet.__name__}-input",args)
        ret= func(*args, **kwargs)
        if enable_dump:
            save_tensor(f"{func_packet.__name__}-output",ret)
        return ret

class TorchDumper:
    _CURRENT_Dumper = None
    def __init__(self,schedule: Any):
        self.p= _ProfilerState(schedule) 

    def __enter__(self):
        assert TorchDumper._CURRENT_Dumper is None
        TorchDumper._CURRENT_Dumper = self
        if self.p.object is None:
            o = self.p.cls(self)
            o.__enter__()
            self.p.object = o
        else:
            self.p.object.step()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        TorchDumper._CURRENT_Dumper = None
        if self.p.object is not None:
            self.p.object.__exit__(exc_type, exc_val, exc_tb)

class Attention(nn.Module):
    def __init__(self,max_seq_len,head_dim,flash):
        super().__init__()
        self.flash = flash
        self.dropout=0
        self.attn_dropout = nn.Dropout(self.dropout)
        self.head_dim=head_dim
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            mask = torch.full((1, 1, max_seq_len, max_seq_len), float("-inf")).to(device)
            mask = torch.triu(mask, diagonal=1).half().to(device)
            self.register_buffer("mask", mask)		
    def forward(
            self,xq: torch.Tensor,xk: torch.Tensor,xv: torch.Tensor):
        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv,
                                                                       attn_mask=None, 
                                                                       dropout_p=self.dropout if self.training else 0.0, is_causal=True)
        else:
            _xk=xk.clone()
            t=_xk.transpose(2, 3)
            scores = torch.matmul(xq,t)
            scores = scores/math.sqrt(self.head_dim)
            a=self.mask[:, :, :seqlen, :seqlen]
            scores = scores+a
            scores = F.softmax(scores.float(), dim=-1)
            scores = scores.type_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores, xv)  
        return output

def main(flash,bs, n_local_heads, seqlen, head_dim):
    torch.random.manual_seed(1)

    q = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
    k = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)
    v = torch.ones((bs, n_local_heads, seqlen, head_dim),dtype=torch.float32).half().to(device)

    q.data.normal_(0, 0.1)
    k.data.normal_(0, 0.1)
    v.data.normal_(0, 0.1)

    q=Variable(q, requires_grad=True).to(device)
    k=Variable(k, requires_grad=True).to(device)
    v=Variable(v, requires_grad=True).to(device)

    gt= torch.randint(0,head_dim,(bs*n_local_heads*seqlen,1)).reshape(-1).to(device)
    loss_func=nn.CrossEntropyLoss().to(device)

    model=Attention(seqlen,head_dim,flash).half().to(device)
    optim = torch.optim.SGD([q,k,v], lr=1.1)

    with TorchDumper(TorchDumpDispatchMode):
        for i in range(1):
            output = model(q,k,v)
            loss=loss_func(output.reshape(-1,head_dim),gt)
            loss.backward()  
            optim.step()
            print("{:.5f},{:.5f},{:.5f},{:.5f}".format(q.sum().item(),k.sum().item(),v.sum().item(),loss.item()))

bs, n_local_heads, seqlen, head_dim = 8, 8, 512, 64
main(False,bs, n_local_heads, seqlen, head_dim)

4.效果

Profiling clone
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py109
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
clone-input 0 torch.Size([8, 8, 512, 64])
clone-output 0 torch.Size([8, 8, 512, 64])
Profiling transpose
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py110
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
transpose-input 0 torch.Size([8, 8, 512, 64])
transpose-output 0 torch.Size([8, 8, 512, 64])
Profiling expand
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
expand-input 0 torch.Size([8, 8, 512, 64])
expand-output 0 torch.Size([8, 8, 512, 64])
Profiling view
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
view-input 0 torch.Size([8, 8, 512, 64])
view-output 0 torch.Size([8, 8, 512, 64])
Profiling expand
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
expand-input 0 torch.Size([8, 8, 64, 512])
expand-output 0 torch.Size([8, 8, 64, 512])
Profiling view
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
view-input 0 torch.Size([8, 8, 64, 512])
view-output 0 torch.Size([8, 8, 64, 512])
Profiling bmm
/home/user/proj/attention/attention_torch_dispatch_dumper.py60
*/home/user/proj/attention/attention_torch_dispatch_dumper.py111
**/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1527
***/home/anaconda3/envs/nvidia_training/lib/python3.10/site-packages/torch/nn/modules/module.py1518
****/home/user/proj/attention/attention_torch_dispatch_dumper.py144
*****/home/user/proj/attention/attention_torch_dispatch_dumper.py151
bmm-input 0 torch.Size([64, 512, 64])
bmm-input 1 torch.Size([64, 64, 512])
bmm-output 0 torch.Size([64, 512, 64])
bmm-output 1 torch.Size([64, 64, 512])
Profiling _unsafe_view

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/578814.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

matlab学习005-利用matlab设计滤波器

目录 一,含有多个频率成分的三角信号 1,以采样频率fs20KHz对信号采样, 画出信号的波形; 1)前期基础 2)波形图 3)代码 2,选取合适的采样点数,利用DFT分析信号的…

FPGA 以太网通信UDP通信环回

1 实验任务 上位机通过网口调试助手发送数据给 FPGA , FPGA 通过 PL 端以太网接口接收数据并将接收到的数据发送给上位机,完成以太网 UDP 数据的环回。 2 系统设计 系统时钟经过PLL时钟模块后,生成了两种不同频率和相位的时钟信号&#…

基于SpringBoot+VueHome F家居系统的设计与实现

系统介绍 该Home F家居系统采用B/S架构、前后端分离以及MVC模型进行设计,并采用Java语言以及SpringBoot框架进行开发。本系统主要设计并完成了用户注册、登录,购买家具过程、个人信息修改等,商家添加家具信息、对家具进行发货,管理…

缓解程序员工作压力:从心理健康到社交网络

缓解程序员工作压力:从心理健康到社交网络 缓解程序员工作压力:从心理健康到社交网络摘要引言工作与休息的平衡制定有效的工作计划定时休息和放松 心理健康与自我关怀培养良好的生活习惯寻找心灵的慰藉 社交与网络建设加入专业社区和论坛建立良好的同事关…

【静态分析】静态分析笔记09 - 污点分析

参考: 【课程笔记】南大软件分析课程—16课时完整版 - 知乎 ------------------------------------------------------------------------------- 1. 信息流安全 访问控制:关注信息访问。 信息流安全:关注信息传播。 信息流&#xff1a…

自己搭建的大疆无人机RTMP流媒体服务延迟太大

流程:无人机摄像头->图传->遥控器->流媒体服务器->取流播放,延迟有10秒来的,大家有没有什么好的方案。

【介绍下有那些常见的ssh功能】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…

python作业 切片逆转

题目: (反转显示一个整数)编写下面的函数,反向显示一个整数。 列如:reserse(3456)。编写一个测试程序,提示用户输入一个整数,然后显示它的反向数。 第一步定义一个函数: def rev…

Linux进程概念(六):进程控制

目录 进程创建 fork函数 进程终止 终止时干了什么 进程终止的三种情况 main函数的返回值 打印默认退出码 自定义退出码 总结 进程终止 exit函数 _exit函数 exit和_exit的区别 进程等待 什么是进程等待 为什么要有进程等待 wait函数 waitpid函数 阻塞等待与…

【前端开发基础知识快速入门】

前端开发基础知识&快速入门 一、VSCode 使用1.1 安装常用插件1.2 创建项目1.3 创建网页1.4 运行效果二、ES62.1 简介2.2 什么是 ECMAScript2.3 ES6 新特性2.3.1 let 声明变量2.3.2 const 声明常量(只读变量)2.3.3 解构表达式2.3.4 字符串扩展2.3.5 函数优化2.3.6 对象优化…

开发日志(20240422):一次以为是跨域但并不是跨域的问题排查记录

1. 日志 在前后端联调的时候,遇到了报错,如下图所示(现在再看感觉非常简单了),发现前一个请求通过了,但是第二个请求报错,然后看到 strict-origin-when-cross-origin 条件反射的认为是跨域配置…

流量网关与服务网关的区别:(面试题,掌握)

流量网关:(如Nignx,OpenResty,Kong)是指提供全局性的、与后端业务应用无关的策略,例如 HTTPS证书认证、Web防火墙、全局流量监控,黑白名单等。 服务网关:(如Spring Clou…

初步认识Java

Java之父 Java 语言源于 1991 年 4 月,Sun 公司 James Gosling博士 领导的绿色计划(Green Project) 开始启动,此计划最初的目标是开发一种能够在各种消费性电子产品(如机顶盒、冰箱、收音机等)上运行的程序架构。这个就是Java的前身: Oak (得…

【Node.js工程师养成计划】之打造自己的脚手架工具

一、创建全局的自定义命令 1、打开一个空文件夹,新建一个bin文件夹,在bin文件夹下新建cli.js文件,js文件可以命名为cli.js(您随意) 2、在cli.js文件中的开头(!!)写下面这…

系统服务(22年国赛)—— 磁盘管理(压缩去重)

前言:原文在我的博客网站中,持续更新数通、系统方面的知识,欢迎来访! 系统服务(22年国赛)—— 磁盘管理(压缩&&去重)https://myweb.myskillstree.cn/90.html 目录 StorageSrv 安装并创建vdo 将…

MIT 6.172 笔记 现代硬件算法案例分析

本文是https://en.algorithmica.org/hpc/和MIT 6.172的课后题解析 课程地址: 文章目录 HW2 Profiling Serial Merge Sort测试DEBUG和非DEBUG区别测试inline和非inline区别Coarsening HW3 向量化为什么用负偏移量测量向量化跨步向量化 HW4 Reducer Hyperobjects比较o…

vue echarts 柱状图 堆叠柱状图

echarts堆叠柱状图&#xff08;效果图在文章末尾&#xff09; 1、默认只显示 月度的 数据&#xff0c;手动点击 legend 季度的 数据才会显示&#xff1b; 2、监听左侧菜单栏的宽度变化&#xff0c;图表宽度自适应展示 <template><div><div id"barChart&q…

【MySQL】A01、性能优化-参数监控分析

1、参数监控 1.1、MySQL command 查看 mysql>SHOW STATUS; &#xff08;服务器状态变量&#xff0c;运行服务器的统计和状态指标&#xff09; mysql> SHOW VARIABLES;&#xff08;服务器系统变量&#xff0c;实际上使用的变量的值&#xff09; mysql> SHOW STATUS …

VTK----VTK数据结构详解1(几何篇)

在讲VTK的数据结构之前&#xff0c;我们先了解可视化数据的两个特征&#xff1a;离散性、有规则或无规则。 离散性。当我们使用计算机去表示我们的数据时&#xff0c;一般都是基于有限数量的点做信息的采样&#xff08;或插值&#xff09;&#xff0c;因此可视化的数据是以一种…

C++笔试强训day8

目录 1.求最小公倍数 2.数组中的最⻓连续⼦序列 3.字母收集 1.求最小公倍数 链接 这就是一道普通的数学题。 最大公倍数 A * B / A 与 B之间的最大公约数。 最大公约数求法&#xff1a;辗转相除法(或者可以用<numeric>头文件中的gcd) #include <iostream> us…
最新文章