當前位置 主頁 > 網站技術 > 代碼類 > 最大化 縮小

    pytorch標簽轉onehot形式實例

    欄目:代碼類 時間:2020-01-02 12:04

    代碼:

    import torch
    
    class_num = 10
    batch_size = 4
    label = torch.LongTensor(batch_size, 1).random_() % class_num
    print(label.size())
    
    one_hot = torch.zeros(batch_size, class_num).scatter_(1, label, 1)
    print(one_hot)
    

    輸出:

    torch.Size([4, 1])
    tensor([[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
    [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

    注意:

    label的形狀必須是[n,1]的,也就是必須是二維的,且第二個維度長度為1,如果是一維度的,則需要升維度,代碼如下:

    import torch
    
    class_num = 10
    batch_size = 4
    label = torch.LongTensor(batch_size).random_() % class_num
    print(label.size())
    label = torch.unsqueeze(label,dim=1)
    print(label.size())
    

    以上這篇pytorch標簽轉onehot形式實例就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持IIS7站長之家。

    下一篇:沒有了
青海十一选五开奖数据