>>It’s my great pleasure to introduce David Sontag as our first speaker

of the session. And, actually, I first met him in New York when

we were both with NYU. I now notice that

you are now with MIT as an Assistant Professor. Already when I was with NYU, I was very impressed by how he applies very interesting innovative

machine learning techniques to a domain that we probably, now or at some point, care about and

that is healthcare. And, yeah, I’m looking forward to see what you’re

working on these days.>>Well, thanks very much for the organizers for inviting me. In fact, they didn’t invite me. They invited Anna Choromanska

to give this talk. They invited me to

give another talk, but Anna unfortunately

couldn’t make it today. She has a young baby at home. And so today I’ll be taking her place to

talk about this work. It’s joint work between myself, Yacine Jernite who was

a PhD student of mine at NYU, and Anna Choromanska. And, reall, y Yacine, the chunk of these slides,

he couldn’t make it either, but really the credit

should go to him. This work had appeared

in our ICML 2017 paper. So, the high-level questions

are as follows. How can we do extreme

classification quickly? This is the topic which has been discussed already

quite a bit this morning. And for us, we’ve been very much motivated by

language modeling, although the techniques I’ll

tell you about today will be really applicable to many other settings

such as classification. So, language modeling, we

would like given some text to learn a distribution over text such that

we could model, for example, draw

the new sentence from the model, or we could ask given

some existing words what’s the likelihood of

seeing some of the next word. So, let’s look at

this simple example here. The rats scared the cat. So, if we had just seen

the first four words, the rat scared the, we might ask what

is the likelihood of the next word in that sentence,

in this case, cat. Well, by the chain rule, any set of words, w1 through wn, the distribution

over them could be factorized as

probability of word one, probability of word

two given word one, probability of word

three given word one comma word two and so

on, just by the chain rule. So, in order to compute

each conditional likelihood, we need to get some distribution

over the next word, given all the previous words. Neural language models typically summarize those previous words

with some context vector, in this case, we’ll refer

to that as a v context. And it might, for example, be done by taking word embeddings for each of

the words in the sentence, pushing them through a recurrent neural network and looking at the hidden state

at the last word, which in this case is the. And that would summarize

all the previous words in some D-dimensional vector. And given that

D-dimensional vector, then the question is

what is that next word? We want to predict one class out of a class of size

the whole vocabulary, which could be in a millions. Now, often that

conditional distribution of that last word p of cat

given the context v, the rat scared the, would be parametrized

with a softmax function. So it would look

something like this, where we say that for

every word in the vocabulary, we have some vector. So, that’s here denoted as a uw, which is also D-dimensional, and then the conditional

distribution of w given the context is

just given to you by a renormalization of

the exponential dot-product of the context with

that word vector uw, renormalized so that we have a valid distribution

over all possible words. And this is nothing new. This

is the standard approach to language modelling that’s used across natural

language processing. Now, to compute

that numerator, for example, that exponential term, it

takes running time order D, where D is the dimensionality of that context and word embedding. But in order to compute the full conditional

distribution, we need that

denominator as well. So, it’s going to be order D time to compute

each of these terms, and then the denominator is

going to be order D times v, the vocabulary size

running time, because we have to sum

over all possible words in the vocabulary to get a valid distribution over

the vocabulary. And that’s the problem.

That ends up being one of the main bottlenecks for learning

neural language models. So, there have been

a number of approaches in the NLP community for trying to address this bottleneck

during training. You might have heard of

things like negative sampling, self-normalizing models,

and hierarchical softmax. And in our work, we’re

going to be focused on that last approach,

hierarchical softmax. And one of the really

nice things about this hierarchical softmax

is that it gives us a valid language model

so we can exactly compute the probability of

any next word in the sentence. And it’s going to be fast, not just at training but also at test time that some of the other approaches

might not have. So, what is this hierarchical softmax that I’m talking about? Well, the key idea is rather than describing the distribution over all of the words

using a softmax, we’re going to predict first. We’re going to, in essence, follow the trajectory

through a tree, and only when we

get to the leaf of the tree will we know what is the word that

