Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions CAML.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
function eta = CAML(logeta_init,DD,labels,Q)
% Copyright 2015 Austin J. Brockmeier
% ajbrockmeier at the domain of gmail.com

logbase=10;
[NN,d_x]=size(DD);
N=sqrt(NN);
if N^2~=NN
error('reshaped distance matrix is not square');
end
L=bsxfun(@eq,labels(:),labels(:)');
L=L/trace(L);
oN=ones(N,1);
avec=oN/N;
nu=L*avec;
Lc=bsxfun(@minus,bsxfun(@minus,L,nu),nu')+avec'*nu;
objfungrad=@(X) my_objfungrad(X,DD,N,d_x,logbase,Lc);
options=[];
options.Display='off';
if Q>1
x0=logeta_init*(ones(d_x,Q)+(rand(d_x,Q)-.5)*.5);
else
x0=logeta_init*(ones(d_x,Q));
end

[logeta,~]=minFunc(objfungrad,x0(:),options);
eta=reshape(logbase.^logeta,d_x,[]);
end

function [f, gradf]= my_objfungrad(logeta,DD,N,d_x,logbase,Lc)
eta=reshape(logbase.^logeta,d_x,[]);
Ks=exp(-DD*eta);%this is by far the slowest part (N^2 D)
K=reshape(sum(Ks,2),N,N);
oN=ones(N,1);
avec=oN/N;
mu=K*avec;
Kc=bsxfun(@minus,bsxfun(@minus,K,mu),mu')+avec'*mu;
if any(isnan(K(:)))
f=NaN;
gradf=0*logeta;
else
trKL=oN'*(Kc.*Lc)*oN;% trKL=trace(K*H*L*H);%where H is centering matrix
trKK=oN'*(Kc.^2)*oN;% trKK=trace(K*H*K*H);
f=-real(log(trKL)-log(trKK)/2);
Grad_lk=Lc/trKL;%Lc=H*L*H
Grad_k=Kc/trKK;%Kc=H*K*H
Grad = Grad_lk-Grad_k;
P = bsxfun(@times,Grad(:),Ks);
gradf=reshape(real(P'*DD)',[],1)*log(logbase).*logbase.^logeta;
end

end

61 changes: 61 additions & 0 deletions CAML_approx.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
function eta = CAML_approx(X,labels,Q,logeta_init,Npos,niter)
% Copyright 2015 Austin J. Brockmeier
% ajbrockmeier at the domain of gmail.com

logbase=10;
Nbatch=Npos*2;
H=eye(Nbatch)-1/Nbatch;
d_x=size(X,2);

theta0=logeta_init*(ones(d_x,Q)+cat(2,zeros(d_x,1),(rand(d_x,Q-1)-.5)*.1));
for ii=1:niter
m=randi(numel(labels),1);
idx_pos=find(labels==labels(m) & (1:numel(labels))'~=m );
idx_neg=find(labels~=labels(m));
if numel(idx_pos)>0
m_pos=idx_pos(randi(numel(idx_pos),Npos-1,1));
else
m_pos=m;
end
m_neg=idx_neg(randi(numel(idx_neg),Npos,1));
XX=X([m;m_pos;m_neg],:);
labs=labels([m;m_pos;m_neg]);
L=bsxfun(@eq,labs,labs');
[~,gradf]=sum_grad(theta0,XX,Nbatch,d_x,logbase,H*L*H);
theta0=theta0-.01*gradf;
end
eta=reshape(logbase.^theta0,d_x,[]);

end


function [f, gradf]= sum_grad(logeta,X,N,d_x,logbase,Lc)

Xt1=kron(ones(N,1),X);
Xt2=kron(X,ones(N,1));
DD=bsxfun(@minus,Xt1,Xt2).^2;
eta=reshape(logbase.^logeta,d_x,[]);
Ks=exp(-DD*eta);
K=reshape(sum(Ks,2),N,N);
oN=ones(N,1);
avec=oN/N;
mu=K*avec;
Kc=bsxfun(@minus,bsxfun(@minus,K,mu),mu')+avec'*mu;

if any(isnan(K(:)))
fprintf('whoa')
f=NaN;
gradf=0*logeta;
else
trKL=oN'*(Kc.*Lc)*oN;% trKL=trace(K*H*L*H);%where H is centering matrix
trKK=oN'*(Kc.^2)*oN;% trKK=trace(K*H*K*H);
trLL=oN'*(Lc.^2)*oN;% trLL=trace(L*H*L*H);

Grad = (Lc*sqrt(trKK*trLL)-trKL*trLL*Kc*(trKK*trLL)^-.5)/(trKK*trLL);
P = bsxfun(@times,Grad(:),Ks);
gradf=real(P'*DD)'*log(logbase).*logbase.^logeta;
f=-real(trKL/sqrt(trKK*trLL));
end

end

5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,8 @@
[(author PDF)](http://cnel.ufl.edu/files/1389293595.pdf)

![](http://cnel.ufl.edu/~ajbrockmeier/metric/test.png)

### MATLAB code:
#### [cmd_simple_test.m](./cmd_simple_test.m) is a test script
#### [CAML.m](./CAML.m) is the metric learning optimization algorithm for a batch of data.
#### [CAML_approx.m](./CAML_approx.m) uses mini-batches for large datasets with vectors in Euclidean space.
131 changes: 131 additions & 0 deletions cmd_simple_test.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
% A simple script to run the metric learning code
% Austin J. Brockmeier (2013-2015)
% ajbrockmeier at the domain of gmail.com

fprintf('Reminder to add minFunc to path\n');
fprintf('via addpath(genpath(...))\n');

%User defined parameters
Qs=[1 1];%number of kernels in the sum kernel when Q is greater than 1 this is a nonlinear kernel
run_approx=[1 0];%run CAML approx instead (useful when the number of sample is too big)
%approx has less storage but is much slower since 10000 iterations are run
if numel(Qs)~=numel(run_approx)%numel(Qs) should be equal to numel(run_approx)
error('Please have same number of parameter choices\n');
end
%Classification paradigm
prop_test=1/3;
prop_valid=1/3;% this is currently unused (in the paper I use it to pick best neighborhood)
nmonte = 20;
Nrep=1; % number of replicate runs
knn=3;%neighborhood size for knn

for dataset={'Two Gaussians','4 Gaussians (XOR)'};
% Dataset generation
N=200;
useful_dim=2;%only the first dimensions are meaningful
extra_dim=28;
N1=round(N/2);
N2=N-N1;
switch dataset{1}
case 'xor'
all_labels=cat(1,ones(N1,1),-ones(N2,1));
centers=[1 1; -1 -1; 1 -1; -1 1];
vals=[1 -1];
idx=cat(1,randi(2,N1,1),2+randi(2,N2,1));
all_centers=cat(2,centers(idx,:),vals(randi(2,N,extra_dim)));
X=.5*randn(N,useful_dim+extra_dim)+all_centers;
otherwise
all_labels=cat(1,ones(N1,1),-ones(N2,1));
X=randn(N,useful_dim+extra_dim)+kron(all_labels, [ones(1,useful_dim) zeros(1,extra_dim)]);
end
P=size(X,2);
N = size(X,1);
X=bsxfun(@rdivide,X,sqrt(sum(X.^2)));%make variance constant!
% %% precompute Euclidean distances
DDall=zeros(N,N,P);
for ii=1:P %calculate all of the distance matrices
G=X(:,ii)*X(:,ii)';
Di=-2*G+bsxfun(@plus,diag(G),diag(G)');
DDall(:,:,ii)= Di;%reshape(Di,N*N,[]);
end


orig_results = zeros(nmonte,2);
new_results = zeros(nmonte,2*numel(Qs));
unit_weights=cell(nmonte,numel(Qs));

Diso=sum(DDall,3);% original distances
for mmm = 1:nmonte
%set up test and train partitions
sortii=cell(2,1);
N_keep=round(N*(1-prop_test-prop_valid));
N_keep2=round(N*(1-prop_test));
outoforder=randperm(N);
sortii{1} = outoforder(1:N_keep);
sortii{3} = outoforder(1+N_keep:N_keep2);
sortii{2} = outoforder(1+N_keep2:end);

labels=all_labels(sortii{1});
test_labels=all_labels(sortii{2});
cv_labels=all_labels(sortii{3});

%Original distance
[~,iii]=sort(Diso(sortii{1},sortii{2}));%un weighted original
acc_1nn=mean(labels(iii(1,:))==test_labels); %1NN
acc_knn=mean(mode(labels(iii(1:knn,:)))'==test_labels);%kNN
orig_results(mmm,:)=[acc_1nn acc_knn];

%CAML
DD=reshape(DDall(sortii{1},sortii{1},:),numel(sortii{1})^2,[]);
w_norm=mean(DD,1);% average squared distance
DD=bsxfun(@rdivide,DD,w_norm); %normalize distances
arez=[];
for qii=1:numel(Qs)
Q=Qs(qii);
if run_approx(qii)==1
eta = CAML_approx(X(sortii{1},:),labels,Q,-1,4,10000);
else
eta = CAML(-1,DD,labels,Q);
end
w=bsxfun(@rdivide,eta,w_norm');%,[1 3 2]);
%Compute new distance
if Q==1
Dnew=sqrt(sum(bsxfun(@times,DDall,permute(w,[3 2 1])),3));
else
Knew=reshape(sum(exp(-reshape(DDall(:,:,:),N^2,[])*w),2),N,N);
K1=diag(Knew)*ones(1,size(Knew,1));
Dnew = sqrt(-2*Knew+K1+K1');
end
[~,iii]=sort(Dnew(sortii{1},sortii{2}));
acc_1nn=mean(labels(iii(1,:))==test_labels);%1NN
acc_knn=mean(mode(labels(iii(1:knn,:)))'==test_labels); %knn
arez=cat(2,arez,[acc_1nn acc_knn]);
unit_weights{mmm,qii}=w;
end%end-Q
new_results(mmm,:)=arez;
end % end-Monte Carlo
% Output results
M=[orig_results,new_results];
rez=mean(M);
rez2=std(M);
method_str=cell(size(M,2),1);
method_str(1:2)={'Orig. 1NN','Orig. kNN'};%,'CAML 3NN','CAML SVM'};
for ii=1:numel(Qs)
if Qs(ii)==1
postfix='';
else
postfix=sprintf(' Q=%i',Qs(ii));
end
if run_approx(ii)==1
prefix='~';
else
prefix='';
end
method_str{3+2*(ii-1)}=[prefix,'CAML 1NN',postfix];
method_str{4+2*(ii-1)}=[prefix,'CAML kVM',postfix];
end
fprintf('%s\n Accuracy (%c correct)\n',dataset{1},'%');
for ii=1:numel(method_str)
fprintf('%s:%s%ią%.1f\n',method_str{ii},[char(9),char(9)],round(rez(ii)*100),100*rez2(ii));
end
end