From 2e94cf037b4a88375b096e9b50be81458c71cafa Mon Sep 17 00:00:00 2001 From: ajbrockmeier Date: Wed, 12 Dec 2018 18:51:05 -0500 Subject: [PATCH 1/4] Create CAML.m Centered alignment metric learning (CAML) uses mini-batch case for large datasets with vectors in Euclidean space. --- CAML.m | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 CAML.m diff --git a/CAML.m b/CAML.m new file mode 100644 index 0000000..2e50ed4 --- /dev/null +++ b/CAML.m @@ -0,0 +1,53 @@ +function eta = CAML(logeta_init,DD,labels,Q) +% Copyright 2015 Austin J. Brockmeier +% ajbrockmeier at the domain of gmail.com + +logbase=10; +[NN,d_x]=size(DD); +N=sqrt(NN); +if N^2~=NN + error('reshaped distance matrix is not square'); +end +L=bsxfun(@eq,labels(:),labels(:)'); +L=L/trace(L); +oN=ones(N,1); +avec=oN/N; +nu=L*avec; +Lc=bsxfun(@minus,bsxfun(@minus,L,nu),nu')+avec'*nu; +objfungrad=@(X) my_objfungrad(X,DD,N,d_x,logbase,Lc); +options=[]; +options.Display='off'; +if Q>1 +x0=logeta_init*(ones(d_x,Q)+(rand(d_x,Q)-.5)*.5); +else +x0=logeta_init*(ones(d_x,Q)); +end + +[logeta,~]=minFunc(objfungrad,x0(:),options); +eta=reshape(logbase.^logeta,d_x,[]); +end + +function [f, gradf]= my_objfungrad(logeta,DD,N,d_x,logbase,Lc) +eta=reshape(logbase.^logeta,d_x,[]); +Ks=exp(-DD*eta);%this is by far the slowest part (N^2 D) +K=reshape(sum(Ks,2),N,N); +oN=ones(N,1); +avec=oN/N; +mu=K*avec; +Kc=bsxfun(@minus,bsxfun(@minus,K,mu),mu')+avec'*mu; +if any(isnan(K(:))) + f=NaN; + gradf=0*logeta; +else + trKL=oN'*(Kc.*Lc)*oN;% trKL=trace(K*H*L*H);%where H is centering matrix + trKK=oN'*(Kc.^2)*oN;% trKK=trace(K*H*K*H); + f=-real(log(trKL)-log(trKK)/2); + Grad_lk=Lc/trKL;%Lc=H*L*H + Grad_k=Kc/trKK;%Kc=H*K*H + Grad = Grad_lk-Grad_k; + P = bsxfun(@times,Grad(:),Ks); + gradf=reshape(real(P'*DD)',[],1)*log(logbase).*logbase.^logeta; +end + +end + From a2caf213b59922225ac4d86e665cc7f199dab59a Mon Sep 17 00:00:00 2001 From: ajbrockmeier Date: Wed, 12 Dec 2018 18:52:11 -0500 Subject: [PATCH 2/4] Create CAML_approx.m --- CAML_approx.m | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 CAML_approx.m diff --git a/CAML_approx.m b/CAML_approx.m new file mode 100644 index 0000000..80ea977 --- /dev/null +++ b/CAML_approx.m @@ -0,0 +1,61 @@ +function eta = CAML_approx(X,labels,Q,logeta_init,Npos,niter) +% Copyright 2015 Austin J. Brockmeier +% ajbrockmeier at the domain of gmail.com + +logbase=10; +Nbatch=Npos*2; +H=eye(Nbatch)-1/Nbatch; +d_x=size(X,2); + +theta0=logeta_init*(ones(d_x,Q)+cat(2,zeros(d_x,1),(rand(d_x,Q-1)-.5)*.1)); +for ii=1:niter + m=randi(numel(labels),1); + idx_pos=find(labels==labels(m) & (1:numel(labels))'~=m ); + idx_neg=find(labels~=labels(m)); + if numel(idx_pos)>0 + m_pos=idx_pos(randi(numel(idx_pos),Npos-1,1)); + else + m_pos=m; + end + m_neg=idx_neg(randi(numel(idx_neg),Npos,1)); + XX=X([m;m_pos;m_neg],:); + labs=labels([m;m_pos;m_neg]); + L=bsxfun(@eq,labs,labs'); + [~,gradf]=sum_grad(theta0,XX,Nbatch,d_x,logbase,H*L*H); + theta0=theta0-.01*gradf; +end +eta=reshape(logbase.^theta0,d_x,[]); + +end + + +function [f, gradf]= sum_grad(logeta,X,N,d_x,logbase,Lc) + +Xt1=kron(ones(N,1),X); +Xt2=kron(X,ones(N,1)); +DD=bsxfun(@minus,Xt1,Xt2).^2; +eta=reshape(logbase.^logeta,d_x,[]); +Ks=exp(-DD*eta); +K=reshape(sum(Ks,2),N,N); +oN=ones(N,1); +avec=oN/N; +mu=K*avec; +Kc=bsxfun(@minus,bsxfun(@minus,K,mu),mu')+avec'*mu; + +if any(isnan(K(:))) + fprintf('whoa') + f=NaN; + gradf=0*logeta; +else + trKL=oN'*(Kc.*Lc)*oN;% trKL=trace(K*H*L*H);%where H is centering matrix + trKK=oN'*(Kc.^2)*oN;% trKK=trace(K*H*K*H); + trLL=oN'*(Lc.^2)*oN;% trLL=trace(L*H*L*H); + + Grad = (Lc*sqrt(trKK*trLL)-trKL*trLL*Kc*(trKK*trLL)^-.5)/(trKK*trLL); + P = bsxfun(@times,Grad(:),Ks); + gradf=real(P'*DD)'*log(logbase).*logbase.^logeta; + f=-real(trKL/sqrt(trKK*trLL)); +end + +end + From acf6bd4fae82938d35d39b5ccede3fc8930702de Mon Sep 17 00:00:00 2001 From: ajbrockmeier Date: Wed, 12 Dec 2018 18:56:17 -0500 Subject: [PATCH 3/4] Create cmd_simple_test.m --- cmd_simple_test.m | 131 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 131 insertions(+) create mode 100644 cmd_simple_test.m diff --git a/cmd_simple_test.m b/cmd_simple_test.m new file mode 100644 index 0000000..e9f1b93 --- /dev/null +++ b/cmd_simple_test.m @@ -0,0 +1,131 @@ +% A simple script to run the metric learning code +% Austin J. Brockmeier (2013-2015) +% ajbrockmeier at the domain of gmail.com + +fprintf('Reminder to add minFunc to path\n'); +fprintf('via addpath(genpath(...))\n'); + +%User defined parameters +Qs=[1 1];%number of kernels in the sum kernel when Q is greater than 1 this is a nonlinear kernel +run_approx=[1 0];%run CAML approx instead (useful when the number of sample is too big) +%approx has less storage but is much slower since 10000 iterations are run +if numel(Qs)~=numel(run_approx)%numel(Qs) should be equal to numel(run_approx) + error('Please have same number of parameter choices\n'); +end +%Classification paradigm +prop_test=1/3; +prop_valid=1/3;% this is currently unused (in the paper I use it to pick best neighborhood) +nmonte = 20; +Nrep=1; % number of replicate runs +knn=3;%neighborhood size for knn + +for dataset={'Two Gaussians','4 Gaussians (XOR)'}; + % Dataset generation + N=200; + useful_dim=2;%only the first dimensions are meaningful + extra_dim=28; + N1=round(N/2); + N2=N-N1; + switch dataset{1} + case 'xor' + all_labels=cat(1,ones(N1,1),-ones(N2,1)); + centers=[1 1; -1 -1; 1 -1; -1 1]; + vals=[1 -1]; + idx=cat(1,randi(2,N1,1),2+randi(2,N2,1)); + all_centers=cat(2,centers(idx,:),vals(randi(2,N,extra_dim))); + X=.5*randn(N,useful_dim+extra_dim)+all_centers; + otherwise + all_labels=cat(1,ones(N1,1),-ones(N2,1)); + X=randn(N,useful_dim+extra_dim)+kron(all_labels, [ones(1,useful_dim) zeros(1,extra_dim)]); + end + P=size(X,2); + N = size(X,1); + X=bsxfun(@rdivide,X,sqrt(sum(X.^2)));%make variance constant! + % %% precompute Euclidean distances + DDall=zeros(N,N,P); + for ii=1:P %calculate all of the distance matrices + G=X(:,ii)*X(:,ii)'; + Di=-2*G+bsxfun(@plus,diag(G),diag(G)'); + DDall(:,:,ii)= Di;%reshape(Di,N*N,[]); + end + + + orig_results = zeros(nmonte,2); + new_results = zeros(nmonte,2*numel(Qs)); + unit_weights=cell(nmonte,numel(Qs)); + + Diso=sum(DDall,3);% original distances + for mmm = 1:nmonte + %set up test and train partitions + sortii=cell(2,1); + N_keep=round(N*(1-prop_test-prop_valid)); + N_keep2=round(N*(1-prop_test)); + outoforder=randperm(N); + sortii{1} = outoforder(1:N_keep); + sortii{3} = outoforder(1+N_keep:N_keep2); + sortii{2} = outoforder(1+N_keep2:end); + + labels=all_labels(sortii{1}); + test_labels=all_labels(sortii{2}); + cv_labels=all_labels(sortii{3}); + + %Original distance + [~,iii]=sort(Diso(sortii{1},sortii{2}));%un weighted original + acc_1nn=mean(labels(iii(1,:))==test_labels); %1NN + acc_knn=mean(mode(labels(iii(1:knn,:)))'==test_labels);%kNN + orig_results(mmm,:)=[acc_1nn acc_knn]; + + %CAML + DD=reshape(DDall(sortii{1},sortii{1},:),numel(sortii{1})^2,[]); + w_norm=mean(DD,1);% average squared distance + DD=bsxfun(@rdivide,DD,w_norm); %normalize distances + arez=[]; + for qii=1:numel(Qs) + Q=Qs(qii); + if run_approx(qii)==1 + eta = CAML_approx(X(sortii{1},:),labels,Q,-1,4,10000); + else + eta = CAML(-1,DD,labels,Q); + end + w=bsxfun(@rdivide,eta,w_norm');%,[1 3 2]); + %Compute new distance + if Q==1 + Dnew=sqrt(sum(bsxfun(@times,DDall,permute(w,[3 2 1])),3)); + else + Knew=reshape(sum(exp(-reshape(DDall(:,:,:),N^2,[])*w),2),N,N); + K1=diag(Knew)*ones(1,size(Knew,1)); + Dnew = sqrt(-2*Knew+K1+K1'); + end + [~,iii]=sort(Dnew(sortii{1},sortii{2})); + acc_1nn=mean(labels(iii(1,:))==test_labels);%1NN + acc_knn=mean(mode(labels(iii(1:knn,:)))'==test_labels); %knn + arez=cat(2,arez,[acc_1nn acc_knn]); + unit_weights{mmm,qii}=w; + end%end-Q + new_results(mmm,:)=arez; + end % end-Monte Carlo + % Output results + M=[orig_results,new_results]; + rez=mean(M); + rez2=std(M); + method_str=cell(size(M,2),1); + method_str(1:2)={'Orig. 1NN','Orig. kNN'};%,'CAML 3NN','CAML SVM'}; + for ii=1:numel(Qs) + if Qs(ii)==1 + postfix=''; + else + postfix=sprintf(' Q=%i',Qs(ii)); + end + if run_approx(ii)==1 + prefix='~'; + else + prefix=''; + end + method_str{3+2*(ii-1)}=[prefix,'CAML 1NN',postfix]; + method_str{4+2*(ii-1)}=[prefix,'CAML kVM',postfix]; + end + fprintf('%s\n Accuracy (%c correct)\n',dataset{1},'%'); + for ii=1:numel(method_str) + fprintf('%s:%s%ią%.1f\n',method_str{ii},[char(9),char(9)],round(rez(ii)*100),100*rez2(ii)); + end +end From 90a7f4a294985679d60d91d61102a265ecf1d60b Mon Sep 17 00:00:00 2001 From: ajbrockmeier Date: Wed, 12 Dec 2018 19:01:33 -0500 Subject: [PATCH 4/4] Update README.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 963d765..6de26cf 100644 --- a/README.md +++ b/README.md @@ -6,3 +6,8 @@ [(author PDF)](http://cnel.ufl.edu/files/1389293595.pdf) ![](http://cnel.ufl.edu/~ajbrockmeier/metric/test.png) + +### MATLAB code: +#### [cmd_simple_test.m](./cmd_simple_test.m) is a test script +#### [CAML.m](./CAML.m) is the metric learning optimization algorithm for a batch of data. +#### [CAML_approx.m](./CAML_approx.m) uses mini-batches for large datasets with vectors in Euclidean space.