we want to predict. And at each step, we’re going to be

just making binary, or in our work, it’ll

be M-ary decisions. And so the total running time

in order to make a prediction is going to be much smaller than

the vocabulary size, potentially logarithmic

in the vocabulary size. Let’s see how this works. Let’s stick with

the previous example. We want to figure out,

given that context, what is the likelihood of

the next word being cat. Well, the first thing

we’re going to do is we’re going to look at

the root of the tree, and we’re going to ask what is the likelihood of

branching, in this case, it’s a three-area tree, so what’s the likelihood

of branching to n2 as being the next step

in this procedure. And I’m going to be asking about each one of these steps

until you get to the leaf corresponding

to the label cat. So, we’re going to factorize. We’re going to stipulate that the distribution of cat

given the context is now going to be factorized as each decision has to be

made to get to that leaf. So first, probability

of n2 given n1, then probability of n4 given n2, then probability

of cat given n4. And if you multiply each of

the probabilities together, you get the overall likelihood of getting to the label cat. You can now imagine

a generative process where we want to

sample a new word, given this conditional context. We’re going to sample from that first distribution

and sort of work our way through that tree

to get to the leaf. So, the running time

then goes from, well, at each decision, it

takes running time D times m where m is again the branching factor

of this tree. D is the dimensionality

of the context. And if the tree

is of depth log v, the vocabulary size, meaning

it’s a balanced tree, then the running time

is now logarithmic in the vocabulary size

in order to sample from the conditional distribution

or to estimate the conditional likelihood

of a particular label, in this case, the label again corresponds to a word

in the vocabulary. So, this hierarchical softmax

is widely used in NLP, but the question which

everyone in the community always asks is what

should that tree be? What is its structure?

Where is it coming from? And there have been a number of approaches to

try to address that question. The first one has been

one of clustering. So, for example, you might use off-the-shelf algorithm

like Word2vec to learn some word embeddings and

then do some clustering of these word embeddings to

get a candidate tree and then train the parameters of that tree or maybe

iterate this procedure in some way in order to try to find trees that are

semantically meaningful, that words that might have similar meanings maybe

are in similar locations along the bottom of the tree. So, that’s one approach. Another approach

that’s widely used is based on Huffman coding, which is completely wacky,

except that it works. So, this Huffman coding idea

says that we’re going to completely

ignore the semantics of what these labels are. We’re just going to try to learn a tree structure such

that very frequent label, and this won’t be

a balanced tree, very frequently occurring

labels are high up in the tree, and not so frequently incurring labels are low in the tree, so that if you are to completely ignore

the contexts and just ask the question of what’s the expected number of branches you need to

do to get to a word, it’s as low as possible. And, finally, in even more

recent work, for example, we’ve used this in my own lab, people have just gone

for random hierarchies. Just giving up on

anything and just saying, “Okay, we’re going to

randomly create a tree.” So, for example, in my work, I’ve used a service square root v tree where we

have a square root v branching and then

it’s not very deep. So we just randomly

assign words to the leafs and learn a model, and that also seems to

work reasonably well. But this leaves open

the question of, “Okay, these things work

reasonably well, but what is there

still to be gained?” If we were to actually try to think hard about

this problem and learn that tree structure,

could we do better? And this is one of the first works to really

address that problem. And the conclusion, I’ll

just tell you that now, the conclusion is that there is a substantial gap in performance between either the

clustering-based approaches or the Huffman coding-based

approaches and the flat softmax. And by cleverly learning

that tree structure, we completely eliminate

that gap, that’s the conclusion, while still attaining most of the computational speedups.

So, how do we get there? Well, there’s been quite a bit of

earlier work on learning tree structures in

similar settings notably by Choromanska and

John Langford sitting here. And that itself built on uneven earlier work by

currents and mentor. But, they considered

settings where their contacts were unchanging. So, for example

if you think about a typical bag-of-words

classification setting, where we have a fixed feature

vector representation, we’re just trying

to learn how to predict from a large number of classes what

the classes should be. Then those algorithms

would apply. But here, the story’s

a little bit different. Because, remember when you’re

