Extreme Classification in Healthcare

>>It’s my great pleasure to introduce David Sontag as our first speaker
of the session. And, actually, I first met him in New York when
we were both with NYU. I now notice that
you are now with MIT as an Assistant Professor. Already when I was with NYU, I was very impressed by how he applies very interesting innovative
machine learning techniques to a domain that we probably, now or at some point, care about and
that is healthcare. And, yeah, I’m looking forward to see what you’re
working on these days.>>Well, thanks very much for the organizers for inviting me. In fact, they didn’t invite me. They invited Anna Choromanska
to give this talk. They invited me to
give another talk, but Anna unfortunately
couldn’t make it today. She has a young baby at home. And so today I’ll be taking her place to
talk about this work. It’s joint work between myself, Yacine Jernite who was
a PhD student of mine at NYU, and Anna Choromanska. And, reall, y Yacine, the chunk of these slides,
he couldn’t make it either, but really the credit
should go to him. This work had appeared
in our ICML 2017 paper. So, the high-level questions
are as follows. How can we do extreme
classification quickly? This is the topic which has been discussed already
quite a bit this morning. And for us, we’ve been very much motivated by
language modeling, although the techniques I’ll
tell you about today will be really applicable to many other settings
such as classification. So, language modeling, we
would like given some text to learn a distribution over text such that
we could model, for example, draw
the new sentence from the model, or we could ask given
some existing words what’s the likelihood of
seeing some of the next word. So, let’s look at
this simple example here. The rats scared the cat. So, if we had just seen
the first four words, the rat scared the, we might ask what
is the likelihood of the next word in that sentence,
in this case, cat. Well, by the chain rule, any set of words, w1 through wn, the distribution
over them could be factorized as
probability of word one, probability of word
two given word one, probability of word
three given word one comma word two and so
on, just by the chain rule. So, in order to compute
each conditional likelihood, we need to get some distribution
over the next word, given all the previous words. Neural language models typically summarize those previous words
with some context vector, in this case, we’ll refer
to that as a v context. And it might, for example, be done by taking word embeddings for each of
the words in the sentence, pushing them through a recurrent neural network and looking at the hidden state
at the last word, which in this case is the. And that would summarize
all the previous words in some D-dimensional vector. And given that
D-dimensional vector, then the question is
what is that next word? We want to predict one class out of a class of size
the whole vocabulary, which could be in a millions. Now, often that
conditional distribution of that last word p of cat
given the context v, the rat scared the, would be parametrized
with a softmax function. So it would look
something like this, where we say that for
every word in the vocabulary, we have some vector. So, that’s here denoted as a uw, which is also D-dimensional, and then the conditional
distribution of w given the context is
just given to you by a renormalization of
the exponential dot-product of the context with
that word vector uw, renormalized so that we have a valid distribution
over all possible words. And this is nothing new. This
is the standard approach to language modelling that’s used across natural
language processing. Now, to compute
that numerator, for example, that exponential term, it
takes running time order D, where D is the dimensionality of that context and word embedding. But in order to compute the full conditional
distribution, we need that
denominator as well. So, it’s going to be order D time to compute
each of these terms, and then the denominator is
going to be order D times v, the vocabulary size
running time, because we have to sum
over all possible words in the vocabulary to get a valid distribution over
the vocabulary. And that’s the problem.
That ends up being one of the main bottlenecks for learning
neural language models. So, there have been
a number of approaches in the NLP community for trying to address this bottleneck
during training. You might have heard of
things like negative sampling, self-normalizing models,
and hierarchical softmax. And in our work, we’re
going to be focused on that last approach,
hierarchical softmax. And one of the really
is that it gives us a valid language model
so we can exactly compute the probability of
any next word in the sentence. And it’s going to be fast, not just at training but also at test time that some of the other approaches
might not have. So, what is this hierarchical softmax that I’m talking about? Well, the key idea is rather than describing the distribution over all of the words
using a softmax, we’re going to predict first. We’re going to, in essence, follow the trajectory
through a tree, and only when we
get to the leaf of the tree will we know what is the word that
we want to predict. And at each step, we’re going to be
just making binary, or in our work, it’ll
be M-ary decisions. And so the total running time
in order to make a prediction is going to be much smaller than
the vocabulary size, potentially logarithmic
in the vocabulary size. Let’s see how this works. Let’s stick with
the previous example. We want to figure out,
given that context, what is the likelihood of
the next word being cat. Well, the first thing
we’re going to do is we’re going to look at
the root of the tree, and we’re going to ask what is the likelihood of
branching, in this case, it’s a three-area tree, so what’s the likelihood
of branching to n2 as being the next step
in this procedure. And I’m going to be asking about each one of these steps
until you get to the leaf corresponding
to the label cat. So, we’re going to factorize. We’re going to stipulate that the distribution of cat
given the context is now going to be factorized as each decision has to be
made to get to that leaf. So first, probability
of n2 given n1, then probability of n4 given n2, then probability
of cat given n4. And if you multiply each of
the probabilities together, you get the overall likelihood of getting to the label cat. You can now imagine
a generative process where we want to
sample a new word, given this conditional context. We’re going to sample from that first distribution
and sort of work our way through that tree
to get to the leaf. So, the running time
then goes from, well, at each decision, it
takes running time D times m where m is again the branching factor
of this tree. D is the dimensionality
of the context. And if the tree
is of depth log v, the vocabulary size, meaning
it’s a balanced tree, then the running time
is now logarithmic in the vocabulary size
in order to sample from the conditional distribution
or to estimate the conditional likelihood
of a particular label, in this case, the label again corresponds to a word
in the vocabulary. So, this hierarchical softmax
is widely used in NLP, but the question which
everyone in the community always asks is what
should that tree be? What is its structure?
Where is it coming from? And there have been a number of approaches to
try to address that question. The first one has been
one of clustering. So, for example, you might use off-the-shelf algorithm
like Word2vec to learn some word embeddings and
then do some clustering of these word embeddings to
get a candidate tree and then train the parameters of that tree or maybe
iterate this procedure in some way in order to try to find trees that are
semantically meaningful, that words that might have similar meanings maybe
are in similar locations along the bottom of the tree. So, that’s one approach. Another approach
that’s widely used is based on Huffman coding, which is completely wacky,
except that it works. So, this Huffman coding idea
says that we’re going to completely
ignore the semantics of what these labels are. We’re just going to try to learn a tree structure such
that very frequent label, and this won’t be
a balanced tree, very frequently occurring
labels are high up in the tree, and not so frequently incurring labels are low in the tree, so that if you are to completely ignore
the contexts and just ask the question of what’s the expected number of branches you need to
do to get to a word, it’s as low as possible. And, finally, in even more
recent work, for example, we’ve used this in my own lab, people have just gone
for random hierarchies. Just giving up on
anything and just saying, “Okay, we’re going to
randomly create a tree.” So, for example, in my work, I’ve used a service square root v tree where we
have a square root v branching and then
it’s not very deep. So we just randomly
assign words to the leafs and learn a model, and that also seems to
work reasonably well. But this leaves open
the question of, “Okay, these things work
reasonably well, but what is there
still to be gained?” If we were to actually try to think hard about
this problem and learn that tree structure,
could we do better? And this is one of the first works to really
address that problem. And the conclusion, I’ll
just tell you that now, the conclusion is that there is a substantial gap in performance between either the
clustering-based approaches or the Huffman coding-based
approaches and the flat softmax. And by cleverly learning
that tree structure, we completely eliminate
that gap, that’s the conclusion, while still attaining most of the computational speedups.
So, how do we get there? Well, there’s been quite a bit of
earlier work on learning tree structures in
similar settings notably by Choromanska and
John Langford sitting here. And that itself built on uneven earlier work by
currents and mentor. But, they considered
settings where their contacts were unchanging. So, for example
if you think about a typical bag-of-words
classification setting, where we have a fixed feature
vector representation, we’re just trying
to learn how to predict from a large number of classes what
the classes should be. Then those algorithms
would apply. But here, the story’s
a little bit different. Because, remember when you’re
learning a language model, you’re not just learning
this tree structure, but you’re also learning the recurrent neural network
parameters as well, right? So, you have to learn
the word embedding. You have to learn the parameters
of your gates and so on. So, what that means is
that the representation of your contacts are changing as well during
your learning algorithm. And what we don’t want to do is a two-step algorithm
where you fix the representation and learn
the tree and then alternate. What we want to ask
in this work is is there a way to jointly
learn everything? And so we’re going to give
an algorithm to do so. We’re going to provide very
weak theoretical guarantees. So, we have some theory
in our paper, but really it leaves
a lot to be desired. But we do show that this
works very well in practice. And we’re going to
evaluate this in both the language
modeling setting, and in an extreme
classification setting. So, our overall approach
will actually be a little bit
similar to how you would tackle learning
a decision tree, right? So, if you think about
learning decision tree, you might think about some criteria for how you
would build that decision tree. You think about
things like purity or accuracy as measures for
how to choose the next split. And you might build
that decision tree in a top-down fashion. We’re going to give you
the same sort of objective here except it’s going to be a little bit
different to take into consideration some of
the subtleties of this problem. And our learning algorithm won’t be a greedy
top-down algorithm. It’s going to have steps in it of fitting the language model completely redoing
the tree and alternating. And what we’re going
to show is that using this objective
leads to good trees. So, here’s now where I get
into some of the details. So, first I never told
you how we branch. So, let’s first look at
the setting of just two classes. We’ll call it the class for the word lost and
the one for time. Suppose, we have some data
with here the data consists of sentences that give you relevant contacts
for what she would like to predict that next word. And let’s suppose
that next word was lost or time. So, if we just had
two classes then our tree is just going to be
a binary tree of depth one. Where we only have
to make one decision go left or go right, in order to get to one
of these two classes. And so there one could use just a sigmoid function
where you take that context, and you hit it with this linear classifier
parameterized by a weight vector H, and that gives you
some distribution. According to
the segment or logistic over which classes the right
class for that context. And if you had more than
two classes that is to say if rather than
a binary this were an M area tree then you’ll
just see the softmax instead. And so that then gives you
the probability of each one of these classes given
the relevant context. So, now suppose that you have more than
two classes and again, for this example,
I’m going to stick with M equals two
so a binary tree. But, in our experiments
would go with M up to 65. So, now the story is
going to be as follows. Consider the
following data point. “I ate some” gives
you some context, this is how we do prediction. So, not learning but just prediction supposing
that we have learned the model. We start at the root. The root node is associated with some parameters in this case
since it’s a binary tree. There’s just a single hyperplane we’re using linear splits. At that hyperplane tells us, by taking the dot-product
of the hyperplane with a context factor and decides
to go left or go right. Then here say you go left, you make the similar decision
now for node n2, node n4 and
eventually you get to your prediction which
is the label pie, right? So, then the likelihood of pie is just given to
you by the product of each one of those softmaxes of the decisions needed to
make to reach that label. So, now let’s ask the question of “How do we
learn that tree structure?”. Well, we gave an objective
function which will quantify how good the split
at any one node is. And then, our overall objective
function is going to be the sum of those
per node objectives. So, now I’ll tell you what each of those per
node objectives is. For each node n in your tree, again this is an M-ary tree
with k over all classes, where k, remember
could be something like if this was vocabulary, it would be vocabulary size. So, for each node in
the tree we’re going to look to see; first, what are the distribution of classes that reach
that node in the tree. So, if you were to imagine a prediction where
each data points are follows its way down. In this case we’re
going to look at the path from the root to the leaf corresponding to the correct label for
that data points class. And so we’re going
to look at all the data points that reach some corresponding
intermediate node along the path from the root
to its correct label. And that will give you
the distribution Q I. And obviously if
this distribution Q I is very peaked to one class
then you’re basically there. You basically know what
the right class is. And if this isn’t
really a uniform distribution then
you’re very uncertain. And your goal here of course, just like with
learning any type of decision tree is to make
this BSP desk possible. Meaning you’re making progress towards figuring out
what the right class is. And so in this example
here we have four classes, and this might be the root. And then we say
this Q just gives us the distribution
of data points that reached that corresponding
node with that class. Now, the next quantity
of relevance is going to be measuring. Now that you know which nodes, which examples which labels
reached this node, where they sent to
next, all right? Which child is or
is it going to? And that’s going
to be essential for characterizing how good is
the splits at this node. And for this we’re
going to measure that by this conditional
distribution P of J given I, where the conditional we’re conditioning on examples
of each class I. We’re going to just look at
data points of some class, and then ask what is
the distribution of data points of that class
that go to each child. And intuitively again
we want this to be, we want each of these
conditional distributions to be as peaked as possible which means that
we’re sort of routing the data points
in an optimal way. And here where we’re using a softmax to make
those routing decisions, this is just given to you by
the following expectation. And finally we can put
those two pieces together. P of J given I and Q of I to get P of J which is
simply the average, the fraction of
examples that go to child I of this
corresponding node. Now, this is just a summary of the three definitions
I gave you, and using those three
definitions now, we have our overall objective which is given on
the very bottom here. So, the objective for node n
is a sum over the Q I. Where again Q I is summing over all possible
classes or labels. And then a sum over
the children of that node of P J minus P
of J given I, all right? And so, we want this to
be as high as possible where as high as possible means for this average quantity, we would like all of that mass to be routed towards
a single child for each class. Now, this has a number
of properties. So, for example what
one can show is, one can define this notion
of purity alpha n, and one can show that this objective function
correlates with good purity. Namely, that if you were to maximize this objective
as much as possible it would lead toward
splits that are pure. Moreover, one could characterize another notion of balance. How much are we balancing across the classes
along the tree. And once again one can
show that you have to maximize this objective
as much as possible, it also leads to good balance. So, in our paper we have some theoretical results which motivate those two
properties I just gave. And also show that
there exist trees, that are not too large, that satisfy these properties. So, our overall learning
algorithm is then as follows. We’re going to have-
this is going to be a batch SGD algorithm and for each batch size
of examples, we’re going to completely
redo the tree structure. And we’re going to find- we have a greedy algorithm which
chooses the new tree structure, that is to say relabeling
each of the leaves with new classes that we show is greedily maximizing
the derivative of the sum of
these note objectives with respect to
our model parameters, subject to well forming
these constraints. And so for each batch
we’re going to first completely redo
the tree structure, then we’re going to
take gradient steps with respect to the model parameters, so that recurrent network
the word embeddings and also the split parameters. Now, importantly each
one of these steps both choosing
the new tree structure and model optimizing
the parameters of your language model are all optimizing the same
objective function. And so that really starts to
get at this goal of having a unified objective that
deals with both the fact that the representations
are changing and the fact that we want to
find really good trees. So, here’s an example
of what it looks like. Suppose we see a new example
for this class was, we look in the current tree
to see what leaf it is that. We compute the gradient of this objective for
every node along the path from the root
to its true label. We can update nodes statistics, we take a gradient step and then we update
the tree structure. And so here, you’ll see as
global change like this, where pi moves from
the far right hand side. We do a new split, we get this new tree structure
and we iterate. So, in just the
remaining two minutes, I’m going to talk a bit about how this algorithm
does in practice. So, first we looked at language modelling problem where we took a large corpora from
the Gutenberg Project, 50 million word corpus. We use a vocabulary
of size 250,000, so there are 250,000 classes here we’re trying to predict. And for this work, we used a really simple log bi-linear n-gram neural language model. We learned m equals
65-ary tree of depth three and what I’m showing here is
a little bit hard to see, but what I’m showing is
the tree structure that’s learned and visualized in
terms of what are the classes, the most frequent classes that are underneath each node of the tree to try to get a sense of whether
we’re actually learning, whether there’s any interesting
semantic grouping of the classes that’s learned by this algorithm and
indeed we see that so. For example, on the far right we see that the class is this, these, those and so on
are grouped together. Over here we see
the classes put, set, lay, cut are grouped
together and so on. But, more importantly than the semantics is really
how it does in practice. It’s a real shame that
these plots aren’t showing, but I’ll walk you through
what they’re saying. This part on the left is the x-axis is epoch’s
of training and the y-axis is perplexity
or how about like or a measure of how
about likelihood and which we want to be
as small as possible. So, this red line that
you see up here is the held-out perplexity
of using a random tree. This blue line that you
can see pretty clearly is the held-out perplexity
of using a flat softmax. So, it’s going to be much
slower to train and use a test time because of the fact that it’s not
hierarchically structured. And you see that there exists
a pretty big gap there. Although, I’m not
showing you it here the clustering tree type of approach is actually worse
than all of these lines. The green line which is
almost impossible to see, I’m going to show you it
by following my cursor. It initially begins
by somewhat being similar to the random hierarchy, until roughly the second epoch
of training where it settles on a good
tree structure. And from there on out, it’s learning curve is identical to that of the flat softmax. And so what we see here is that by learning
the tree structure, in fact we’re able
to fully recover the quality of model
that we’re able to get using just a flat softmax. Next, we looked at
a text classification problem where we actually
modeled it almost identically to the density estimation
setting previously considered, with only difference being
that here to do prediction, we don’t want to sample we actually want to
do map inference. So, for prediction
here, we’re going to be doing a map inference for the class which we do by simple branch and bound with a depth-first search through
this hierarchical tree, in order to predict
for each new example. We evaluated on this flicker
type classification setting, where the input is an image like this and your goal
is to figure out what are the tags that
this input corresponds to. Now most people use this dataset to try to predict
the tag directly from the image or to
do something about relating damage to the caption. I want to first
out that we could have used this approach from
the image itself. All right. So, one could imagine using our algorithm to
simultaneously learn the parameters of
a convolutional neural network and a hierarchical softmax
to predict the class. That would be a good use case. We wanted to stick with language because there’s
something that my student is more interested in
and so we didn’t actually use the image
itself rather we use the text of
the caption and ask we predict the tags from
the texts of the caption, using a simple bag-of-words embedding where
we’re simultaneously learning the word embeddings
at the same time. And indeed the same story
applies here, where the learned tree
that I’m showing you here, here for the areas 5 and 20
obtains precision at one, which is substantially better than you could have gotten with a Huffman tree or a couple of other baseline algorithms
that we compare to. What’s particularly
interesting however is that, as the dimensionality
of your word embeddings increases
that gap seems to shrink. So, I think that’s something that’ll be interesting for us to think
about in the future. In particular what
this motivates is that, maybe this approach of
learning hierarchical softmax, we’re learning the structure
maybe that’s most important in settings where you’re more resource
constrained. Where you can’t afford to have very large vocabularies and so the decisions
that are- sorry, you can’t afford to have
very large dimensionalities and so the decisions
that are made at each step of the tree are
hard and because they’re hard, learning the tree structure
is more important. So, I’ll just conclude by
saying this is one of the first works that show how to learn the tree structure of hierarchical softmax together with the underlying
representations. There are ton of open questions particularly on
the theoretical side and there are also some
interesting empirical work to be done and
some new directions there. So, just to spear some interest, there’s really
no reason why we had to use a tree-structured model. One could imagine
using a forest, one could imagine allowing each class to appear
a couple of times rather than just
once in the tree and one could even
imagine using the sort of same structure
learning algorithm for other deep learning
applications. So, thank you.>>Okay, thank you. So, we have maybe time only for one or two quick questions
if there are any. Yeah maybe this
one and then yeah.>>So the question
is about applying hierarchical softmax for
multi-label classification.>>Naively as could
be used there as well. I don’t know of
the best way to do it. I had to think about it. That’s a really
interesting question. David’s asking whether
the structure could be reused or whatever original thoughts
is that the structure itself might be interesting from a linguistic point of view. Particularly on
language modeling setting. I don’t think
there’s evidence from extra results to
justify that statement. A second question is whether the structure
could be reused for other predictive problems and I think the short answer
is yes if the way that the representations relate to the labels that you’re
predicting is generalizable.>>Okay. Let’s
thank our speaker. 