-
Notifications
You must be signed in to change notification settings - Fork 6
/
step3e_validation_pairwise_decoding1854.m
148 lines (127 loc) · 5.26 KB
/
step3e_validation_pairwise_decoding1854.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
function step3e_validation_pairwise_decoding1854(bids_dir, toolbox_dir, participant, blocknr, n_blocks)
%% Function to run pairwise decoding analysis for the 1854 object classes
%
% @ Lina Teichmann, 2022
%
% Usage:
% step3d_validation_pairwise_decoding200(bids_dir,participant, ...)
%
% Inputs:
% bids_dir path to the bids root folder
% toolbox_dir path to toolbox folder containtining CoSMoMVPA
% participant participant number
% blocknr number of the chunk you want to run this analysis for
% n_blocks how many blocks you want to run this analysis in (this is to make it faster by running stuff in parallel)
%
% Returns:
% decoding_acc file that has the decoding accuracy for each decoding block ('PX_pairwise)decoding_1854_blockX.mat')
% decoding_pairs file that contains which pairwise comparisons were run so it can be stacked back together ('PX_pairwise)decoding_1854_blockX_pairs.mat')
%% folders
preprocdir = [bids_dir '/derivatives/preprocessed/'];
res_dir = [bids_dir '/derivatives/output/'];
addpath(genpath([toolbox_dir '/CoSMoMVPA']))
load([preprocdir '/P' num2str(participant) '_cosmofile.mat'],'ds');
% make a pairwise decoding folder if it does not exist
if ~exist([res_dir '/pairwise_decoding'], 'dir')
mkdir([res_dir '/pairwise_decoding'])
end
outfn = [res_dir '/pairwise_decoding/P' num2str(participant) '_pairwise_decoding_1854_block' num2str(blocknr) '.mat'];
outfn_pairs = [res_dir '/pairwise_decoding/P' num2str(participant) '_pairwise_decoding_1854_block' num2str(blocknr) '_pairs.mat'];
%% pairwise decoding
ds = cosmo_slice(ds,strcmp(ds.sa.trial_type,'exp'));
ds.sa.targets = ds.sa.things_category_nr;
ds.sa.chunks = ds.sa.session_nr;
all_combinations = combnk(unique(ds.sa.targets),2);
all_targets = unique(ds.sa.targets);
all_chunks = unique(ds.sa.chunks);
% split into blocks
step = ceil(length(all_combinations)/n_blocks);
s = 1:step:length(all_combinations);
blocks = cell(length(s),1);
for b = 1:length(s)
blocks{b} = all_combinations(s(b):min(s(b)+step-1,length(all_combinations)),:);
end
combs = blocks{blocknr};
save(outfn_pairs, 'combs')
nproc = cosmo_parallel_get_nproc_available;
%% create RDM
% find the items belonging to the exemplars
target_idx = cell(1,length(all_targets));
for j=1:length(all_targets)
target_idx{j} = find(ds.sa.targets==all_targets(j));
end
% for each chunk, find items belonging to the test set
test_chunk_idx = cell(1,length(all_chunks));
for j=1:length(all_chunks)
test_chunk_idx{j} = find(ds.sa.chunks==all_chunks(j));
end
%% make blocks for parfor loop
step = ceil(length(combs)/nproc);
s = 1:step:length(combs);
comb_blocks = cell(1,length(s));
for b = 1:nproc
comb_blocks{b} = combs(s(b):min(s(b)+step-1,length(combs)),:);
end
%arguments for searchlight and crossvalidation
ma = struct();
ma.classifier = @cosmo_classify_lda;
ma.output = 'accuracy';
ma.check_partitions = false;
ma.nproc = 1;
ma.progress = 0;
ma.partitions = struct();
% set options for each worker process
nh = cosmo_interval_neighborhood(ds,'time','radius',0);
worker_opt_cell = cell(1,nproc);
for procs=1:nproc
worker_opt=struct();
worker_opt.ds=ds;
worker_opt.ma = ma;
worker_opt.uc = all_chunks;
worker_opt.worker_id=procs;
worker_opt.nproc=nproc;
worker_opt.nh=nh;
worker_opt.combs = comb_blocks{procs};
worker_opt.target_idx = target_idx;
worker_opt.test_chunk_idx = test_chunk_idx;
worker_opt_cell{procs}=worker_opt;
end
%% run the workers
tic
result_map_cell=cosmo_parcellfun(nproc,@run_block_with_worker,worker_opt_cell,'UniformOutput',false);
toc
%% cat the results
res=cosmo_stack(result_map_cell);
%% save
fprintf('Saving...');tic
save(outfn,'res','-v7.3')
fprintf('Saving finished in %i seconds\n',ceil(toc))
end
function res_block = run_block_with_worker(worker_opt)
ds=worker_opt.ds;
nh=worker_opt.nh;
ma=worker_opt.ma;
uc=worker_opt.uc;
target_idx=worker_opt.target_idx;
test_chunk_idx=worker_opt.test_chunk_idx;
worker_id=worker_opt.worker_id;
nproc=worker_opt.nproc;
combs=worker_opt.combs;
res_cell = cell(1,length(combs));
cc=clock();mm='';
for i=1:length(combs)
idx_ex = [target_idx{combs(i,1)}; target_idx{combs(i,2)}];
[ma.partitions.train_indices,ma.partitions.test_indices] = deal(cell(1,length(uc)));
for j=1:length(uc)
ma.partitions.train_indices{j} = setdiff(idx_ex,test_chunk_idx{j});
ma.partitions.test_indices{j} = intersect(test_chunk_idx{j},idx_ex);
end
res_cell{i} = cosmo_searchlight(ds,nh,@cosmo_crossvalidation_measure,ma);
res_cell{i}.sa.target1 = combs(i,1);
res_cell{i}.sa.target2 = combs(i,2);
if ~mod(i,10)
mm=cosmo_show_progress(cc,i/length(combs),sprintf('%i/%i for worker %i/%i\n',i,length(combs),worker_id,nproc),mm);
end
end
res_block = cosmo_stack(res_cell);
end