learning a language model, you’re not just learning

this tree structure, but you’re also learning the recurrent neural network

parameters as well, right? So, you have to learn

the word embedding. You have to learn the parameters

of your gates and so on. So, what that means is

that the representation of your contacts are changing as well during

your learning algorithm. And what we don’t want to do is a two-step algorithm

where you fix the representation and learn

the tree and then alternate. What we want to ask

in this work is is there a way to jointly

learn everything? And so we’re going to give

an algorithm to do so. We’re going to provide very

weak theoretical guarantees. So, we have some theory

in our paper, but really it leaves

a lot to be desired. But we do show that this

works very well in practice. And we’re going to

evaluate this in both the language

modeling setting, and in an extreme

classification setting. So, our overall approach

will actually be a little bit

similar to how you would tackle learning

a decision tree, right? So, if you think about

learning decision tree, you might think about some criteria for how you

would build that decision tree. You think about

things like purity or accuracy as measures for

how to choose the next split. And you might build

that decision tree in a top-down fashion. We’re going to give you

the same sort of objective here except it’s going to be a little bit

different to take into consideration some of

the subtleties of this problem. And our learning algorithm won’t be a greedy

top-down algorithm. It’s going to have steps in it of fitting the language model completely redoing

the tree and alternating. And what we’re going

to show is that using this objective

leads to good trees. So, here’s now where I get

into some of the details. So, first I never told

you how we branch. So, let’s first look at

the setting of just two classes. We’ll call it the class for the word lost and

the one for time. Suppose, we have some data

with here the data consists of sentences that give you relevant contacts

for what she would like to predict that next word. And let’s suppose

that next word was lost or time. So, if we just had

two classes then our tree is just going to be

a binary tree of depth one. Where we only have

to make one decision go left or go right, in order to get to one

of these two classes. And so there one could use just a sigmoid function

where you take that context, and you hit it with this linear classifier

parameterized by a weight vector H, and that gives you

some distribution. According to

the segment or logistic over which classes the right

class for that context. And if you had more than

two classes that is to say if rather than

a binary this were an M area tree then you’ll

just see the softmax instead. And so that then gives you

the probability of each one of these classes given

the relevant context. So, now suppose that you have more than

two classes and again, for this example,

I’m going to stick with M equals two

so a binary tree. But, in our experiments

would go with M up to 65. So, now the story is

going to be as follows. Consider the

following data point. “I ate some” gives

you some context, this is how we do prediction. So, not learning but just prediction supposing

that we have learned the model. We start at the root. The root node is associated with some parameters in this case

since it’s a binary tree. There’s just a single hyperplane we’re using linear splits. At that hyperplane tells us, by taking the dot-product

of the hyperplane with a context factor and decides

to go left or go right. Then here say you go left, you make the similar decision

now for node n2, node n4 and

eventually you get to your prediction which

is the label pie, right? So, then the likelihood of pie is just given to

you by the product of each one of those softmaxes of the decisions needed to

make to reach that label. So, now let’s ask the question of “How do we

learn that tree structure?”. Well, we gave an objective

function which will quantify how good the split

at any one node is. And then, our overall objective

function is going to be the sum of those

per node objectives. So, now I’ll tell you what each of those per

node objectives is. For each node n in your tree, again this is an M-ary tree

with k over all classes, where k, remember

could be something like if this was vocabulary, it would be vocabulary size. So, for each node in

the tree we’re going to look to see; first, what are the distribution of classes that reach

that node in the tree. So, if you were to imagine a prediction where

each data points are follows its way down. In this case we’re

going to look at the path from the root to the leaf corresponding to the correct label for

that data points class. And so we’re going

to look at all the data points that reach some corresponding

intermediate node along the path from the root

to its correct label. And that will give you

the distribution Q I. And obviously if

this distribution Q I is very peaked to one class

then you’re basically there. You basically know what

the right class is. And if this isn’t

really a uniform distribution then

you’re very uncertain. And your goal here of course, just like with

learning any type of decision tree is to make

this BSP desk possible. Meaning you’re making progress towards figuring out

what the right class is. And so in this example

