SEO:

libtorch 如何 OneHot ?

torch OneHot 源代码 ?

https://www.tensorflow.org/api_docs/python/tf/one_hot

最新的 1.3 版本中已经添加了该函数

#include <torch/torch.h>
#include <c10/util/StringUtil.h>
torch::Tensor one_hot(const torch::Tensor &self, int64_t num_classes) {
AT_CHECK(self.dtype() == torch::kLong, "one_hot is only applicable to index tensor.");
auto shape = self.sizes().vec(); // empty tensor could be converted to one hot representation,
// but shape inference is not possible.
if (self.numel() == 0) {
if (num_classes <= 0) {
AT_ERROR("Can not infer total number of classes from empty tensor.");
}
else {
shape.push_back(num_classes);
return at::empty(shape, self.options());
}
} // non-empty tensor
AT_CHECK(self.min().item().toLong() >= 0, "Class values must be non-negative.");
if (num_classes == -1) {
num_classes = self.max().item().toLong() + 1;
}
else {
AT_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes.");
} shape.push_back(num_classes);
torch::Tensor ret = at::zeros(shape, self.options());
ret.scatter_(-1, self.unsqueeze(-1), 1);
return ret;
}

使用示例

	torch::TensorOptions options(torch::kLong);
auto tensor = torch::tensor({ 0,1,2 }, options);
std::cout << tensor << std::endl; try
{
auto one_hot = torch::one_hot(tensor,4);
std::cout << one_hot << std::endl;
}
catch (const c10::Error& watch)
{
std::cout << watch.msg() << std::endl;
}

随机推荐

  1. powerdesigner数据库设计

    (1)创建物理数据模型 打开PowerDesigner,然后点击File-->New  Model然后选择如下图所示的物理数据模型(物理数据模型的名字自己起,然后选择自己所使用的数据库即可) ( ...

  2. selenium下拉菜单

    from selenium.webdriver.support.select import Selectdef select_value(self, css, value):    '''    选中 ...

  3. 【学习】011 JVM参数调优配置

    自动内存管理机制 Java虚拟机原理 所谓虚拟机,就是一台虚拟的机器.他是一款软件,用来执行一系列虚拟计算指令,大体上虚拟机可以分为 系统虚拟机和程序虚拟机, 大名鼎鼎的Visual Box.Vmar ...

  4. AIX系统软件安装问题

    一.安装软件时一定要cd到介质目录中 二.选择accept new licence 三.更新系统时避免使用updata_all,要手动选择出要更新的软件 四.oracle11G的rac还要用到open ...

  5. Java初步

    Java的核心优势:跨平台 Java SE:标准版Java EE:企业级Java ME:微型版 源文件(*.java)→编译器→字节码文件(*.class)→(类装载器→字节码校验器→解释器)[JRE ...

  6. 关于memset

    memset填充的是一个字节,比方下面的一段程序: #include <cstdio> #include <cstring> using namespace std; ]; i ...

  7. OPTIONS请求后台处理 跨域Filter

    import cn.hutool.http.Method; import org.springframework.web.filter.OncePerRequestFilter; import jav ...

  8. Web开发中的服务器跳转与客户端跳转

    两者比较如下: 跳转类型  客户端请求次数 服务端响应次数 URL变化 站外跳转 作用域 服务器跳转 1 1 无 否 pageContext.request.session.application 客 ...

  9. oppo面试题

    1.synchronized和Lock有什么区别?哪个可重入?哪个效率高? synchronized和Lock都用于线程同步的场景中. synchronized是jdk的关键字,用来构造同步代码块或者 ...

  10. vim安装bundle和使用

    一.准备工作 安装Git(因为下面我们选择的插件管理器需要使用到它)安装其他插件前首先需要选择一个Vim插件管理器,我这里选择的是Vundle,Vundle的工作过程中需要通过Git自动从远程创库同步 ...