밑은 CNN이든 뭐든 어쨌든 latent vector를 만들고 위는 GRU, LSTM등의 시계열로 처리해서 관계성. 나중엔 여기에 내 목적용 classification 용 layer만 붙여 fine tuning해서 사용.
밑은 다른 코드에서 cpc를 작성한거 분석한것.
https://github.com/hhi-aml/ecg-selfsupervised/blob/main/clinical_ts/cpc.py
GitHub - hhi-aml/ecg-selfsupervised: Self-supervised representation learning from 12-lead ECG data
Self-supervised representation learning from 12-lead ECG data - GitHub - hhi-aml/ecg-selfsupervised: Self-supervised representation learning from 12-lead ECG data
github.com
In [1]:
###############
#generic
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import transforms
import torch.nn.functional as F
import torchvision
import os
import argparse
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import copy
#################
#specific
from clinical_ts.timeseries_utils import *
from clinical_ts.ecg_utils import *
from functools import partial
from pathlib import Path
import pandas as pd
import numpy as np
from clinical_ts.xresnet1d import xresnet1d50,xresnet1d101
from clinical_ts.basic_conv1d import weight_init
from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap
from clinical_ts.cpc import *
In [2]:
hparams = {
"accumulate": 1,
"batch_size": 32,
"bias": False,
"data": ['./data/cinc', './data/zheng', './data/ribeiro'],
"discriminative_lr_factor": 0.1,
"distributed_backend": None,
"dropout_head": 0.5,
"epochs": 1000,
"executable": "cpc",
"fc_encoder": True,
"finetune": False,
"finetune_dataset": "thew",
"gpus": 1,
"gru": False,
"input_channels": 12,
"input_size": 1000,
"lin_ftrs_head": [512],
"linear_eval": False,
"lr": 0.0001,
"lr_find": False,
"metadata": None,
"mlp": False,
"n_false_negatives": 128,
"n_hidden": 512,
"n_layers": 2,
"negatives_from_same_seq_only": True,
"no_bn_encoder": False,
"no_bn_head": False,
"normalize": True,
"num_nodes": 1,
"optimizer": "adam",
"output_path": "./runs/cpc/all",
"precision": 16,
"pretrained": None,
"resume": None,
"skip_encoder": False,
"steps_predicted": 12,
"train_head_only": False,
"weight_decay": 0.001,
}
class Struct:
def __init__(self, **entries):
self.__dict__.update(entries)
hparams = Struct(**hparams)
# configure dataset params
chunkify_train = False
chunk_length_train = hparams.input_size if chunkify_train else 0
stride_train = hparams.input_size
chunkify_valtest = True
chunk_length_valtest = hparams.input_size if chunkify_valtest else 0
stride_valtest = hparams.input_size//2
train_datasets = []
val_datasets = []
test_datasets = []
for i,target_folder in enumerate(hparams.data):
target_folder = Path(target_folder)
df_mapped, lbl_itos, mean, std = load_dataset(target_folder)
# always use PTB-XL stats
mean = np.array([-0.00184586, -0.00130277, 0.00017031, -0.00091313, -0.00148835, -0.00174687, -0.00077071, -0.00207407, 0.00054329, 0.00155546, -0.00114379, -0.00035649])
std = np.array([0.16401004, 0.1647168 , 0.23374124, 0.33767231, 0.33362807, 0.30583013, 0.2731171 , 0.27554379, 0.17128962, 0.14030828, 0.14606956, 0.14656108])
lbl_itos = lbl_itos
tfms_ptb_xl_cpc = ToTensor() if hparams.normalize is False else transforms.Compose([Normalize(mean,std),ToTensor()])
max_fold_id = df_mapped.strat_fold.max() #unfortunately 1-based for PTB-XL; sometimes 100 (Ribeiro)
df_train = df_mapped[df_mapped.strat_fold<(max_fold_id-1 if hparams.finetune else max_fold_id)]
df_val = df_mapped[df_mapped.strat_fold==(max_fold_id-1 if hparams.finetune else max_fold_id)]
train_datasets.append(TimeseriesDatasetCrops(df_train,hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_train,min_chunk_length=hparams.input_size, stride=stride_train,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if hparams.finetune else None,memmap_filename=target_folder/("memmap.npy")))
val_datasets.append(TimeseriesDatasetCrops(df_val,hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if hparams.finetune else None,memmap_filename=target_folder/("memmap.npy")))
if(hparams.finetune):
test_datasets.append(TimeseriesDatasetCrops(df_test,hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label",memmap_filename=target_folder/("memmap.npy")))
print("\n",target_folder)
print("train dataset:",len(train_datasets[-1]),"samples")
print("val dataset:",len(val_datasets[-1]),"samples")
if(hparams.finetune):
print("test dataset:",len(test_datasets[-1]),"samples")
if(len(train_datasets)>1): #multiple data folders
print("\nCombined:")
train_dataset = ConcatDataset(train_datasets)
val_dataset = ConcatDataset(val_datasets)
print("train dataset:",len(train_dataset),"samples")
print("val dataset:",len(val_dataset),"samples")
if(hparams.finetune):
test_dataset = ConcatDataset(test_datasets)
print("test dataset:",len(test_dataset),"samples")
else: #just a single data folder
train_dataset = train_datasets[0]
val_dataset = val_datasets[0]
if(hparams.finetune):
test_dataset = test_datasets[0]
data/cinc
train dataset: 38710 samples
val dataset: 9158 samples
data/zheng
train dataset: 9582 samples
val dataset: 1064 samples
data/ribeiro
train dataset: 261 samples
val dataset: 26 samples
Combined:
train dataset: 48553 samples
val dataset: 10248 samples
In [3]:
signal, label = train_dataset.__getitem__(7)
In [4]:
signal.size()
Out[4]:
torch.Size([12, 1000])
In [5]:
label
Out[5]:
0.0
In [6]:
data_loader = DataLoader(train_dataset, batch_size=4, num_workers=4, shuffle=True, drop_last = True)
In [7]:
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/17_cpc.ipynb (unless otherwise specified).
__all__ = ['CPCEncoder', 'CPCModel']
# Cell
import torch
import torch.nn.functional as F
import torch.nn as nn
from clinical_ts.basic_conv1d import _conv1d
import numpy as np
# from .basic_conv1d import listify, bn_drop_lin
# Cell
class CPCEncoder(nn.Sequential):
'CPC Encoder'
# strides = [1, 1, 1, 1]
# kss = [1, 1, 1, 1]
# features = [512, 512, 512, 512]
def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn=False):
assert(len(strides)==len(kss) and len(strides)==len(features))
lst = []
for i,(s,k,f) in enumerate(zip(strides,kss,features)):
lst.append(_conv1d(input_channels if i==0 else features[i-1],f,kernel_size=k,stride=s,bn=bn))
super().__init__(*lst)
self.downsampling_factor = np.prod(strides)
self.output_dim = features[-1]
# output: bs, output_dim, seq//downsampling_factor
def encode(self, input):
#bs = input.size()[0]
#ch = input.size()[1]
#seq = input.size()[2]
#segments = seq//self.downsampling_factor
#input_encoded = self.forward(input[:,:,:segments*self.downsampling_factor]).transpose(1,2) #bs, seq//downsampling, encoder_output_dim (standard ordering for batch_first RNNs)
print(f"{input.size()=}")
input_encoded = self.forward(input)
print(f"{input_encoded.size()=}")
input_encoded = input_encoded.transpose(1,2)
print(f"{input_encoded.size()=}")
return input_encoded
# Cell
class CPCModel(nn.Module):
"CPC model"
def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn_encoder=False, n_hidden=512,n_layers=2,mlp=False,lstm=True,bias_proj=False, num_classes=None, concat_pooling=True, ps_head=0.5,lin_ftrs_head=[512],bn_head=True,skip_encoder=False):
super().__init__()
assert(skip_encoder is False or num_classes is not None)#pretraining only with encoder
self.encoder = CPCEncoder(input_channels,strides=strides,kss=kss,features=features,bn=bn_encoder) if skip_encoder is False else None
self.encoder_output_dim = self.encoder.output_dim if skip_encoder is False else None
self.encoder_downsampling_factor = self.encoder.downsampling_factor if skip_encoder is False else None
self.n_hidden = n_hidden
self.n_layers = n_layers
self.mlp = mlp
self.num_classes = num_classes
self.concat_pooling = concat_pooling
self.rnn = nn.LSTM(self.encoder_output_dim if skip_encoder is False else input_channels,n_hidden,num_layers=n_layers,batch_first=True) if lstm is True else nn.GRU(self.encoder.output_dim,n_hidden,num_layers=n_layers,batch_first=True)
if(num_classes is None): #pretraining
if(mlp):# additional hidden layer as in simclr
self.proj = nn.Sequential(nn.Linear(n_hidden, n_hidden),nn.ReLU(inplace=True),nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj))
else:
self.proj = nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj)
def forward(self, input):
# input shape bs,ch,seq
if(self.encoder is not None):
input_encoded = self.encoder.encode(input)
else:
input_encoded = input.transpose(1,2) #bs, seq, channels
output_rnn, _ = self.rnn(input_encoded) #output_rnn: bs, seq, n_hidden
if(self.num_classes is None):#pretraining
print(f"{input_encoded.size()=}")
print(f"{self.rnn=}")
print(f"{output_rnn.size()=}")
print(f"{self.proj=}")
print(f"{self.proj(output_rnn).size()=}")
return input_encoded, self.proj(output_rnn)
else:#classifier
output = output_rnn.transpose(1,2)#bs,n_hidden,seq (i.e. standard CNN channel ordering)
if(self.concat_pooling is False):
output = output[:,:,-1]
return self.head(output)
def get_layer_groups(self):
return (self.encoder,self.rnn,self.head)
def get_output_layer(self):
return self.head[-1]
def set_output_layer(self,x):
self.head[-1] = x
# step_predicted=12
def cpc_loss(self,input, target=None, steps_predicted=5, n_false_negatives=9, negatives_from_same_seq_only=False, eval_acc=False):
assert(self.num_classes is None)
input_encoded, output = self.forward(input) #input_encoded: bs, seq, features; output: bs,seq,features
input_encoded_flat = input_encoded.reshape(-1,input_encoded.size(2)) #for negatives below: -1, features
print(f"{input_encoded_flat.size()=}")
bs = input_encoded.size()[0]
seq = input_encoded.size()[1]
loss = torch.tensor(0,dtype=torch.float32).to(input.device)
tp_cnt = torch.tensor(0,dtype=torch.int64).to(input.device)
# steps_predicted = 12
print(f"{(input_encoded.size()[1]-steps_predicted)=}")
for i in range(input_encoded.size()[1]-steps_predicted):
# input_incoded.size() = (bs, seq, channels(vectors))
print(f"{input_encoded[:,i+steps_predicted].size()=}")
positives = input_encoded[:,i+steps_predicted].unsqueeze(1) #bs,1,encoder_output_dim
print(f"{positives.size()=}")
print(f"{negatives_from_same_seq_only=}")
if(negatives_from_same_seq_only): # True
# seq = 1000
# n_false_negatives = 128
idxs = torch.randint(0,(seq-1),(bs*n_false_negatives,)).to(input.device)
else:#negative from everywhere
idxs = torch.randint(0,bs*(seq-1),(bs*n_false_negatives,)).to(input.device)
# print(f"{idxs=}")
# print(f"{idxs.size()=}")
idxs_seq = torch.remainder(idxs,seq-1) #bs*false_neg
# print(f"{idxs_seq=}")
idxs_seq2 = idxs_seq * (idxs_seq<(i+steps_predicted)).long() +(idxs_seq+1)*(idxs_seq>=(i+steps_predicted)).long()#bs*false_neg
# print(f"{idxs_seq2=}")
if(negatives_from_same_seq_only):
idxs_batch = torch.arange(0,bs).repeat_interleave(n_false_negatives).to(input.device)
# print(f"{idxs_batch=}")
# print(f"{idxs_batch.size()=}")
else:
idxs_batch = idxs//(seq-1)
idxs2_flat = idxs_batch*seq+idxs_seq2 #for negatives from everywhere: this skips step i+steps_predicted from the other sequences as well for simplicity
print(f"{idxs2_flat=}")
print(f"{idxs2_flat.size()=}")
negatives = input_encoded_flat[idxs2_flat].view(bs,n_false_negatives,-1) #bs*false_neg, encoder_output_dim
print(f"{negatives.size()=}")
candidates = torch.cat([positives,negatives],dim=1)#bs,1+false_neg,encoder_output_dim
print(f"{candidates.size()=}")
print(f"{output[:,i].size()=}") # (bs, features)
print(f"{output[:,i].unsqueeze(1).size()=}")# (bs, 1, features)
print(f"{(output[:,i].unsqueeze(1)*candidates).size()=}") # (bs, (1+false_neg), features)
preds=torch.sum(output[:,i].unsqueeze(1)*candidates,dim=-1) #bs,(1+false_neg) # sum 은 global average pooling 같은 느낌인가?
print(f"{preds=}")
print(f"{preds.size()=}")
targs = torch.zeros(bs, dtype=torch.int64).to(input.device) # 0번째 class가 true이기 때문에 !!!!!! target이 0 !!
print(f"{targs=}")
print(f"{targs.size()=}")
if(eval_acc):
preds_argmax = torch.argmax(preds,dim=-1)
tp_cnt += torch.sum(preds_argmax == targs)
loss += F.cross_entropy(preds,targs)
print(f"{loss=}")
if(eval_acc):
return loss, tp_cnt.float()/bs/(input_encoded.size()[1]-steps_predicted)
else:
return loss
#copied from RNN1d
class AdaptiveConcatPoolRNN(nn.Module):
def __init__(self, bidirectional=False):
super().__init__()
self.bidirectional = bidirectional
def forward(self,x):
#input shape bs, ch, ts
t1 = nn.AdaptiveAvgPool1d(1)(x)
t2 = nn.AdaptiveMaxPool1d(1)(x)
if(self.bidirectional is False):
t3 = x[:,:,-1]
else:
channels = x.size()[1]
t3 = torch.cat([x[:,:channels,-1],x[:,channels:,0]],1)
out=torch.cat([t1.squeeze(-1),t2.squeeze(-1),t3],1) #output shape bs, 3*ch
return out
In [8]:
strides=[1]*4
kss = [1]*4
features = [512]*4
model_cpc = CPCModel(
input_channels=hparams.input_channels, # 12
strides=strides, # [2,2,2,2]
kss=kss, # [10,4,4,4]
features=features, # [512]*4
n_hidden=hparams.n_hidden, # 512
n_layers=hparams.n_layers, # 2
mlp=hparams.mlp, #
lstm=not(hparams.gru),
bias_proj=hparams.bias,
num_classes=None,
skip_encoder=hparams.skip_encoder,
bn_encoder=not(hparams.no_bn_encoder),
# lin_ftrs_head=[] if hparams.linear_eval else eval(hparams.lin_ftrs_head),
lin_ftrs_head=[] if hparams.linear_eval else 512,
ps_head=0 if hparams.linear_eval else hparams.dropout_head,
bn_head=False if hparams.linear_eval else not(hparams.no_bn_head)
)
In [9]:
# torch.load('runs/cpc/all/version_5/best_model.ckpt').keys()
In [10]:
# model_cpc = model_cpc.load_state_dict(torch.load('runs/cpc/all/version_5/best_model.ckpt')['state_dict'])
In [11]:
data_batch = next(iter(data_loader))
print(len(data_batch))
2
In [12]:
data_batch[0].size()
Out[12]:
torch.Size([4, 12, 1000])
In [13]:
data_batch[1]
Out[13]:
tensor([0., 0., 0., 0.], dtype=torch.float64)
In [14]:
loss, acc = model_cpc.cpc_loss(data_batch[0],steps_predicted=997,n_false_negatives=hparams.n_false_negatives, negatives_from_same_seq_only=hparams.negatives_from_same_seq_only, eval_acc=True)
input.size()=torch.Size([4, 12, 1000])
input_encoded.size()=torch.Size([4, 512, 1000])
input_encoded.size()=torch.Size([4, 1000, 512])
input_encoded.size()=torch.Size([4, 1000, 512])
self.rnn=LSTM(512, 512, num_layers=2, batch_first=True)
output_rnn.size()=torch.Size([4, 1000, 512])
self.proj=Linear(in_features=512, out_features=512, bias=False)
self.proj(output_rnn).size()=torch.Size([4, 1000, 512])
input_encoded_flat.size()=torch.Size([4000, 512])
(input_encoded.size()[1]-steps_predicted)=3
input_encoded[:,i+steps_predicted].size()=torch.Size([4, 512])
positives.size()=torch.Size([4, 1, 512])
negatives_from_same_seq_only=True
idxs2_flat=tensor([ 799, 554, 251, 445, 652, 289, 14, 691, 946, 976, 147, 477,
159, 462, 307, 286, 567, 884, 519, 760, 271, 180, 83, 182,
825, 266, 1, 408, 115, 602, 994, 555, 482, 401, 226, 671,
326, 206, 291, 951, 425, 776, 517, 78, 415, 131, 203, 226,
29, 425, 39, 750, 806, 262, 990, 739, 747, 149, 172, 957,
297, 364, 633, 974, 465, 516, 758, 11, 818, 197, 188, 774,
648, 550, 998, 222, 828, 450, 655, 819, 101, 549, 401, 75,
611, 193, 745, 277, 61, 323, 936, 437, 670, 922, 23, 298,
729, 277, 953, 42, 793, 395, 151, 794, 448, 744, 327, 423,
779, 292, 185, 783, 437, 579, 593, 27, 255, 316, 589, 452,
70, 51, 672, 101, 302, 240, 925, 732, 1955, 1335, 1236, 1248,
1664, 1319, 1566, 1411, 1195, 1685, 1070, 1763, 1470, 1943, 1013, 1065,
1466, 1845, 1198, 1991, 1325, 1311, 1439, 1630, 1741, 1750, 1725, 1324,
1983, 1198, 1926, 1328, 1987, 1521, 1176, 1770, 1287, 1466, 1216, 1264,
1038, 1771, 1083, 1062, 1516, 1064, 1810, 1465, 1803, 1911, 1366, 1729,
1477, 1239, 1008, 1029, 1700, 1812, 1225, 1103, 1061, 1363, 1224, 1059,
1636, 1907, 1320, 1877, 1194, 1277, 1394, 1062, 1034, 1795, 1183, 1635,
1288, 1236, 1397, 1131, 1400, 1249, 1970, 1153, 1620, 1307, 1424, 1982,
1209, 1800, 1341, 1218, 1691, 1046, 1408, 1683, 1237, 1612, 1561, 1001,
1919, 1168, 1516, 1543, 1245, 1515, 1479, 1917, 1001, 1421, 1835, 1404,
1759, 1327, 1371, 1397, 1468, 1566, 1874, 1256, 1226, 1154, 1839, 1895,
1826, 1192, 1562, 1600, 2194, 2599, 2613, 2471, 2536, 2234, 2745, 2489,
2649, 2250, 2455, 2412, 2072, 2873, 2599, 2875, 2089, 2349, 2322, 2414,
2650, 2825, 2005, 2379, 2952, 2695, 2665, 2431, 2163, 2408, 2016, 2105,
2615, 2809, 2706, 2652, 2519, 2497, 2157, 2676, 2982, 2795, 2796, 2684,
2998, 2975, 2940, 2941, 2344, 2324, 2059, 2927, 2665, 2747, 2357, 2723,
2970, 2670, 2539, 2914, 2074, 2840, 2167, 2226, 2794, 2086, 2548, 2555,
2910, 2082, 2319, 2088, 2911, 2481, 2530, 2014, 2503, 2729, 2558, 2072,
2360, 2031, 2928, 2001, 2221, 2805, 2998, 2571, 2158, 2304, 2063, 2667,
2160, 2129, 2896, 2002, 2872, 2228, 2656, 2143, 2770, 2984, 2061, 2656,
2751, 2988, 2083, 2367, 2586, 2273, 2026, 2116, 2507, 2122, 2550, 2078,
2975, 2249, 2446, 2788, 2924, 2628, 2821, 2039, 2537, 2146, 2740, 2739,
3812, 3666, 3193, 3595, 3483, 3269, 3689, 3242, 3803, 3739, 3739, 3840,
3860, 3492, 3940, 3025, 3388, 3921, 3910, 3208, 3334, 3616, 3413, 3293,
3034, 3012, 3658, 3457, 3632, 3015, 3517, 3219, 3988, 3932, 3142, 3636,
3391, 3605, 3343, 3627, 3777, 3692, 3117, 3710, 3515, 3846, 3395, 3126,
3624, 3220, 3314, 3801, 3133, 3270, 3212, 3642, 3368, 3567, 3067, 3968,
3176, 3824, 3075, 3152, 3043, 3054, 3576, 3169, 3295, 3639, 3038, 3051,
3785, 3091, 3652, 3336, 3368, 3174, 3702, 3899, 3290, 3220, 3939, 3293,
3886, 3742, 3062, 3138, 3860, 3909, 3092, 3259, 3085, 3833, 3333, 3752,
3780, 3439, 3710, 3514, 3797, 3020, 3003, 3839, 3101, 3654, 3527, 3473,
3322, 3261, 3099, 3864, 3921, 3519, 3681, 3270, 3523, 3180, 3964, 3317,
3735, 3612, 3380, 3233, 3912, 3597, 3149, 3906])
idxs2_flat.size()=torch.Size([512])
negatives.size()=torch.Size([4, 128, 512])
candidates.size()=torch.Size([4, 129, 512])
output[:,i].size()=torch.Size([4, 512])
output[:,i].unsqueeze(1).size()=torch.Size([4, 1, 512])
(output[:,i].unsqueeze(1)*candidates).size()=torch.Size([4, 129, 512])
preds=tensor([[-3.6536e-03, 1.4450e-01, 3.2580e-01, -1.2079e-02, -5.1958e-02,
-3.4181e-02, -1.0701e-01, -1.2421e-01, -7.5189e-02, -4.0776e-02,
-3.7895e-02, -4.8495e-02, 2.9928e-01, -2.5149e-02, -4.5315e-02,
-7.6321e-02, -5.6055e-02, -4.2787e-02, -3.3802e-02, -3.2608e-02,
-1.4660e-01, -4.4604e-02, -4.8901e-02, -7.5359e-02, -6.4136e-02,
-7.6762e-02, -2.5051e-02, -5.1359e-02, -5.0701e-02, 9.1746e-02,
-5.4879e-02, -2.3431e-03, -9.2638e-03, 2.0610e-02, -7.0630e-02,
-4.4089e-02, -3.1877e-02, -1.4184e-01, -4.9470e-02, -6.9476e-02,
-5.3928e-02, 3.3361e-02, 9.0234e-02, -6.0236e-02, -9.3135e-02,
-6.7241e-02, -5.8956e-02, -5.2465e-02, -4.4089e-02, -9.4225e-02,
3.3361e-02, -3.8123e-02, -1.3814e-01, -4.7204e-02, 3.3017e-01,
1.3148e-01, -9.0427e-02, -1.3685e-01, -7.9107e-02, -1.1649e-01,
-2.5576e-02, -8.3871e-02, -5.1497e-02, -5.4495e-02, -3.6224e-02,
-3.2053e-02, -4.0254e-02, -1.2547e-01, -1.1670e-01, -7.1497e-02,
-5.4785e-02, 2.7132e-01, 1.4350e-01, -5.1522e-02, -3.6274e-02,
3.8014e-03, -4.0404e-02, -4.3578e-02, -5.3691e-02, -5.6413e-02,
-5.3787e-02, -2.7803e-02, -5.1253e-02, -7.0630e-02, -3.9933e-02,
-2.0780e-02, -5.9637e-02, -1.3156e-01, -6.8665e-02, -8.8468e-02,
-1.4388e-01, 2.7811e-02, -7.6181e-02, -2.5556e-02, -5.1981e-02,
-7.5509e-02, -8.6344e-02, -1.3728e-01, -6.8665e-02, -5.1848e-02,
-5.9756e-02, 1.0752e-01, -1.1970e-01, -6.6722e-02, 1.7014e-01,
-4.5494e-02, -1.2248e-01, -1.4083e-01, -1.9173e-02, 8.0392e-02,
-9.3972e-02, 3.5300e-02, 3.8736e-02, -7.6181e-02, -1.3065e-02,
-3.4709e-02, -9.0272e-02, -3.7138e-02, -4.8101e-02, -3.8957e-02,
-5.4228e-02, 3.6140e-03, -6.8414e-02, -2.3627e-02, -2.7803e-02,
-5.1850e-02, -6.3502e-02, -5.9725e-02, -9.9340e-02],
[ 1.7624e-01, -1.3059e-01, -8.7348e-02, -4.7866e-02, -2.4590e-02,
1.0858e-02, 2.3725e-02, 5.3881e-02, 8.5260e-02, -7.1399e-02,
-3.5656e-02, 1.2744e-01, 2.8234e-03, -4.7538e-02, -2.0207e-01,
-2.8529e-01, -2.2159e-01, -2.4506e-01, 2.9820e-02, -7.5603e-02,
1.7624e-01, -3.7772e-02, -1.2123e-01, -5.3504e-02, 5.1153e-02,
-1.2117e-01, -5.9872e-02, -2.7079e-03, -4.4610e-02, -1.4857e-01,
-7.5603e-02, 1.1125e-02, -5.4885e-02, -2.4311e-01, -7.4841e-04,
-1.2108e-02, -2.5994e-02, 2.6492e-02, -2.4506e-01, -7.9725e-02,
-4.4852e-02, 1.3568e-01, -5.3335e-02, -1.4822e-01, -2.1545e-01,
2.0198e-04, -2.2730e-01, 1.7395e-02, 1.9353e-02, -1.0326e-01,
-1.0364e-01, -2.9768e-02, 9.3020e-02, 2.0693e-03, 2.5560e-03,
-3.1599e-01, 2.6013e-02, 2.3654e-02, 8.7851e-03, -1.2383e-01,
-1.5334e-01, -2.2765e-01, -6.6796e-02, -9.1513e-02, -2.7024e-01,
4.1978e-02, -1.4797e-01, 1.6646e-02, -6.4321e-02, -7.5163e-02,
-8.8880e-02, -7.6017e-02, -2.1545e-01, -4.4064e-02, -7.4472e-02,
-6.7997e-02, 4.6951e-02, 3.8682e-02, -4.7866e-02, -9.0491e-02,
-1.2591e-01, -1.4598e-02, -6.2480e-02, -1.7548e-03, -1.7305e-01,
1.1648e-01, -1.0278e-01, -1.2791e-01, -1.4794e-01, 4.7801e-02,
-7.2656e-02, -5.3352e-02, -1.2165e-01, -6.7716e-03, -3.0831e-01,
1.1530e-02, -4.0602e-02, 1.8301e-02, 2.4349e-02, 1.3448e-02,
-7.5951e-02, 6.2158e-02, 1.5752e-03, 2.0198e-04, 1.9729e-02,
-4.5772e-02, 2.1445e-02, -2.0317e-02, -1.8548e-02, -7.5951e-02,
-1.0423e-01, -9.8751e-02, -4.9004e-02, -7.0480e-02, -1.9040e-02,
2.0697e-02, -9.0491e-02, 3.5147e-03, 5.3881e-02, -5.2825e-02,
-6.2842e-02, -1.2910e-01, -1.9723e-01, 3.1518e-02, -1.7122e-01,
-3.0749e-02, -7.2336e-02, 2.3687e-02, -6.1004e-03],
[-6.8570e-02, -6.1383e-02, -1.1235e-01, -1.1991e-01, -1.3994e-01,
-9.0237e-02, -7.0384e-02, 1.3279e-01, -1.2270e-01, -7.3408e-02,
-6.3124e-02, -1.1114e-01, -7.4025e-02, -9.5269e-02, -8.8338e-02,
-1.1235e-01, -8.7805e-02, -3.3563e-02, -7.9831e-02, -4.5363e-02,
-6.4634e-02, -7.3719e-02, -7.6706e-02, -7.2522e-02, -9.6519e-02,
-1.1483e-01, -8.0694e-02, -7.8023e-01, -7.6768e-02, -9.6147e-02,
-8.4328e-02, 3.0979e-02, -8.3418e-02, -6.1058e-02, -7.3485e-02,
-3.5823e-02, -9.6997e-02, -5.4127e-01, -8.1689e-02, -5.3362e-01,
-8.6622e-02, -1.0318e-01, -3.5736e-02, -3.3684e-02, -1.1802e-01,
-6.8570e-02, -8.5535e-02, -1.0944e-01, -1.1066e-01, -9.1730e-02,
-8.9758e-02, -5.5508e-02, -9.6292e-02, -7.8023e-01, -7.9765e-02,
-5.8191e-02, -6.4801e-02, 3.3194e-01, 2.1964e-01, -9.2463e-02,
-9.2304e-02, -6.1545e-02, -8.6978e-02, -9.5697e-02, -1.4012e-01,
-5.0366e-02, 5.1657e-01, -6.7711e-02, -6.9339e-02, -1.5025e-01,
-8.2274e-02, -7.7392e-02, 1.5289e-01, -1.4626e-01, -8.9064e-02,
-1.2899e-01, 6.9455e-01, -1.0927e-01, -9.5855e-02, -4.5140e-02,
-9.5269e-02, -1.2919e-01, -1.0944e-01, -9.1273e-02, -6.1939e-02,
-8.8301e-02, -5.8994e-02, -6.8570e-02, -9.8734e-02, 3.5827e-01,
4.2291e-02, -4.7908e-02, -6.2591e-01, 1.2087e-01, -9.4031e-02,
-6.0979e-02, -6.0127e-02, -3.2055e-02, -8.2287e-01, -1.1063e-01,
-1.2433e-01, -9.0731e-02, -5.9011e-02, -5.7718e-02, -1.1063e-01,
-6.8581e-02, -5.6325e-02, -9.4389e-01, -7.4881e-02, -8.9221e-02,
-7.1893e-02, -8.7986e-02, -5.4521e-02, -1.1674e-01, -4.6947e-02,
-8.7879e-02, -3.6110e-02, -8.5535e-02, -6.3261e-02, -5.5483e-01,
-6.6856e-02, -8.7085e-02, -5.4298e-02, -6.7173e-02, -1.3790e-01,
-9.6285e-02, -9.0006e-02, -8.4264e-01, -2.9675e-01],
[-2.4108e-02, -7.9086e-02, -3.6460e-02, -6.3839e-02, -3.0928e-02,
-8.9724e-03, -2.6177e-02, -3.1920e-02, -9.4825e-02, 6.2595e-02,
2.9915e-02, 2.9915e-02, -1.1907e-01, -3.6858e-02, -5.3699e-02,
-8.8650e-02, 1.4783e-04, -9.5215e-02, -3.6492e-02, -8.6401e-02,
-4.8725e-02, -4.4830e-02, -7.1997e-02, 3.6180e-02, -8.1361e-02,
-6.6032e-02, -1.2119e-01, -3.5716e-02, -9.7922e-02, -1.2326e-01,
-3.4647e-01, -1.4996e-01, -1.4452e-01, -2.7238e-02, 2.3847e-02,
-5.4275e-02, -8.2306e-02, -6.4857e-02, -2.5216e-02, 4.2553e-02,
-1.4775e-02, -1.5899e-01, -6.6600e-01, -1.2882e-01, -1.4031e-01,
-1.4704e-01, -9.7343e-02, -2.6937e-02, -3.0223e-02, -4.6484e-01,
-1.5042e-01, -1.2248e-01, -1.5356e-02, -4.3862e-02, -4.7250e-02,
9.4548e-02, -1.1353e-01, -1.0860e-01, -1.2842e-01, -1.2798e-01,
-1.6791e-01, -9.7986e-02, -5.7001e-02, 1.2336e-01, -7.3119e-02,
1.2989e-03, -3.9045e-02, -1.0525e-01, -1.4321e-01, -6.6856e-01,
-9.6104e-02, -2.0837e-01, -2.8935e-02, -7.6138e-02, -1.1124e-01,
-9.9079e-02, -6.0649e-02, -1.0860e-01, -8.5095e-02, -1.2208e-01,
-1.1213e-01, -1.1649e-01, -1.5042e-01, -1.1434e-01, -8.1361e-02,
9.2121e-02, 1.4233e-02, 1.5692e-02, -5.8243e-02, -3.6858e-02,
-1.0359e-01, -3.3828e-01, -5.6688e-02, -6.9825e-02, -1.0482e-01,
-3.1501e-02, -7.7254e-02, -1.5735e-01, -8.4784e-02, -1.4031e-01,
-7.1494e-02, -5.5661e-02, -1.1651e-01, -4.2644e-02, -9.8933e-02,
-1.0171e-01, -8.2601e-02, -6.4309e-02, -5.2385e-02, -7.2599e-02,
-3.9481e-02, -3.3239e-02, -1.6509e-02, -3.6492e-02, -1.5517e-01,
-5.1274e-02, -4.7250e-02, -8.9897e-02, -1.3966e-01, -9.4581e-02,
-1.5523e-01, -4.0755e-02, -3.2069e-02, -1.3164e-01, -1.4856e-01,
-7.2113e-02, -5.6699e-02, -1.2483e-02, -1.5517e-01]],
grad_fn=<SumBackward1>)
preds.size()=torch.Size([4, 129])
targs=tensor([0, 0, 0, 0])
targs.size()=torch.Size([4])
loss=tensor(4.7787, grad_fn=<AddBackward0>)
input_encoded[:,i+steps_predicted].size()=torch.Size([4, 512])
positives.size()=torch.Size([4, 1, 512])
negatives_from_same_seq_only=True
idxs2_flat=tensor([ 104, 928, 75, 693, 85, 804, 796, 914, 487, 342, 478, 360,
509, 681, 544, 305, 7, 757, 860, 995, 201, 726, 459, 707,
56, 2, 686, 639, 906, 608, 978, 216, 202, 682, 145, 77,
910, 677, 159, 741, 599, 339, 769, 566, 870, 896, 371, 108,
319, 271, 508, 32, 816, 524, 245, 279, 146, 704, 436, 305,
356, 408, 591, 825, 756, 730, 800, 474, 38, 907, 136, 666,
492, 235, 727, 577, 814, 241, 368, 574, 527, 85, 226, 957,
390, 115, 730, 767, 885, 78, 369, 630, 382, 754, 2, 138,
38, 755, 330, 787, 874, 210, 129, 785, 252, 789, 268, 775,
229, 814, 88, 157, 188, 343, 83, 12, 488, 922, 326, 600,
97, 410, 16, 732, 222, 460, 593, 746, 1404, 1847, 1185, 1991,
1765, 1698, 1256, 1778, 1980, 1269, 1275, 1557, 1487, 1586, 1446, 1534,
1667, 1583, 1865, 1913, 1453, 1341, 1384, 1573, 1617, 1816, 1013, 1366,
1201, 1477, 1561, 1354, 1722, 1541, 1396, 1526, 1457, 1997, 1751, 1372,
1996, 1534, 1902, 1579, 1862, 1835, 1795, 1510, 1002, 1258, 1888, 1009,
1154, 1959, 1441, 1911, 1906, 1788, 1641, 1234, 1534, 1768, 1883, 1544,
1898, 1092, 1820, 1923, 1204, 1999, 1715, 1743, 1814, 1075, 1675, 1750,
1451, 1571, 1443, 1041, 1855, 1676, 1736, 1585, 1069, 1865, 1720, 1268,
1512, 1909, 1915, 1760, 1940, 1641, 1283, 1477, 1688, 1513, 1411, 1000,
1264, 1396, 1196, 1392, 1507, 1675, 1626, 1936, 1843, 1647, 1215, 1780,
1988, 1818, 1683, 1168, 1489, 1926, 1771, 1195, 1380, 1035, 1011, 1820,
1077, 1830, 1004, 1514, 2876, 2702, 2291, 2009, 2086, 2947, 2797, 2836,
2291, 2157, 2300, 2477, 2301, 2288, 2094, 2150, 2624, 2979, 2625, 2753,
2746, 2046, 2411, 2014, 2413, 2465, 2757, 2934, 2572, 2023, 2380, 2061,
2353, 2012, 2649, 2481, 2691, 2118, 2062, 2633, 2283, 2370, 2560, 2575,
2443, 2459, 2408, 2417, 2606, 2198, 2256, 2896, 2726, 2455, 2898, 2949,
2721, 2131, 2806, 2636, 2077, 2084, 2469, 2477, 2498, 2423, 2328, 2600,
2215, 2406, 2194, 2290, 2694, 2753, 2872, 2818, 2867, 2474, 2957, 2613,
2344, 2578, 2477, 2450, 2448, 2495, 2347, 2070, 2792, 2397, 2356, 2946,
2708, 2697, 2063, 2214, 2444, 2553, 2617, 2499, 2114, 2843, 2018, 2458,
2235, 2115, 2696, 2502, 2742, 2979, 2158, 2109, 2399, 2162, 2886, 2028,
2518, 2513, 2986, 2056, 2405, 2320, 2218, 2109, 2784, 2285, 2315, 2745,
3967, 3032, 3468, 3544, 3377, 3343, 3857, 3279, 3372, 3022, 3346, 3644,
3369, 3806, 3484, 3861, 3437, 3303, 3202, 3080, 3406, 3247, 3992, 3809,
3816, 3185, 3830, 3882, 3107, 3345, 3921, 3525, 3443, 3815, 3018, 3578,
3163, 3725, 3079, 3235, 3631, 3756, 3665, 3301, 3650, 3810, 3420, 3132,
3231, 3435, 3874, 3565, 3136, 3298, 3916, 3620, 3535, 3562, 3408, 3206,
3500, 3462, 3896, 3825, 3203, 3172, 3860, 3292, 3721, 3132, 3335, 3508,
3196, 3098, 3878, 3510, 3978, 3920, 3210, 3521, 3940, 3339, 3394, 3326,
3236, 3181, 3475, 3346, 3528, 3036, 3464, 3416, 3363, 3386, 3727, 3981,
3509, 3539, 3881, 3439, 3671, 3553, 3331, 3473, 3725, 3534, 3845, 3611,
3845, 3448, 3880, 3112, 3904, 3529, 3481, 3992, 3039, 3657, 3339, 3338,
3188, 3139, 3442, 3368, 3046, 3717, 3462, 3184])
idxs2_flat.size()=torch.Size([512])
negatives.size()=torch.Size([4, 128, 512])
candidates.size()=torch.Size([4, 129, 512])
output[:,i].size()=torch.Size([4, 512])
output[:,i].unsqueeze(1).size()=torch.Size([4, 1, 512])
(output[:,i].unsqueeze(1)*candidates).size()=torch.Size([4, 129, 512])
preds=tensor([[-1.1513e-02, -8.6167e-02, -1.0890e-01, -7.5679e-02, -1.0873e-01,
-1.4730e-01, -1.0625e-01, 2.0727e-01, -1.0376e-01, -1.3580e-02,
-3.5122e-02, 1.4486e-01, -7.2247e-02, -5.5162e-02, -1.0776e-01,
-6.7901e-02, -9.9624e-02, -1.5217e-01, -2.1239e-01, -1.6302e-01,
1.2403e-02, -4.1251e-02, -2.5332e-01, -8.8969e-02, -9.6065e-02,
-1.2475e-01, -8.1485e-02, -1.2415e-01, -1.2516e-01, -1.1837e-01,
2.8465e-02, -5.5470e-02, -1.1176e-01, -2.2623e-02, -1.4183e-01,
-8.8291e-02, -1.9290e-01, -1.1867e-01, -1.7056e-02, -5.1928e-02,
-2.9614e-01, -4.9800e-02, -3.5910e-02, -4.4615e-01, -8.4373e-02,
-1.1926e-01, -7.7696e-02, -9.3673e-02, -1.1397e-01, -7.5079e-02,
-6.2685e-02, -1.1016e-01, -1.0275e-01, -1.3273e-01, -9.1142e-02,
-1.4539e-01, -3.8847e-02, -8.6682e-02, -8.2226e-02, -1.1099e-01,
-9.9624e-02, -8.8786e-03, -7.9937e-02, -6.8090e-02, -1.3568e-01,
-2.3339e-01, -2.8208e-01, -1.6530e-02, -4.3057e-02, -8.7352e-02,
-1.1089e-01, -4.6403e-02, -6.7427e-02, -7.4669e-03, -1.1471e-01,
-2.7847e-01, -1.9692e-01, -7.9592e-02, -1.4196e-01, -1.3985e-01,
-5.0735e-02, -6.1251e-02, -1.4730e-01, -6.6720e-02, -4.5161e-02,
-1.1985e-01, -1.2345e-01, -2.8208e-01, -3.8452e-01, -5.7506e-02,
-1.6227e-01, -1.5198e-01, -6.9328e-02, -6.8196e-02, -4.5485e-02,
-8.1485e-02, -4.3068e-03, -8.7352e-02, -1.2140e-01, -5.1733e-02,
5.0057e-02, -8.6353e-02, 1.4219e-02, -8.9982e-02, 7.9097e-02,
-4.5997e-02, -1.6027e-02, -3.8515e-02, 1.7340e-01, -5.5125e-02,
-7.9592e-02, -1.4466e-01, -1.0149e-01, 1.1252e-01, -3.6295e-02,
-1.5760e-01, -1.6731e-01, -1.4552e-02, -7.7709e-02, -2.0301e-01,
-6.7301e-02, -5.1425e-02, -6.9566e-02, -1.2575e-01, -2.2129e-01,
-6.4946e-02, -9.5417e-02, -6.7476e-02, -3.3347e-01],
[ 3.0364e-01, -8.3328e-02, -3.2884e-02, -1.2567e-01, 3.0364e-01,
-7.4426e-02, 2.3880e-02, -1.4378e-01, -1.4338e-01, -2.0768e-01,
-1.1419e-01, -1.4276e-01, 4.1219e-02, 1.6319e-01, 3.0400e-02,
1.5333e-01, -4.9437e-03, 9.0514e-03, 1.5836e-01, -1.3630e-01,
-2.5144e-01, -1.0259e-01, -1.4091e-01, -5.3779e-02, 2.3185e-02,
2.2901e-01, -1.5385e-01, -5.0554e-01, -6.9228e-02, -1.1563e-01,
-6.7641e-03, 1.9915e-02, -1.4508e-01, 5.9981e-03, 2.4717e-02,
-1.2248e-01, 9.7136e-03, -6.5705e-02, 3.0364e-01, -1.6048e-01,
-1.7254e-02, 3.0364e-01, -4.9437e-03, -5.0349e-02, 7.4416e-02,
-1.6681e-01, -1.7062e-01, -1.2926e-01, -1.2686e-02, 2.4447e-01,
-1.3188e-01, -3.7998e-02, -5.2237e-01, -2.8906e-01, -2.3531e-01,
-9.9910e-02, -1.7333e-01, -1.9266e-01, -1.4443e-01, 6.8102e-02,
-1.6504e-01, -4.9437e-03, 2.6161e-02, -1.5675e-01, 5.3302e-02,
-3.1986e-01, -3.1421e-01, -2.3627e-01, 1.7002e-01, -7.3490e-02,
3.0364e-01, -6.2855e-02, 1.1785e-01, 1.3357e-02, 1.3462e-01,
-1.1021e-01, -1.6862e-01, 1.2906e-01, 2.1103e-02, -5.5553e-02,
-3.1262e-01, -8.9996e-02, -9.1304e-02, -7.6217e-02, 1.1610e-01,
1.0058e-02, -1.3630e-01, 4.8382e-02, -1.1007e-01, -9.8527e-03,
-2.2991e-01, -1.6476e-01, -7.4514e-02, -9.9630e-02, 6.8102e-02,
-1.4631e-01, -6.7641e-03, 6.8031e-02, 9.0185e-03, 1.7174e-01,
-6.6848e-01, -1.1663e-01, -1.2248e-01, -1.0850e-01, -1.3938e-01,
-1.6724e-02, -1.1021e-01, 3.7363e-02, 2.1164e-01, 1.4548e-01,
2.7142e-02, -1.3304e-01, -1.4454e-01, 5.1777e-01, -1.8120e-01,
-6.3477e-02, 7.5929e-02, -3.0569e-02, 2.8187e-04, -7.0425e-02,
-1.1106e-01, -2.5584e-01, -4.6909e-02, -4.5689e-01, -2.3627e-01,
-2.8487e-01, -1.6632e-01, -3.3376e-01, 3.0196e-02],
[-1.4563e-01, -1.3328e-01, -7.9832e-02, -8.2984e-02, -1.5971e-01,
7.7268e-01, -1.6332e-01, -1.1319e-01, -2.7653e-01, -8.2984e-02,
-1.0522e+00, -1.5258e+00, -8.5227e-02, -3.8117e-01, -1.2596e-01,
-1.5002e-01, -1.3972e-01, -7.7500e-02, -9.6660e-02, -8.7879e-02,
-1.5197e-01, -1.2733e-01, -9.2829e-02, -1.2295e-01, 1.0882e+00,
-8.2784e-02, -1.8570e-01, -1.8491e-01, -1.8620e-01, -1.2508e-01,
-1.7059e-01, -1.7224e-01, -9.5731e-02, -1.2236e-01, -1.2767e+00,
-1.0647e-01, -1.4538e-01, -1.8832e-01, -1.3047e-01, -8.7684e-02,
-1.3320e-01, -9.0221e-02, -3.3658e-01, -1.1335e-01, -7.0239e-02,
-2.2329e-01, -1.7096e-01, -1.3934e-01, -8.8694e-02, -1.9286e-01,
-1.0168e-01, -1.0379e-01, -9.9497e-02, -9.1686e-02, -2.3279e-01,
-1.5380e-01, -1.3666e-01, -8.4521e-02, -1.4144e-01, -1.1598e-01,
-1.0114e-01, -1.0245e-01, -1.6665e+00, -2.1897e-01, -8.5227e-02,
-1.2714e-01, -1.4373e-01, -7.3407e-02, -2.0999e-01, -1.9110e-01,
-1.0770e-01, -1.0171e-01, -9.6309e-02, -1.6769e-01, -1.5197e-01,
-4.2810e-02, 6.7731e-01, -1.1568e-01, -8.1843e-02, -1.5739e-01,
-2.8294e-01, -1.6024e-01, -1.5438e-01, -8.5227e-02, -1.2014e-01,
7.2693e-01, -1.2947e-01, -1.2644e-01, -4.3717e-02, -9.7719e-02,
-1.9199e-01, -9.9647e-02, -1.7200e-01, -7.2197e-02, -9.3881e-02,
-8.3424e-02, -1.2937e-01, -1.8169e+00, -1.1100e-01, -2.5297e-01,
-8.5439e-02, -9.3579e-02, -1.4879e-01, -1.6707e-01, -1.8280e-01,
-1.5483e-01, -1.4358e-01, -1.3556e-01, -1.3164e-01, -4.5298e-01,
-9.6660e-02, 4.9850e-01, -1.4431e-01, -1.7820e-01, -1.6360e-01,
-8.2378e-02, -2.0769e-01, -1.6943e+00, -1.4628e-01, -1.9276e-01,
-1.1956e-01, -8.2449e-02, -1.5597e-01, -1.1991e-01, -1.4431e-01,
-9.1144e-02, -1.9830e-01, -1.9115e-01, 2.1598e-01],
[-3.0150e-02, -3.2955e-01, 1.1721e-01, -1.1204e-01, 9.0476e-02,
-2.5168e-01, 8.6707e-02, -5.3538e-02, 6.9870e-02, -1.9042e-01,
-1.4320e-01, 2.7253e-02, -2.1802e-01, -1.8362e-01, 7.7121e-02,
-4.0283e-02, -4.7240e-02, -1.9819e-01, -1.6145e-01, -1.5018e-01,
1.3165e-01, -5.3044e-02, -2.7971e-01, -3.0150e-02, -7.1346e-02,
-2.3805e-01, -3.4356e-01, -1.6526e-01, -3.2173e-01, -1.6974e-01,
7.0052e-02, -6.8471e-02, -1.6461e-01, -2.2570e-01, -2.4285e-01,
-3.1711e-02, -2.5415e-01, -1.0251e+00, -9.6156e-02, 1.3345e-01,
-2.4770e-01, -2.0523e-01, -1.0316e+00, -1.0354e-01, -2.2429e-01,
-2.7321e-01, -6.9733e-02, -1.9763e-01, -9.8368e-02, -6.5333e-02,
-2.2895e-01, -1.2631e-01, -2.4207e-01, -7.7607e-02, 5.1611e-03,
-6.6462e-02, -1.4629e-01, -9.8427e-02, -4.6104e-01, -4.9567e-02,
-1.7031e-01, -1.5471e-01, -9.4588e-02, -2.1527e-01, -1.7846e-01,
-1.1544e-01, -2.1606e-01, -6.1321e-02, -7.4325e-02, -1.3243e-01,
-9.8368e-02, -8.3602e-02, -1.4902e-01, -9.3539e-02, -7.1700e-02,
-1.1638e-01, -1.8658e-01, -5.4350e-02, -1.0060e-01, 6.3692e-02,
-2.7830e-01, -1.4618e-01, -7.2696e-02, -6.2492e-02, -1.0544e-01,
-2.4614e-01, -3.1423e-01, -1.1988e-02, 2.7253e-02, -9.3662e-02,
-2.8226e-02, -8.7144e-02, -4.4606e-02, 2.8895e-01, -2.5961e-01,
-8.0950e-02, -7.1112e-02, -1.7114e-01, -8.0503e-02, -3.3785e-02,
-1.7140e-01, -5.8731e-02, -1.8776e-01, -7.8908e-02, -8.5219e-02,
-9.6156e-02, -9.3952e-02, -2.8185e-01, -5.3794e-02, -2.8185e-01,
-2.6171e-01, -1.1572e-01, -1.2928e-01, -2.5731e-01, -1.0181e-01,
2.8058e-02, -3.0150e-02, -3.4986e-01, -8.4546e-02, -7.2696e-02,
-8.2613e-02, -2.2695e-01, -1.0884e-01, -2.0743e-01, -2.0333e-01,
-3.9925e-02, -2.1923e-01, -9.4588e-02, -3.7132e-01]],
grad_fn=<SumBackward1>)
preds.size()=torch.Size([4, 129])
targs=tensor([0, 0, 0, 0])
targs.size()=torch.Size([4])
loss=tensor(9.5172, grad_fn=<AddBackward0>)
input_encoded[:,i+steps_predicted].size()=torch.Size([4, 512])
positives.size()=torch.Size([4, 1, 512])
negatives_from_same_seq_only=True
idxs2_flat=tensor([ 332, 864, 343, 199, 428, 747, 619, 131, 400, 586, 976, 417,
400, 539, 786, 351, 286, 71, 776, 846, 70, 838, 877, 463,
754, 667, 678, 364, 514, 482, 815, 490, 935, 368, 397, 849,
260, 728, 65, 66, 174, 513, 598, 70, 343, 208, 765, 698,
187, 177, 344, 169, 656, 806, 655, 293, 895, 869, 269, 38,
961, 404, 673, 40, 974, 87, 974, 479, 429, 350, 905, 843,
757, 789, 357, 515, 615, 512, 948, 875, 379, 649, 24, 72,
412, 988, 519, 817, 805, 725, 89, 393, 625, 559, 94, 841,
497, 302, 88, 137, 343, 105, 297, 483, 758, 834, 476, 703,
142, 141, 752, 360, 331, 634, 616, 972, 44, 114, 318, 875,
938, 329, 596, 132, 584, 838, 644, 613, 1445, 1711, 1918, 1039,
1723, 1428, 1982, 1685, 1469, 1734, 1217, 1360, 1343, 1588, 1556, 1719,
1888, 1703, 1181, 1745, 1740, 1927, 1322, 1854, 1430, 1897, 1248, 1912,
1527, 1208, 1786, 1970, 1797, 1174, 1660, 1555, 1189, 1115, 1378, 1838,
1603, 1786, 1096, 1756, 1637, 1568, 1049, 1888, 1935, 1467, 1577, 1148,
1413, 1979, 1894, 1253, 1663, 1916, 1951, 1949, 1899, 1589, 1106, 1622,
1057, 1964, 1750, 1431, 1653, 1450, 1326, 1531, 1751, 1014, 1517, 1110,
1120, 1089, 1788, 1005, 1198, 1360, 1046, 1756, 1723, 1667, 1613, 1129,
1701, 1271, 1274, 1745, 1027, 1479, 1516, 1484, 1844, 1027, 1718, 1710,
1060, 1605, 1077, 1177, 1921, 1448, 1216, 1707, 1597, 1128, 1595, 1628,
1167, 1546, 1159, 1368, 1457, 1566, 1421, 1876, 1467, 1422, 1943, 1464,
1445, 1848, 1406, 1179, 2172, 2088, 2894, 2695, 2492, 2516, 2396, 2398,
2269, 2541, 2462, 2402, 2648, 2743, 2507, 2836, 2751, 2115, 2211, 2345,
2556, 2651, 2588, 2936, 2143, 2611, 2922, 2869, 2121, 2995, 2753, 2983,
2857, 2704, 2881, 2578, 2234, 2997, 2253, 2091, 2884, 2596, 2919, 2323,
2362, 2330, 2928, 2727, 2183, 2934, 2705, 2731, 2503, 2123, 2515, 2900,
2513, 2477, 2623, 2376, 2415, 2497, 2623, 2118, 2115, 2949, 2238, 2813,
2975, 2773, 2057, 2672, 2621, 2056, 2277, 2148, 2467, 2998, 2102, 2044,
2150, 2974, 2929, 2077, 2455, 2402, 2842, 2552, 2517, 2637, 2034, 2950,
2010, 2104, 2349, 2270, 2464, 2871, 2818, 2715, 2243, 2010, 2648, 2497,
2841, 2165, 2045, 2383, 2138, 2900, 2346, 2180, 2773, 2487, 2702, 2826,
2753, 2217, 2977, 2356, 2034, 2449, 2770, 2685, 2585, 2851, 2306, 2315,
3225, 3503, 3852, 3388, 3396, 3664, 3638, 3135, 3639, 3656, 3772, 3106,
3894, 3143, 3768, 3732, 3502, 3633, 3505, 3548, 3613, 3851, 3862, 3507,
3361, 3004, 3443, 3485, 3567, 3328, 3387, 3278, 3622, 3663, 3386, 3121,
3556, 3104, 3205, 3862, 3048, 3554, 3301, 3429, 3066, 3982, 3951, 3898,
3881, 3840, 3330, 3296, 3805, 3851, 3353, 3231, 3478, 3648, 3705, 3922,
3396, 3645, 3581, 3508, 3014, 3472, 3046, 3432, 3144, 3969, 3190, 3691,
3127, 3669, 3343, 3556, 3312, 3704, 3772, 3551, 3323, 3353, 3541, 3314,
3934, 3536, 3626, 3754, 3228, 3664, 3255, 3712, 3517, 3610, 3965, 3529,
3156, 3924, 3754, 3667, 3622, 3124, 3121, 3273, 3569, 3896, 3097, 3135,
3734, 3633, 3110, 3608, 3675, 3938, 3221, 3790, 3354, 3898, 3113, 3326,
3717, 3062, 3256, 3296, 3259, 3353, 3865, 3176])
idxs2_flat.size()=torch.Size([512])
negatives.size()=torch.Size([4, 128, 512])
candidates.size()=torch.Size([4, 129, 512])
output[:,i].size()=torch.Size([4, 512])
output[:,i].unsqueeze(1).size()=torch.Size([4, 1, 512])
(output[:,i].unsqueeze(1)*candidates).size()=torch.Size([4, 129, 512])
preds=tensor([[-1.0552e-02, -1.5840e-01, -1.3136e-01, -3.1857e-02, -1.2053e-01,
-3.3754e-01, -4.6784e-01, -1.3208e-01, -1.1423e-01, -1.7025e-01,
-1.0325e-01, -1.0039e-01, -9.5827e-02, -1.7025e-01, -7.6416e-02,
3.5367e-02, -3.7562e-02, -1.5303e-01, -6.4000e-02, 1.2499e-01,
1.6475e-01, -9.5313e-02, -1.1112e-01, -1.2842e-01, -1.2734e-01,
-9.1730e-02, -6.8705e-02, 9.9640e-03, -1.4748e-01, -9.0829e-02,
-2.1478e-02, -1.3511e-01, -3.5305e-02, -3.0083e-02, -1.6378e-01,
-1.4172e-01, -8.9982e-02, -7.4905e-02, -3.5752e-01, 5.6087e-03,
3.9132e-02, -1.2712e-01, -9.3274e-02, -1.3187e-01, -9.5313e-02,
-3.1857e-02, -7.0321e-02, -5.7903e-01, 2.5004e-02, -5.0105e-02,
-1.0209e-01, -3.3778e-02, -1.6755e-01, -8.1964e-02, -1.2748e-01,
-1.2313e-01, -1.6774e-01, -1.0817e-01, -1.2191e-01, -5.0001e-02,
-8.8910e-02, -8.9596e-02, -8.1565e-02, -8.3702e-02, -9.4197e-02,
-1.2178e-01, -1.7780e-01, -1.2178e-01, -1.9442e-01, -4.0424e-01,
9.9631e-04, -1.5724e-01, 4.5504e-01, -3.0891e-01, -5.0742e-02,
-2.3497e-02, -8.8144e-02, -1.1154e-01, -8.5451e-02, -7.9171e-02,
-1.2456e-01, -1.1937e-01, -1.4265e-01, -1.0304e-01, -7.8468e-02,
-5.7888e-02, -2.1738e-01, -7.3160e-02, -1.8317e-01, -1.4422e-01,
-4.2038e-01, -1.8054e-01, -1.2432e-01, 1.7417e-01, -4.5759e-02,
-1.4227e-01, -2.5704e-02, -1.0848e-01, -7.4679e-02, -1.9166e-01,
-4.1202e-02, -3.1857e-02, -1.5834e-01, -1.6926e-01, 1.0062e-02,
-3.5874e-01, -1.1299e-01, -1.1978e-01, -9.0256e-02, -1.1855e-01,
-8.8107e-02, -2.6040e-01, -1.0704e-01, -4.9900e-01, -1.0485e-01,
-1.2694e-01, -1.5587e-01, 8.5378e-02, -4.4137e-01, -9.3385e-02,
-1.2456e-01, -1.4650e-01, -1.5184e-01, -1.0268e-01, -1.2723e-01,
-7.1851e-02, -1.1112e-01, -7.9374e-02, -1.1716e-01],
[-2.6666e-01, 2.4091e-01, -7.6162e-03, -6.5027e-02, -1.4427e-01,
-5.3186e-02, -9.1817e-01, -3.1107e-01, -1.2833e-01, -9.0936e-02,
1.0595e-01, -4.0094e-01, -1.5961e-01, -2.3634e-01, -1.0205e-01,
-3.2538e-02, -5.8687e-02, -5.3516e-02, -1.0387e-01, -1.9667e-01,
-4.4039e-01, -1.7277e-01, -1.7113e-01, -1.1543e-01, -1.4215e-01,
-2.5927e-02, -2.7665e-01, -1.6980e-02, -2.9475e-01, -7.1099e-02,
2.3546e-01, -2.1533e-01, -1.2813e-01, -1.5419e-01, -1.3814e-01,
-3.2814e-02, -3.1765e-02, -2.2404e-01, -3.1153e-01, -1.6467e-01,
-1.0699e-01, -7.4868e-02, -2.1533e-01, -3.7502e-01, -1.8692e-01,
-1.4227e-02, -1.1294e-02, -1.0064e-01, -5.3516e-02, 2.7307e-01,
2.3519e-01, 3.6587e-02, -5.1851e-01, 2.0092e-01, -2.6192e-01,
-3.0084e-01, -1.6547e-01, -2.0958e-01, -2.4749e-01, -1.2559e-01,
-1.7435e-02, -4.8185e-01, -7.5818e-02, -3.5181e-01, 1.2276e-01,
-4.7514e-01, -1.8789e-01, -2.1584e-01, -1.7890e-01, 4.4840e-01,
3.0576e-01, -8.0726e-02, -3.4730e-02, -1.7476e-01, -5.2701e-01,
-1.5629e-02, -3.1000e-01, -2.3776e-01, -3.5448e-01, -2.0681e-01,
-5.7091e-01, -1.5471e-01, -1.5961e-01, -5.4969e-01, -1.8692e-01,
-5.3186e-02, 1.0348e-01, 2.1710e-01, -2.1205e-01, -4.5291e-02,
-1.3391e-01, -1.6587e-01, -4.4039e-01, -1.4758e-01, -8.0090e-02,
1.4018e-02, 4.4276e-01, 1.6447e-01, -1.4758e-01, 8.2431e-03,
-6.6660e-02, -5.2126e-01, -6.0018e-02, -3.4418e-01, -4.3153e-02,
1.1534e-01, 4.7844e-01, -2.0447e-01, -9.9088e-03, -1.7922e-01,
-2.3023e-01, 1.1266e-01, 7.0687e-03, 1.1055e-01, -4.1183e-02,
-2.4990e-01, 7.9754e-02, -4.8913e-02, 3.3480e-02, -2.4127e-01,
-2.3549e-01, 2.3519e-01, -1.8295e-01, -3.3132e-01, 1.2754e-01,
2.4091e-01, -8.0665e-02, -5.0527e-03, -1.2389e-01],
[-2.1623e-01, -2.9882e-01, 2.6445e-01, 6.4528e-01, -1.6453e-01,
-2.2739e-01, -2.6903e-01, -2.6664e-01, -3.0309e-01, -8.2458e-02,
-2.8802e-01, -2.9246e-01, -9.1200e-02, -8.3108e-02, 1.2861e+00,
-2.2869e-01, -3.8131e-01, -1.5496e-01, -1.7263e-01, -1.7492e-01,
-2.0208e-01, -1.4275e-01, -1.3119e-01, 5.7836e-02, -2.3965e-01,
-2.3308e-01, -4.4014e-01, -1.0145e-01, -6.0919e-02, -9.9740e-02,
-2.1623e-01, -2.0040e-01, -1.7178e-01, -5.9649e-02, -8.9506e-02,
-8.6894e-02, -1.8071e-01, -1.5806e-01, -2.1623e-01, -2.6064e-01,
-2.0473e-01, -6.4504e-02, 3.7771e-03, -2.1369e-01, -2.2899e-01,
-1.7232e-01, -2.3319e-01, -1.7042e-01, -2.3424e-01, -2.8796e-01,
-2.2533e-01, -7.8901e-02, -1.8397e-01, -2.2328e-01, -6.3133e-02,
-4.0996e-02, -1.7258e-01, -1.7079e-01, -9.4583e-02, -8.2082e-02,
1.4631e-01, -1.4054e-01, -1.5893e-01, -8.2082e-02, -1.4801e-01,
-1.7263e-01, -1.6117e-01, -1.8751e-01, -3.3367e-01, -1.8883e-01,
-5.3047e-02, -1.1038e-01, -2.1679e-01, -1.4648e-01, -1.4547e-01,
-1.0161e-01, -1.8566e-01, -3.0342e-01, -2.1623e-01, -3.1519e-01,
-1.3477e-01, -1.6758e-01, -1.7162e-01, -1.6158e-01, -1.1916e-01,
-3.2131e-01, -9.1200e-02, -2.5053e-01, -1.4521e-01, -2.3642e+00,
-1.4416e-01, -2.4283e-01, -2.0224e-01, -1.0561e+00, -3.4772e-01,
-1.6462e-01, -1.0990e-01, -3.2207e-01, 1.9450e-02, 7.8673e-01,
-1.0573e-01, -2.9972e-01, -1.0561e+00, -8.3108e-02, -1.5893e-01,
-2.7888e-01, -2.0354e-01, -1.3901e-01, -2.6176e-01, -9.8829e-02,
-1.7258e-01, -1.5438e-01, -2.6914e-01, -5.3047e-02, -1.3680e-01,
-9.4991e-02, -1.3877e-01, -2.0040e-01, -1.3723e-01, -1.2024e-01,
-1.0886e-01, -2.4283e-01, 1.2661e-01, -1.5936e-01, -3.4422e-01,
-1.6518e-01, -1.1249e-01, -2.2563e-01, -2.6092e-01],
[-2.9977e-02, -8.5338e-02, -2.1015e-01, -1.6073e-01, -2.5322e-01,
-1.0669e-01, -1.0661e-01, -2.8151e-01, -1.1239e-01, -2.7470e-01,
-1.5149e-01, -2.6141e-01, -2.1919e-01, -2.8415e-01, -1.2268e-01,
-2.3905e-01, -1.0034e-01, -2.1259e-01, -2.7185e-01, -2.7472e-01,
-6.4762e-02, -9.2788e-02, -1.8118e-01, -1.0414e-01, -2.2932e-01,
-1.3125e+00, -2.3819e-01, -3.0672e-01, -1.0738e-01, -3.1271e-01,
-9.6309e-02, -3.0116e-01, 1.7409e-01, -1.3128e-01, -1.0168e-01,
-3.4764e-01, -7.2735e-02, -1.6504e-01, -1.8690e-01, -2.4164e-01,
-1.0414e-01, -2.2159e-01, -2.3014e-01, -2.8679e-01, -1.1934e+00,
-1.9532e-01, -1.2188e-01, -3.0133e-01, -3.1416e-01, -5.1713e-02,
-3.0756e-01, -9.1921e-02, -1.0240e+00, 1.6925e-01, -1.8118e-01,
-2.5480e-01, -3.7221e-02, 1.4241e-01, -4.2661e-01, -2.5585e-01,
-1.0874e-01, -1.0669e-01, -2.9115e-01, -4.4820e-01, -2.0454e-01,
-9.5814e-01, -9.6619e-02, -6.7237e-02, -3.8883e-02, -6.4254e-02,
-4.4077e-01, -2.1441e-01, -1.1529e+00, -8.0692e-02, -6.6252e-02,
1.2276e-01, -1.6504e-01, -3.3431e-01, -2.9687e-01, -2.6141e-01,
-2.4292e-01, -1.6662e-01, -2.5480e-01, -6.5617e-02, -3.4656e-01,
-8.3057e-03, -1.5493e-01, -1.0365e+00, -3.7096e-02, -1.1783e+00,
-1.0661e-01, -1.5559e-01, -3.6025e-01, -4.0849e-01, 1.0871e-01,
-2.7379e-01, -1.3305e-01, -1.8497e-01, -1.0076e-01, -3.7096e-02,
-8.6594e-02, -1.3128e-01, -1.4765e-01, -7.2735e-02, -5.2743e-02,
-2.9676e-01, -2.9619e-01, -2.6412e-02, -1.1239e-01, -7.3396e-02,
-2.7185e-01, -1.5300e-01, 1.7072e-01, 1.9137e-01, -1.9447e-01,
-3.2132e-01, -1.4413e-01, -1.6037e-01, -3.1416e-01, -1.7647e-01,
-1.4249e-01, -2.9429e-01, 3.4662e-02, -2.0672e-01, -1.0240e+00,
-1.5271e-01, -2.5480e-01, 2.0199e-03, -2.6212e-01]],
grad_fn=<SumBackward1>)
preds.size()=torch.Size([4, 129])
targs=tensor([0, 0, 0, 0])
targs.size()=torch.Size([4])
loss=tensor(14.3763, grad_fn=<AddBackward0>)
In [15]:
# loss, acc = model_cpc.cpc_loss(data_batch[0],steps_predicted=hparams.steps_predicted,n_false_negatives=hparams.n_false_negatives, negatives_from_same_seq_only=hparams.negatives_from_same_seq_only, eval_acc=True)
In [16]:
loss
Out[16]:
tensor(14.3763, grad_fn=<AddBackward0>)
In [17]:
acc
Out[17]:
tensor(0.0833)
밑에건 finetuning 부분.
In [1]:
###############
#generic
import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import transforms
import torch.nn.functional as F
import torchvision
import os
import argparse
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import copy
#################
#specific
from clinical_ts.timeseries_utils import *
from clinical_ts.ecg_utils import *
from functools import partial
from pathlib import Path
import pandas as pd
import numpy as np
from clinical_ts.xresnet1d import xresnet1d50,xresnet1d101
from clinical_ts.basic_conv1d import weight_init
from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap
from clinical_ts.cpc import *
In [7]:
hparams = {
"accumulate": 1,
"batch_size": 128,
"bias": False,
"data": ['./data/ptb_xl_fs100'],
"discriminative_lr_factor": 0.1,
"distributed_backend": None,
"dropout_head": 0.5,
"epochs": 50,
"executable": "cpc",
"fc_encoder": True,
"finetune": True,
"finetune_dataset": "ptbxl_all",
"gpus": 1,
"gru": False,
"input_channels": 12,
"input_size": 250,
"lin_ftrs_head": [512],
"linear_eval": False,
"lr": 0.0001,
"lr_find": False,
"metadata": None,
"mlp": False,
"n_false_negatives": 128,
"n_hidden": 512,
"n_layers": 2,
"negatives_from_same_seq_only": False,
"no_bn_encoder": False,
"no_bn_head": False,
"normalize": True,
"num_nodes": 1,
"optimizer": "adam",
"output_path": "./runs/cpc/all_ptbxl",
"precision": 16,
"pretrained": './runs/cpc/all/version_5/best_model.ckpt',
"resume": None,
"skip_encoder": False,
"steps_predicted": 12,
"train_head_only": True,
"weight_decay": 0.001,
}
class Struct:
def __init__(self, **entries):
self.__dict__.update(entries)
hparams = Struct(**hparams)
# configure dataset params
chunkify_train = False
chunk_length_train = hparams.input_size if chunkify_train else 0
stride_train = hparams.input_size
chunkify_valtest = True
chunk_length_valtest = hparams.input_size if chunkify_valtest else 0
stride_valtest = hparams.input_size//2
train_datasets = []
val_datasets = []
test_datasets = []
for i,target_folder in enumerate(hparams.data):
target_folder = Path(target_folder)
df_mapped, lbl_itos, mean, std = load_dataset(target_folder)
# always use PTB-XL stats
mean = np.array([-0.00184586, -0.00130277, 0.00017031, -0.00091313, -0.00148835, -0.00174687, -0.00077071, -0.00207407, 0.00054329, 0.00155546, -0.00114379, -0.00035649])
std = np.array([0.16401004, 0.1647168 , 0.23374124, 0.33767231, 0.33362807, 0.30583013, 0.2731171 , 0.27554379, 0.17128962, 0.14030828, 0.14606956, 0.14656108])
#specific for PTB-XL
if(hparams.finetune and hparams.finetune_dataset.startswith("ptbxl")):
if(hparams.finetune_dataset=="ptbxl_super"):
ptb_xl_label = "label_diag_superclass"
elif(hparams.finetune_dataset=="ptbxl_all"):
ptb_xl_label = "label_all"
lbl_itos= np.array(lbl_itos[ptb_xl_label])
def multihot_encode(x, num_classes):
res = np.zeros(num_classes,dtype=np.float32)
for y in x:
res[y]=1
return res
df_mapped["label"]= df_mapped[ptb_xl_label+"_filtered_numeric"].apply(lambda x: multihot_encode(x,len(lbl_itos)))
lbl_itos = lbl_itos
tfms_ptb_xl_cpc = ToTensor() if hparams.normalize is False else transforms.Compose([Normalize(mean,std),ToTensor()])
max_fold_id = df_mapped.strat_fold.max() #unfortunately 1-based for PTB-XL; sometimes 100 (Ribeiro)
df_train = df_mapped[df_mapped.strat_fold<(max_fold_id-1 if hparams.finetune else max_fold_id)]
df_val = df_mapped[df_mapped.strat_fold==(max_fold_id-1 if hparams.finetune else max_fold_id)]
if(hparams.finetune):
df_test = df_mapped[df_mapped.strat_fold==max_fold_id]
train_datasets.append(TimeseriesDatasetCrops(df_train,hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_train,min_chunk_length=hparams.input_size, stride=stride_train,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if hparams.finetune else None,memmap_filename=target_folder/("memmap.npy")))
val_datasets.append(TimeseriesDatasetCrops(df_val,hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if hparams.finetune else None,memmap_filename=target_folder/("memmap.npy")))
if(hparams.finetune):
test_datasets.append(TimeseriesDatasetCrops(df_test,hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label",memmap_filename=target_folder/("memmap.npy")))
print("\n",target_folder)
print("train dataset:",len(train_datasets[-1]),"samples")
print("val dataset:",len(val_datasets[-1]),"samples")
if(hparams.finetune):
print("test dataset:",len(test_datasets[-1]),"samples")
if(len(train_datasets)>1): #multiple data folders
print("\nCombined:")
train_dataset = ConcatDataset(train_datasets)
val_dataset = ConcatDataset(val_datasets)
print("train dataset:",len(train_dataset),"samples")
print("val dataset:",len(val_dataset),"samples")
if(hparams.finetune):
test_dataset = ConcatDataset(test_datasets)
print("test dataset:",len(test_dataset),"samples")
else: #just a single data folder
train_dataset = train_datasets[0]
val_dataset = val_datasets[0]
if(hparams.finetune):
test_dataset = test_datasets[0]
data/ptb_xl_fs100
train dataset: 17441 samples
val dataset: 15351 samples
test dataset: 15421 samples
In [19]:
signal, label = train_dataset.__getitem__(7)
In [20]:
signal.size()
Out[20]:
torch.Size([12, 250])
In [21]:
label
Out[21]:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
In [24]:
label.size()
Out[24]:
torch.Size([71])
In [22]:
import matplotlib.pyplot as plt
# 12 lead ecg full 형태로 그려서 보기
def ecgvis(data, mask=None, title:str=None, save_path=None, imshow=True):
if mask != None:
fig = plt.figure(figsize=(10, 10))
fig.suptitle(title, fontsize=16)
fig.subplots_adjust(top=0.93)
axs = fig.subplots(12, 2, sharex=True)
for i, (ax, _) in enumerate(zip(axs, data)):
ax[0].plot(data[i])
ax[0].set_ylim([-2,2])
ax[1].plot(mask[i])
ax[1].set_ylim([-2,2])
else:
fig = plt.figure(figsize=(5,10))
fig.suptitle(title, fontsize=16)
fig.subplots_adjust(top=0.93)
axs = fig.subplots(12, 1, sharex=True)
for i, (ax, _) in enumerate(zip(axs, data)):
ax.plot(data[i])
ax.set_ylim([-2,2])
if save_path:
print(f"{save_path} saved.")
plt.savefig(save_path)
if imshow:
plt.show()
In [23]:
ecgvis(signal)
In [25]:
data_loader = DataLoader(train_dataset, batch_size=4, num_workers=4, shuffle=True, drop_last = True)
In [88]:
from typing import Iterable
def listify(p=None, q=None):
"Make `p` listy and the same length as `q`."
if p is None: p=[]
elif isinstance(p, str): p = [p]
elif not isinstance(p, Iterable): p = [p]
#Rank 0 tensors in PyTorch are Iterable but don't have a length.
else:
try: a = len(p)
except: p = [p]
n = q if type(q)==int else len(p) if q is None else len(q)
if len(p)==1: p = p * n
assert len(p)==n, f'List len mismatch ({len(p)} vs {n})'
return list(p)
In [89]:
def bn_drop_lin(n_in, n_out, bn=True, p=0., actn=None):
"Sequence of batchnorm (if `bn`), dropout (with `p`) and linear (`n_in`,`n_out`) layers followed by `actn`."
layers = [nn.BatchNorm1d(n_in)] if bn else []
if p != 0: layers.append(nn.Dropout(p))
layers.append(nn.Linear(n_in, n_out))
if actn is not None: layers.append(actn)
return layers
In [172]:
# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/17_cpc.ipynb (unless otherwise specified).
__all__ = ['CPCEncoder', 'CPCModel']
# Cell
import torch
import torch.nn.functional as F
import torch.nn as nn
from clinical_ts.basic_conv1d import _conv1d
import numpy as np
# from .clinical_ts.basic_conv1d import listify, bn_drop_lin
# Cell
class CPCEncoder(nn.Sequential):
'CPC Encoder'
# strides = [1, 1, 1, 1]
# kss = [1, 1, 1, 1]
# features = [512, 512, 512, 512]
def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn=False):
assert(len(strides)==len(kss) and len(strides)==len(features))
lst = []
for i,(s,k,f) in enumerate(zip(strides,kss,features)):
lst.append(_conv1d(input_channels if i==0 else features[i-1],f,kernel_size=k,stride=s,bn=bn))
super().__init__(*lst)
self.downsampling_factor = np.prod(strides)
self.output_dim = features[-1]
# output: bs, output_dim, seq//downsampling_factor
def encode(self, input):
#bs = input.size()[0]
#ch = input.size()[1]
#seq = input.size()[2]
#segments = seq//self.downsampling_factor
#input_encoded = self.forward(input[:,:,:segments*self.downsampling_factor]).transpose(1,2) #bs, seq//downsampling, encoder_output_dim (standard ordering for batch_first RNNs)
print(f"{input.size()=}")
input_encoded = self.forward(input)
print(f"{input_encoded.size()=}")
input_encoded = input_encoded.transpose(1,2)
print(f"{input_encoded.size()=}")
return input_encoded
# Cell
class CPCModel(nn.Module):
"CPC model"
def __init__(self, input_channels, strides=[5,4,2,2,2], kss=[10,8,4,4,4], features=[512,512,512,512],bn_encoder=False, n_hidden=512,n_layers=2,mlp=False,lstm=True,bias_proj=False, num_classes=None, concat_pooling=True, ps_head=0.5,lin_ftrs_head=[512],bn_head=True,skip_encoder=False):
super().__init__()
assert(skip_encoder is False or num_classes is not None)#pretraining only with encoder
self.encoder = CPCEncoder(input_channels,strides=strides,kss=kss,features=features,bn=bn_encoder) if skip_encoder is False else None
self.encoder_output_dim = self.encoder.output_dim if skip_encoder is False else None
self.encoder_downsampling_factor = self.encoder.downsampling_factor if skip_encoder is False else None
self.n_hidden = n_hidden
self.n_layers = n_layers
self.mlp = mlp
self.num_classes = num_classes
self.concat_pooling = concat_pooling
self.rnn = nn.LSTM(self.encoder_output_dim if skip_encoder is False else input_channels,n_hidden,num_layers=n_layers,batch_first=True) if lstm is True else nn.GRU(self.encoder.output_dim,n_hidden,num_layers=n_layers,batch_first=True)
if(num_classes is None): #pretraining
if(mlp):# additional hidden layer as in simclr
self.proj = nn.Sequential(nn.Linear(n_hidden, n_hidden),nn.ReLU(inplace=True),nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj))
else:
self.proj = nn.Linear(n_hidden, self.encoder_output_dim,bias=bias_proj)
else: #classifier
# self.head=Sequential(
# (0): AdaptiveConcatPoolRNN()
# (1): BatchNorm1d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): Dropout(p=0.25, inplace=False)
# (3): Linear(in_features=1536, out_features=512, bias=True)
# (4): ReLU(inplace=True)
# (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (6): Dropout(p=0.5, inplace=False)
# (7): Linear(in_features=512, out_features=71, bias=True)
# )
#slightly adapted from RNN1d
layers_head =[]
if(self.concat_pooling):
layers_head.append(AdaptiveConcatPoolRNN())
#classifier
nf = 3*n_hidden if concat_pooling else n_hidden
lin_ftrs_head = [nf, num_classes] if lin_ftrs_head is None else [nf] + lin_ftrs_head + [num_classes]
ps_head = listify(ps_head)
if len(ps_head)==1:
ps_head = [ps_head[0]/2] * (len(lin_ftrs_head)-2) + ps_head
actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs_head)-2) + [None]
for ni,no,p,actn in zip(lin_ftrs_head[:-1],lin_ftrs_head[1:],ps_head,actns):
layers_head+=bn_drop_lin(ni,no,bn_head,p,actn)
self.head=nn.Sequential(*layers_head)
def forward(self, input):
# input shape bs,ch,seq
if(self.encoder is not None):
input_encoded = self.encoder.encode(input)
else:
input_encoded = input.transpose(1,2) #bs, seq, channels
output_rnn, _ = self.rnn(input_encoded) #output_rnn: bs, seq, n_hidden
if(self.num_classes is None):#pretraining
print(f"{input_encoded.size()=}")
print(f"{self.rnn=}")
print(f"{output_rnn.size()=}")
print(f"{self.proj=}")
print(f"{self.proj(output_rnn).size()=}")
return input_encoded, self.proj(output_rnn)
else:#classifier
print(f"{input_encoded.size()=}")
print(f"{self.rnn=}")
print(f"{output_rnn.size()=}")
output = output_rnn.transpose(1,2)#bs,n_hidden,seq (i.e. standard CNN channel ordering)
print(f"{output.size()=}")
if(self.concat_pooling is False):
output = output[:,:,-1]
print(f"{output.size()=}")
print(f"{self.head=}")
return self.head(output)
def get_layer_groups(self):
return (self.encoder,self.rnn,self.head)
def get_output_layer(self):
return self.head[-1]
def set_output_layer(self,x):
self.head[-1] = x
#copied from RNN1d
class AdaptiveConcatPoolRNN(nn.Module):
def __init__(self, bidirectional=False):
super().__init__()
self.bidirectional = bidirectional
def forward(self,x):
#input shape bs, ch, ts
t1 = nn.AdaptiveAvgPool1d(1)(x)
t2 = nn.AdaptiveMaxPool1d(1)(x)
print(f"{x.size()=}") # (bs, 512, 250)
print(f"{t1.size()=}") # (bs, 512, 1)
if(self.bidirectional is False):
t3 = x[:,:,-1]
print(f"{t3.size()=}")
else:
channels = x.size()[1]
t3 = torch.cat([x[:,:channels,-1],x[:,channels:,0]],1)
out=torch.cat([t1.squeeze(-1),t2.squeeze(-1),t3],1) #output shape bs, 3*ch
print(f"{out.size()=}") # (bs, 512*3)
return out
In [173]:
if(hparams.finetune):
criterion = F.cross_entropy if hparams.finetune_dataset == "thew" else F.binary_cross_entropy_with_logits
if(hparams.finetune_dataset == "thew"):
num_classes = 5
elif(hparams.finetune_dataset == "ptbxl_super"):
num_classes = 5
if(hparams.finetune_dataset == "ptbxl_all"):
num_classes = 71
else:
num_classes = None
strides=[1]*4
kss = [1]*4
features = [512]*4
model_cpc = CPCModel(
input_channels=hparams.input_channels, # 12
strides=strides, # [2,2,2,2]
kss=kss, # [1,1,1,1]
features=features, # [512]*4
n_hidden=hparams.n_hidden, # 512
n_layers=hparams.n_layers, # 2
mlp=hparams.mlp, #
lstm=not(hparams.gru),
bias_proj=hparams.bias,
num_classes=71,
skip_encoder=hparams.skip_encoder,
bn_encoder=not(hparams.no_bn_encoder),
# lin_ftrs_head=[] if hparams.linear_eval else eval(hparams.lin_ftrs_head),
lin_ftrs_head=[] if hparams.linear_eval else [512],
ps_head=0 if hparams.linear_eval else hparams.dropout_head,
bn_head=False if hparams.linear_eval else not(hparams.no_bn_head)
)
In [174]:
# torch.load('runs/cpc/all/version_5/best_model.ckpt').keys()
In [175]:
# model_cpc = model_cpc.load_state_dict(torch.load('runs/cpc/all/version_5/best_model.ckpt')['state_dict'])
In [176]:
data_batch = next(iter(data_loader))
print(len(data_batch))
2
In [177]:
data_batch[0].size()
Out[177]:
torch.Size([4, 12, 250])
In [178]:
data_batch[1]
Out[178]:
tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
In [179]:
criterion = F.binary_cross_entropy_with_logits
preds = model_cpc.forward(data_batch[0])
loss = criterion(preds,data_batch[1])
input.size()=torch.Size([4, 12, 250])
input_encoded.size()=torch.Size([4, 512, 250])
input_encoded.size()=torch.Size([4, 250, 512])
input_encoded.size()=torch.Size([4, 250, 512])
self.rnn=LSTM(512, 512, num_layers=2, batch_first=True)
output_rnn.size()=torch.Size([4, 250, 512])
output.size()=torch.Size([4, 512, 250])
output.size()=torch.Size([4, 512, 250])
self.head=Sequential(
(0): AdaptiveConcatPoolRNN()
(1): BatchNorm1d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): Dropout(p=0.25, inplace=False)
(3): Linear(in_features=1536, out_features=512, bias=True)
(4): ReLU(inplace=True)
(5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): Dropout(p=0.5, inplace=False)
(7): Linear(in_features=512, out_features=71, bias=True)
)
x.size()=torch.Size([4, 512, 250])
t1.size()=torch.Size([4, 512, 1])
t3.size()=torch.Size([4, 512])
out.size()=torch.Size([4, 1536])
In [180]:
# loss, acc = model_cpc.cpc_loss(data_batch[0],steps_predicted=hparams.steps_predicted,n_false_negatives=hparams.n_false_negatives, negatives_from_same_seq_only=hparams.negatives_from_same_seq_only, eval_acc=True)
In [181]:
preds
Out[181]:
tensor([[ 0.4904, 0.9663, -0.8967, 0.7091, -0.5224, -2.8415, -0.6364, 0.8394,
-0.0612, 0.3853, 0.9607, -0.2586, 0.0475, -1.1203, 1.0986, 0.2361,
-0.2130, -0.0164, 1.8457, 0.1226, -0.2282, -0.5134, 0.3835, 0.6062,
0.7211, -0.6962, 0.4977, -0.2050, 0.7662, -0.0516, -0.6528, -0.9284,
0.6641, -0.2279, 1.3668, 0.2698, -0.0427, -1.0269, 1.3962, 0.0154,
0.6494, -0.1973, -0.6923, 0.0087, 1.5831, -1.2592, -0.8713, -0.7730,
-0.2061, 0.2507, -0.9297, 1.0548, -0.4769, 0.5285, 0.2222, 0.0847,
-0.1046, -0.8203, 1.4155, -0.7033, 0.7081, -0.6253, 0.1231, -0.7729,
1.0558, -0.5198, 0.0273, 0.0210, -0.0174, 0.2799, -1.4736],
[ 1.8827, -0.5242, -0.1192, -0.2502, 1.4966, 0.5454, -0.1425, 0.3351,
-0.2048, -0.5085, -0.4204, -0.8195, 2.0861, -0.1588, -1.1620, 1.8986,
0.4400, -0.7598, -0.1281, -0.8977, 0.2911, -0.8124, -1.3636, 0.1560,
0.2761, -0.0915, 0.7039, 0.5319, -0.7401, -0.2648, -1.3579, 0.4114,
-0.1510, -1.1825, -0.4306, -0.7494, 0.2291, 0.5552, 0.1079, 1.2030,
-0.2725, -0.3373, -0.2139, -0.2609, 0.4887, -1.0080, 0.2039, -0.3551,
0.2238, -0.5231, 0.0542, -0.8187, -0.9941, -0.2048, 0.8903, 0.4601,
-0.6119, 0.4721, 0.1118, 0.4077, -0.4585, 1.5003, 0.4742, 1.2788,
-0.1211, -0.8732, 0.3471, -0.1384, 0.9570, 0.3622, 0.4238],
[ 0.2124, 1.1371, -1.4289, 0.5764, 0.1256, 0.7644, -0.0632, 0.2319,
-1.5474, -1.0827, -0.0089, -0.2720, -0.2680, 0.3866, -1.5928, -1.7298,
-1.1263, -0.1306, -0.7756, -1.8331, -0.7831, -1.7948, -0.2670, -1.0389,
-0.2255, 0.9311, 0.1483, 0.7334, 0.2286, 0.1374, 1.4007, 1.2882,
-1.0346, -0.1200, 0.5116, -0.7399, 1.0112, -0.1332, -1.0827, -1.7130,
-0.1061, 0.2382, 1.0377, 0.9740, 0.6781, 0.4616, 1.5393, 0.1711,
-0.7150, 1.0673, 0.0099, -1.3019, 0.0671, -1.1842, -1.1761, 0.1942,
0.9548, 0.8669, -0.5044, -0.6430, 0.5952, 0.2021, -1.1524, -0.5429,
-0.2770, -0.1876, -0.9680, -0.5728, -0.1015, 0.8704, 0.9444],
[-0.5953, -0.5858, 1.1886, 0.3539, -0.6052, 1.0329, -0.5317, -1.1553,
0.6659, -0.2896, -1.2310, 1.2269, -1.2983, 0.1395, 0.9812, -0.1193,
1.4990, -1.6813, 1.2867, 0.3875, 0.8207, -0.4037, -0.6387, 0.5325,
0.4646, -0.7448, -0.8019, -1.4744, -0.7349, 1.0952, 1.4583, -0.0345,
-0.5427, 0.2260, -1.4092, 0.8894, -0.7744, 0.1883, -0.0135, 0.5633,
0.2371, -1.4176, -1.6180, 0.6626, -1.1416, 1.4265, -0.5873, 0.3752,
-0.8690, 0.1606, 0.7334, 0.1294, 1.9597, -0.0328, 0.6499, 0.7449,
-0.0300, -0.1900, 0.8012, -1.1069, -1.0994, -0.5885, 0.4513, -0.4877,
-0.1794, -0.9845, -1.0642, 0.5420, -0.5530, -2.1108, 1.0529]],
grad_fn=<AddmmBackward>)
In [161]:
preds.size()
Out[161]:
torch.Size([4, 71])
In [149]:
data_batch[1].size()
Out[149]:
torch.Size([4, 71])
In [150]:
loss
Out[150]:
tensor(0.7736, grad_fn=<BinaryCrossEntropyWithLogitsBackward>)
In [150]:
In [ ]:
#copied from RNN1d
class AdaptiveConcatPoolRNN(nn.Module):
def __init__(self, bidirectional=False):
super().__init__()
self.bidirectional = bidirectional
def forward(self,x):
#input shape bs, ch, ts
t1 = nn.AdaptiveAvgPool1d(1)(x)
t2 = nn.AdaptiveMaxPool1d(1)(x)
if(self.bidirectional is False):
t3 = x[:,:,-1]
else:
channels = x.size()[1]
t3 = torch.cat([x[:,:channels,-1],x[:,channels:,0]],1)
out=torch.cat([t1.squeeze(-1),t2.squeeze(-1),t3],1) #output shape bs, 3*ch
return out
'논문 읽기 연습 > 그냥 메모' 카테고리의 다른 글
Exploring Simple Siamese Representation Learning (0) | 2022.02.11 |
---|