here we have four classes, and this might be the root. And then we say

this Q just gives us the distribution

of data points that reached that corresponding

node with that class. Now, the next quantity

of relevance is going to be measuring. Now that you know which nodes, which examples which labels

reached this node, where they sent to

next, all right? Which child is or

is it going to? And that’s going

to be essential for characterizing how good is

the splits at this node. And for this we’re

going to measure that by this conditional

distribution P of J given I, where the conditional we’re conditioning on examples

of each class I. We’re going to just look at

data points of some class, and then ask what is

the distribution of data points of that class

that go to each child. And intuitively again

we want this to be, we want each of these

conditional distributions to be as peaked as possible which means that

we’re sort of routing the data points

in an optimal way. And here where we’re using a softmax to make

those routing decisions, this is just given to you by

the following expectation. And finally we can put

those two pieces together. P of J given I and Q of I to get P of J which is

simply the average, the fraction of

examples that go to child I of this

corresponding node. Now, this is just a summary of the three definitions

I gave you, and using those three

definitions now, we have our overall objective which is given on

the very bottom here. So, the objective for node n

is a sum over the Q I. Where again Q I is summing over all possible

classes or labels. And then a sum over

the children of that node of P J minus P

of J given I, all right? And so, we want this to

be as high as possible where as high as possible means for this average quantity, we would like all of that mass to be routed towards

a single child for each class. Now, this has a number

of properties. So, for example what

one can show is, one can define this notion

of purity alpha n, and one can show that this objective function

correlates with good purity. Namely, that if you were to maximize this objective

as much as possible it would lead toward

splits that are pure. Moreover, one could characterize another notion of balance. How much are we balancing across the classes

along the tree. And once again one can

show that you have to maximize this objective

as much as possible, it also leads to good balance. So, in our paper we have some theoretical results which motivate those two

properties I just gave. And also show that

there exist trees, that are not too large, that satisfy these properties. So, our overall learning

algorithm is then as follows. We’re going to have-

this is going to be a batch SGD algorithm and for each batch size

of examples, we’re going to completely

redo the tree structure. And we’re going to find- we have a greedy algorithm which

chooses the new tree structure, that is to say relabeling

each of the leaves with new classes that we show is greedily maximizing

the derivative of the sum of

these note objectives with respect to

our model parameters, subject to well forming

these constraints. And so for each batch

we’re going to first completely redo

the tree structure, then we’re going to

take gradient steps with respect to the model parameters, so that recurrent network

the word embeddings and also the split parameters. Now, importantly each

one of these steps both choosing

the new tree structure and model optimizing

the parameters of your language model are all optimizing the same

objective function. And so that really starts to

get at this goal of having a unified objective that

deals with both the fact that the representations

are changing and the fact that we want to

find really good trees. So, here’s an example

of what it looks like. Suppose we see a new example

for this class was, we look in the current tree

to see what leaf it is that. We compute the gradient of this objective for

every node along the path from the root

to its true label. We can update nodes statistics, we take a gradient step and then we update

the tree structure. And so here, you’ll see as

global change like this, where pi moves from

the far right hand side. We do a new split, we get this new tree structure

and we iterate. So, in just the

remaining two minutes, I’m going to talk a bit about how this algorithm

does in practice. So, first we looked at language modelling problem where we took a large corpora from

the Gutenberg Project, 50 million word corpus. We use a vocabulary

of size 250,000, so there are 250,000 classes here we’re trying to predict. And for this work, we used a really simple log bi-linear n-gram neural language model. We learned m equals

65-ary tree of depth three and what I’m showing here is

a little bit hard to see, but what I’m showing is

the tree structure that’s learned and visualized in

terms of what are the classes, the most frequent classes that are underneath each node of the tree to try to get a sense of whether

we’re actually learning, whether there’s any interesting

semantic grouping of the classes that’s learned by this algorithm and

indeed we see that so. For example, on the far right we see that the class is this, these, those and so on

are grouped together. Over here we see

the classes put, set, lay, cut are grouped

together and so on. But, more importantly than the semantics is really

how it does in practice. It’s a real shame that

these plots aren’t showing, but I’ll walk you through

what they’re saying. This part on the left is the x-axis is epoch’s

of training and the y-axis is perplexity

or how about like or a measure of how

about likelihood and which we want to be

as small as possible. So, this red line that

you see up here is the held-out perplexity

of using a random tree. This blue line that you

can see pretty clearly is the held-out perplexity

of using a flat softmax. So, it’s going to be much

slower to train and use a test time because of the fact that it’s not

hierarchically structured. And you see that there exists

a pretty big gap there. Although, I’m not

showing you it here the clustering tree type of approach is actually worse

than all of these lines. The green line which is

almost impossible to see, I’m going to show you it

by following my cursor. It initially begins

by somewhat being similar to the random hierarchy, until roughly the second epoch

of training where it settles on a good

tree structure. And from there on out, it’s learning curve is identical to that of the flat softmax. And so what we see here is that by learning

the tree structure, in fact we’re able

to fully recover the quality of model

that we’re able to get using just a flat softmax. Next, we looked at

a text classification problem where we actually

modeled it almost identically to the density estimation

setting previously considered, with only difference being

that here to do prediction, we don’t want to sample we actually want to

do map inference. So, for prediction

here, we’re going to be doing a map inference for the class which we do by simple branch and bound with a depth-first search through

this hierarchical tree, in order to predict

for each new example. We evaluated on this flicker

type classification setting, where the input is an image like this and your goal

is to figure out what are the tags that

this input corresponds to. Now most people use this dataset to try to predict

the tag directly from the image or to

do something about relating damage to the caption. I want to first

out that we could have used this approach from

the image itself. All right. So, one could imagine using our algorithm to

simultaneously learn the parameters of

a convolutional neural network and a hierarchical softmax

to predict the class. That would be a good use case. We wanted to stick with language because there’s

something that my student is more interested in

and so we didn’t actually use the image

itself rather we use the text of

the caption and ask we predict the tags from

the texts of the caption, using a simple bag-of-words embedding where

we’re simultaneously learning the word embeddings

at the same time. And indeed the same story

applies here, where the learned tree

that I’m showing you here, here for the areas 5 and 20

obtains precision at one, which is substantially better than you could have gotten with a Huffman tree or a couple of other baseline algorithms

that we compare to. What’s particularly

interesting however is that, as the dimensionality

of your word embeddings increases

that gap seems to shrink. So, I think that’s something that’ll be interesting for us to think

about in the future. In particular what

this motivates is that, maybe this approach of

learning hierarchical softmax, we’re learning the structure

maybe that’s most important in settings where you’re more resource

constrained. Where you can’t afford to have very large vocabularies and so the decisions

that are- sorry, you can’t afford to have

very large dimensionalities and so the decisions

that are made at each step of the tree are

hard and because they’re hard, learning the tree structure

is more important. So, I’ll just conclude by

saying this is one of the first works that show how to learn the tree structure of hierarchical softmax together with the underlying

representations. There are ton of open questions particularly on

the theoretical side and there are also some

interesting empirical work to be done and

some new directions there. So, just to spear some interest, there’s really

no reason why we had to use a tree-structured model. One could imagine

using a forest, one could imagine allowing each class to appear

a couple of times rather than just

once in the tree and one could even

imagine using the sort of same structure

learning algorithm for other deep learning

applications. So, thank you.>>Okay, thank you. So, we have maybe time only for one or two quick questions

if there are any. Yeah maybe this

one and then yeah.>>So the question

is about applying hierarchical softmax for

multi-label classification.>>Naively as could

be used there as well. I don’t know of

the best way to do it. I had to think about it. That’s a really

interesting question. David’s asking whether

the structure could be reused or whatever original thoughts

is that the structure itself might be interesting from a linguistic point of view. Particularly on

language modeling setting. I don’t think

there’s evidence from extra results to

justify that statement. A second question is whether the structure

could be reused for other predictive problems and I think the short answer

is yes if the way that the representations relate to the labels that you’re

predicting is generalizable.>>Okay. Let’s

thank our speaker.