here's a recording; check it out!
and here's a pdf of the slides
]]>eurosat/all is a dataset of 27,000 64x64 satellite images taken with 13 spectral bands. each image is labelled one of ten classes.
for the purpose of classification these 13 aren't equally useful, and the information in them varies across resolutions. if we were designing a sensor we might choose to use different channels in different resolutions.
how can we explore the trade off between mixed resolutions and whether to use a channel at all?
let's start with a simple baseline to see what performance we get. we won't spend too much time on this model, we just want something that we can iterate on quickly.
the simple model shown below trained on a 'training' split with adam hits 0.942 top 1 accuracy on a 2nd 'validation' split in 5 epochs. that'll do for a start.
let's check the effect of including different combos of input channels. we'll do so by introducing a channel mask.
a mask of all ones denotes using all channels and gives our baseline performance
mask | validation accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] | 0.942 |
a mask of all zeros denotes using no channels and acts as a sanity check; it gives the performance of random chance which is in line with what we expect give the balanced training set. (note: we standardise the input data so that it has zero mean per channel (with the mean, standard deviation parameters fit against training data only) so we can get this effect)
mask | validation accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] | 0.942 |
[0,0,0,0,0,0,0,0,0,0,0,0,0] | 0.113 |
but what about if we drop just one channel? i.e. a mask of all ones except for a single zero.
channel to drop | validation accuracy |
0 | 0.735 |
1 | 0.528 |
2 | 0.661 |
3 | 0.675 |
4 | 0.809 |
5 | 0.724 |
6 | 0.749 |
7 | 0.634 |
8 | 0.874 |
9 | 0.934 |
10 | 0.593 |
11 | 0.339 |
12 | 0.896 |
from this we can see that the performance hit we get from losing a single channel is not always the same. in particular consider channel 11; if we drop that channel we get a huge hit! does that mean that if we keep only 11 that should give reasonable performance?
mask | validation accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] | 0.942 |
[0,0,0,0,0,0,0,0,0,0,0,0,0] | 0.113 |
[0,0,0,0,0,0,0,0,0,0,0,1,0] (keep just 11) | 0.260 |
bbbzzzttt (or other appropriate annoying buzzer noise). channel 11 is contributing to the classification but it's not being used independently. in general this is exactly the behaviour we want from a neural network but what should we do to explore the effect of not having this dependence?
consider using a dropout idea, just with input channels instead of intermediate nodes.
what behaviour do we get if we drop channels out during training? i.e. with 50% probability we replace an entire input channel with 0s?
things take longer to train and we get a slight hit in accuracy...
dropout? | validation accuracy |
no | 0.942 |
yes | 0.934 |
...but now when we mask out one channel at a time we don't get a big hit for losing any particular one.
channel to drop | validation accuracy | |
no dropout | with dropout | |
0 | 0.735 | 0.931 |
1 | 0.528 | 0.931 |
2 | 0.661 | 0.936 |
3 | 0.675 | 0.935 |
4 | 0.809 | 0.937 |
5 | 0.724 | 0.934 |
6 | 0.749 | 0.931 |
7 | 0.634 | 0.927 |
8 | 0.874 | 0.927 |
9 | 0.934 | 0.927 |
10 | 0.593 | 0.927 |
11 | 0.339 | 0.933 |
12 | 0.896 | 0.937 |
now that we have a model that is robust to any combo of channels what do we see if we use a simple genetic algorithm (GA) to evolve the channel mask to use with this pre trained network? a mask that represents using all channels will be the best right? right?
we'll evolve the GA using the network trained above but based on it's performance on a 3rd "ga_train" split using the inverse loss as a fitness function.
amusingly the GA finds that a mask of [1,1,0,1,0,0,0,1,1,0,1,0,1]
does better marginally better than all channels, but only uses 1/2 of them!
mask | split | accuracy |
[1,1,1,1,1,1,1,1,1,1,1,1,1] (all) | ga_validate | 0.934 |
[1,1,0,1,0,0,0,1,1,0,1,0,1] (ga) | ga_validate | 0.936 |
important note: we can imagine the best performance overall would be to have the GA evolve not the channels to use from this model, but the channels to use when training from scratch. this would though require a lot more model training, basically a full training cycle per fitness evaluation :( in the approach we describe here we only have to train a single model and then have the GA just run inference.
taking the idea of channel selection a step further, what if we got the GA to not only decide whether to use a channel or not, but also what resolution it should be in?
consider some example images across resolutions....
example images (just RGB channels shown) | |||
orig x64 | x32 | x16 | x8 |
we could then weight the use of a channel based on resolution; the higher the resolution the more the channel "costs" to use, with not using the channel at all being "free".
to support this we can change the GA to represent members not as a string of {0, 1}s but instead a sequence of {0, x8, x16, x32, x64} values per channel where these represent...
resolution | description | channel cost |
x64 | use original (64, 64) version of input | 0.8 |
x32 | use a 1/2 res (32, 32) version of input | 0.4 |
x16 | use a 1/4 res (16, 16) version of input | 0.2 |
x8 | use a 1/8 res (8, 8) version of input | 0.1 |
0 | don't use channel | 0 |
the change in the encoding of our GA is trivial, just 5 values per channel instead of 2, but before we look at that; how do we change our network?
we can do it without having to add too many extra parameters by using the magic of fully convolutional networks :)
notice how the main trunk of our first network was a series of 2d convolutions with a global spatial mean. this network will simply take as input all the resolutions we need! we can simply reuse it multiple times!
so we can have our network...
note: try as i might i can't get steps 2 to 4 to run parallelised in a pmap. asked on github about it and looks to be something you can't do at the moment.
when we consider channel cost vs loss there is no single best solution, it's a classic example of a pareto front where we see a tradeoff between the channel_cost and loss.
consider this sampling of 1,000 random channel masks...
the GA needs to operate with a fitness that's a single scalar; for now we just use
a simple combo of (1.0 / loss) - channel_cost
running with this fitness function we evolve the solution
[x16, x64, x64, x16, x32, ignore, x8, x64, x8, ignore, x8, ignore, x32]
it's on the pareto front, as we'd hope, and it's interesting that it includes a mix of resolutions including ignoring 3 channels completely :)
different mixings of loss and channel_cost would result in different GA solutions along the front
all the code is on github
]]>a classic piece of advice i hear for people using tpus is that they should "try using a bigger batch size".
this got me thinking; i wonder how big a batch size i could reasonably use? how would the optimisation go? how fast could i get things?
let's train a model on the eurosat/rgb dataset. it's a 10 way classification problem on 64x64 images
with a training split of 80% we have 21,600 training examples. we'll use another 10% for validation (2,700 images) and just not use the final 10% test split #hack
for the model we'll use a simple stack of convolutions with channel sizes 32, 64, 128 and 256, a stride of 2 for spatial reduction all with gelu activation. after the convolutions we'll do a simple global spatial pooling, a single 128d dense layer with gelu and then a 10d logit output. a pretty vanilla architecture of ~400K params. nothing fancy.
a v3-32 tpu pod slice is 4 hosts, each with 8 tpu devices.
21,600 training examples total => 5,400 examples per host => 675 examples per device.
this number of images easily fits on a device. great.
now usually augmentation is something we do randomly per batch, but for this hack we're interested in seeing how big a batch we can run. so why not fill out the dataset a bit by just running a stack of augmentations before training?
for each image we'll do 90, 180 and 270 deg rotations along with left/right flips for a total of 8 augmented images for each original image. e.g.....
this gives us now 172,800 images total => 43,200 per host => 5,400 per tpu device. which stills fits no problem.
side note: turns out doing this augmentation was one of the most fun parts of this hack :) see this tweet thread for some more info on how i used nested pmaps and vmaps to do it!
one motivation i had for this hack was to compare adam to lamb. i'd seen lamb referred to in the past, would it perform better for this model/dataset size? turns out it does! a simple sweep comparing lamb, adam and sgd shows lamb consistently doing the best. definitely one to add to the tuning mix from now on.
not only does the augmented data fit sharded across devices but we can replicate both the model parameters and the optimiser state as well. this is important for speed since the main training loop doesn't have to do any host/device communication. taking a data parallel approach means the only cross device comms is a gradient psum.
for training we run an inner loop just pumping the param = update(params)
step.
an outer loop runs the inner loop 100 times before doing a validation accuracy check.
the inner loop runs at 1.5s for the 100 iterations and since each iteration is a forward & backwards pass for all 172,800 images across all hosts that's 11M images processed per second. ðŸ”¥ðŸ”¥ðŸ”¥
at this speed the best result of 0.95 on validation takes 13 outer loops; i.e. all done in under 20s. o_O !!
when reviewing runs i did laugh to see sgd with momentum make a top 10 entry.
new t-shirt slogan: "sgd with momentum; always worth a try"
all the code in hacktastic undocumented form on github
]]>this 4 (and a bit) part tute series starts with
jax
fundamentals, builds up to describing a data parallel approach to training on a
cloud tpu pod slice, and
finishes with a tpu pod slice implementation of
ensemble nets....
all with the goal of solving 1d y=mx+b
and though it may seem like a bit of overkill it turns out it's a good example to work through so that we can focus on the library support without having to worry about the modelling.
in this first section we introduce some jax fundamentals; e.g. make_jaxpr, grad, jit, vmap & pmap.
colab: 01 pmap jit vmap oh my.ipynb
in part 2 we use the techniques from part 1 to solve y=mx+b
in pure jax. we'll also
introduce
pytrees
and various
tree_utils
for manipulating them.
we run first on a single device and work up to using pmap to demonstrate a simple data parallelism approach. along the way we'll do a small detour to a tpu pod slice to illustrate the difference in a multi host setup.
( note: the experience as described here for a pod slice isn't publically available yet; but sign up via the JAX on Cloud TPU Interest Form to get more info. see also this JAX on Cloud TPUs (NeurIPS 2020) talk )
colab: 02 y mx b on a tpu.ipynb
next we introduce haiku as a way of defining our model and optax as a library to provide standard optimisers. to illustrate there use we'll do a minimal port of our model and training loop to use them.
colab: 03 y mx b in haiku.ipynb
in part 4 we'll reimplement ensemble nets for this trivial model, continuing to do things in a way that supports running on a tpu pod slice.
colab: 04 y mx b haiku ensemble.ipynb
to wrap up we acknowledge that though tpu pod slices and data parallel approaches are fun we could have just solved this in a single calculation using the normal equation... :D
colab: 05 booooooooooooooooring.ipynb
y=mx+b
!!!OOD detection is an often overlooked part of many projects. it's super important to understand when your model is operating on data it's not familiar with!
though this can be framed as the problem of detecting that an input differs from training data, discussed more in this 'data engineering concerns for machine learning products' talk, for this post we're going to look at the problem more from the angle of the model needing to be able to express it's own lack of confidence.
a core theme of this post is treating OOD as a function of the output of the model. when i've played with this idea before the best result i've had is from using entropy as a proxy of confidence. in the past it's not be in the context of OOD directly but instead as a way of prioritising data annotation (i.e. uncertainty sampling for active learning)
using entropy as a stable measure of confidence though requires a model be well calibrated.
neural nets are sadly notorious for not being well calibrated under the default ways we train these days.
the first approach i was shown to calibrate a neural net is to train your model as normal, then finetune just the final classifier layer using a held out set; the general idea being that this is effectively just a logistic regression on features learnt by the rest of the network and so should come with the better guarantees that logistic regression provides us regarding calibration (see this great overview from sklearn) whereas this has always worked well for me it does require two steps to training, as well as the need for a seperate held out set to be managed.
another apporach to calibration i revisited this week comes from this great overview paper from 2017 on calibration of modern neural networks. it touches on platt scaling, which i've also used in the past successfully, but includes something even simpler i didn't really notice until it was pointed out to me by a colleague; don't bother tuning the entire last layer, just fit a single temperature rescaling of the output logits i.e. what can be thought of as a single parameter version of platt scaling. what a great simple idea!
the main purpose of this post though is to reproduce some ideas around out of distribution and calibration when you use focal loss. this was described in this great piece of work calibrating deep neural networks using focal loss. we'll implement that, but the other two as well for comparison.
( as an aside, another great simple way of determining confidence is using an ensemble! i explore this idea a bit in my post on ensemble nets where i train ensembles as a single model using jax vmap )
let's start with cifar10 and split the data up a bit differently than the standard splits...
train
and test
splits.
automobile
& cat
as a "hard" ood
set ( i've intentionally chosen these two since i know they cause model confusion as observed in this keras.io tute on metric learning ).
train
and 0.1 for each of validation
, calibration
and test
.
we'll additionally generate an "easy" ood
set which will just be random images.
in terms of how we'll use these splits...
train
will be the main dataset to train against.
validation
will be used during training for learning rate stepping, early stopping, etc.
calibration
will be used when we want to tune a model in some way for calibration.
test
will be used for final held out analysis.
ood_hard
and ood_easy
will be used as the representative out of distribution sets.
examples images | |
in distribution | |
out of distribution (hard) | |
out of distribution (easy) |
model_1
, on train
, tuning against validate
.
model_2
by finetuning the last layer of model_1
against calibrate
; i.e. the logisitic regression approach to calibration.
model_3
by just fitting a scalar temperture of last layer of model_1
; i.e. the temperture scaling approach.
train
using validate
for tuning.
at each step we'll check the entropy distributions across the various splits, including the easy and hard ood sets
for a baseline we'll use a
resnet18 objax model
trained against train
using validate
for simple early stopping.
# define model model = objax.zoo.resnet_v2.ResNet18(in_channels=3, num_classes=10) # train against all model vars trainable_vars = model.vars()
$ python3 train_basic_model.py --loss-fn cross_entropy --model-dir m1 learning_rate 0.001 validation accuracy 0.372 entropy min/mean/max 0.0005 0.8479 1.9910 learning_rate 0.001 validation accuracy 0.449 entropy min/mean/max 0.0000 0.4402 1.9456 learning_rate 0.001 validation accuracy 0.517 entropy min/mean/max 0.0000 0.3566 1.8013 learning_rate 0.001 validation accuracy 0.578 entropy min/mean/max 0.0000 0.3070 1.7604 learning_rate 0.001 validation accuracy 0.680 entropy min/mean/max 0.0000 0.3666 1.7490 learning_rate 0.001 validation accuracy 0.705 entropy min/mean/max 0.0000 0.3487 1.6987 learning_rate 0.001 validation accuracy 0.671 entropy min/mean/max 0.0000 0.3784 1.8276 learning_rate 0.0001 validation accuracy 0.805 entropy min/mean/max 0.0000 0.3252 1.8990 learning_rate 0.0001 validation accuracy 0.818 entropy min/mean/max 0.0000 0.3224 1.7743 learning_rate 0.0001 validation accuracy 0.807 entropy min/mean/max 0.0000 0.3037 1.8580
we can check the distribution of entropy values of this trained model against our various splits
some observations...
recall that a higher entropy => a more uniform distribution => less prediction confidence.
as such our goal is to have the entropy of the ood
set as high as possible without
dropping the test
set too much.
as mentioned before the main approach i'm familiar with is retraining the classifier
layer. we'll start with model_1
and use the so-far-unseen calibration
set for fine tuning it.
# define model model = objax.zoo.resnet_v2.ResNet18(in_channels=3, num_classes=10) # restore from model_1 objax.io.load_var_collection('m1/weights.npz', model.vars()) # train against last layer only classifier_layer = model[-1] trainable_vars = classifier_layer.vars()
$ python3 train_model_2.py --input-model-dir m1 --output-model-dir m2 learning_rate 0.001 calibration accuracy 0.808 entropy min/mean/max 0.0000 0.4449 1.8722 learning_rate 0.001 calibration accuracy 0.809 entropy min/mean/max 0.0002 0.5236 1.9026 learning_rate 0.001 calibration accuracy 0.808 entropy min/mean/max 0.0004 0.5468 1.9079 learning_rate 0.001 calibration accuracy 0.811 entropy min/mean/max 0.0005 0.5496 1.9214 learning_rate 0.001 calibration accuracy 0.812 entropy min/mean/max 0.0001 0.5416 1.9020 learning_rate 0.001 calibration accuracy 0.812 entropy min/mean/max 0.0002 0.5413 1.9044 learning_rate 0.001 calibration accuracy 0.811 entropy min/mean/max 0.0003 0.5458 1.8990 learning_rate 0.001 calibration accuracy 0.814 entropy min/mean/max 0.0001 0.5472 1.9002 learning_rate 0.001 calibration accuracy 0.812 entropy min/mean/max 0.0001 0.5455 1.9124 learning_rate 0.001 calibration accuracy 0.813 entropy min/mean/max 0.0001 0.5503 1.9092 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0001 0.5485 1.9011 learning_rate 0.0001 calibration accuracy 0.814 entropy min/mean/max 0.0001 0.5476 1.9036 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0001 0.5483 1.9051 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5479 1.9038 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5482 1.9043 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5482 1.9059 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5479 1.9068 learning_rate 0.0001 calibration accuracy 0.813 entropy min/mean/max 0.0000 0.5480 1.9058 learning_rate 0.0001 calibration accuracy 0.812 entropy min/mean/max 0.0000 0.5488 1.9040 learning_rate 0.0001 calibration accuracy 0.814 entropy min/mean/max 0.0000 0.5477 1.9051
if we compare (top1) accuracy of model_1
vs model_2
we see a slight improvement
in model_2
, attributable to the slight extra training i suppose (?)
$ python3 calculate_metrics.py --model m1/weights.npz train accuracy 0.970 validate accuracy 0.807 calibrate accuracy 0.799 test accuracy 0.803 $ python3 calculate_metrics.py --model m2/weights.npz train accuracy 0.980 validate accuracy 0.816 calibrate accuracy 0.814 test accuracy 0.810
more importantly though; how do the entropy distributions looks?
some observations...
in model_3
we create a simple single parameter layer that represents rescaling,
append it to the pretrained resnet and train just that layer.
class Temperature(objax.module.Module): def __init__(self): super().__init__() self.temperature = objax.variable.TrainVar(jnp.array([1.0])) def __call__(self, x): return x / self.temperature.value # define model model = objax.zoo.resnet_v2.ResNet18(in_channels=3, num_classes=10) # restore from model_1 objax.io.load_var_collection('m1/weights.npz', model.vars()) # add a temp rescaling layer temperature_layer = layers.Temperature() model.append(temperature_layer) # train against just this layer trainable_vars = temperature_layer.vars()
$ python3 train_model_3.py --input-model-dir m1 --output-model-dir m3 learning_rate 0.01 temp 1.4730 calibration accuracy 0.799 entropy min/mean/max 0.0006 0.5215 1.9348 learning_rate 0.01 temp 1.5955 calibration accuracy 0.799 entropy min/mean/max 0.0013 0.5828 1.9611 learning_rate 0.01 temp 1.6118 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5911 1.9643 learning_rate 0.01 temp 1.6162 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5933 1.9652 learning_rate 0.01 temp 1.6118 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5911 1.9643 learning_rate 0.01 temp 1.6181 calibration accuracy 0.799 entropy min/mean/max 0.0015 0.5943 1.9655 learning_rate 0.01 temp 1.6126 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5915 1.9645 learning_rate 0.01 temp 1.6344 calibration accuracy 0.799 entropy min/mean/max 0.0016 0.6025 1.9687 learning_rate 0.01 temp 1.6338 calibration accuracy 0.799 entropy min/mean/max 0.0016 0.6022 1.9686 learning_rate 0.01 temp 1.6018 calibration accuracy 0.799 entropy min/mean/max 0.0013 0.5860 1.9623 learning_rate 0.001 temp 1.6054 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5878 1.9630 learning_rate 0.001 temp 1.6047 calibration accuracy 0.799 entropy min/mean/max 0.0013 0.5874 1.9629 learning_rate 0.001 temp 1.6102 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5903 1.9640 learning_rate 0.001 temp 1.6095 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5899 1.9638 learning_rate 0.001 temp 1.6112 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5908 1.9642 learning_rate 0.001 temp 1.6145 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5924 1.9648 learning_rate 0.001 temp 1.6134 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5919 1.9646 learning_rate 0.001 temp 1.6144 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5924 1.9648 learning_rate 0.001 temp 1.6105 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5904 1.9640 learning_rate 0.001 temp 1.6151 calibration accuracy 0.799 entropy min/mean/max 0.0014 0.5927 1.9650
some observations...
model_1
. this is since the temp rescaling never changes the ordering of predictions. ( just this property alone could be enough reason to use this approach in some cases! )
model_2
in terms of raising the entropy of ood
!
not bad for fitting a single parameter :D
as a final experiment we'll use focal loss for training instead of cross entropy
note: i rolled my own focal loss function which is always dangerous! it's not impossible (i.e. it's likely) i've done something wrong in terms of numerical stability; please let me know if you see a dumb bug :)
def focal_loss_sparse(logits, y_true, gamma=1.0):
log_probs = jax.nn.log_softmax(logits, axis=-1)
log_probs = log_probs[jnp.arange(len(log_probs)), y_true]
probs = jnp.exp(log_probs)
elementwise_loss = -1 * ((1 - probs)**gamma) * log_probs
return elementwise_loss
note that focal loss includes a gamma. we'll trial 1.0, 2.0 and 3.0 in line with the experiments mentioned in calibrating deep neural networks using focal loss
$ python3 train_basic_model.py --loss-fn focal_loss --gamma 1.0 --model-dir m4_1 learning_rate 0.001 validation accuracy 0.314 entropy min/mean/max 0.0193 1.0373 2.0371 learning_rate 0.001 validation accuracy 0.326 entropy min/mean/max 0.0000 0.5101 1.8546 learning_rate 0.001 validation accuracy 0.615 entropy min/mean/max 0.0000 0.6680 1.8593 learning_rate 0.001 validation accuracy 0.604 entropy min/mean/max 0.0000 0.5211 1.8107 learning_rate 0.001 validation accuracy 0.555 entropy min/mean/max 0.0001 0.5127 1.7489 learning_rate 0.0001 validation accuracy 0.758 entropy min/mean/max 0.0028 0.6008 1.8130 learning_rate 0.0001 validation accuracy 0.780 entropy min/mean/max 0.0065 0.6590 1.9305 learning_rate 0.0001 validation accuracy 0.796 entropy min/mean/max 0.0052 0.6390 1.8589 learning_rate 0.0001 validation accuracy 0.790 entropy min/mean/max 0.0029 0.5886 1.9151
$ python3 train_basic_model.py --loss-fn focal_loss --gamma 2.0 --model-dir m4_2 learning_rate 0.001 validation accuracy 0.136 entropy min/mean/max 0.0001 0.2410 1.7906 learning_rate 0.001 validation accuracy 0.530 entropy min/mean/max 0.0000 0.8216 2.0485 learning_rate 0.001 validation accuracy 0.495 entropy min/mean/max 0.0003 0.7524 1.9086 learning_rate 0.001 validation accuracy 0.507 entropy min/mean/max 0.0000 0.6308 1.8460 learning_rate 0.001 validation accuracy 0.540 entropy min/mean/max 0.0000 0.5899 2.0196 learning_rate 0.001 validation accuracy 0.684 entropy min/mean/max 0.0001 0.6189 1.9029 learning_rate 0.001 validation accuracy 0.716 entropy min/mean/max 0.0000 0.6509 1.8822 learning_rate 0.001 validation accuracy 0.606 entropy min/mean/max 0.0001 0.7096 1.8312 learning_rate 0.0001 validation accuracy 0.810 entropy min/mean/max 0.0002 0.6208 1.9447 learning_rate 0.0001 validation accuracy 0.807 entropy min/mean/max 0.0004 0.6034 1.8420
$ python3 train_basic_model.py --loss-fn focal_loss --gamma 3.0 --model-dir m4_3 learning_rate 0.001 validation accuracy 0.288 entropy min/mean/max 0.0185 1.1661 1.9976 learning_rate 0.001 validation accuracy 0.319 entropy min/mean/max 0.0170 1.0925 2.1027 learning_rate 0.001 validation accuracy 0.401 entropy min/mean/max 0.0396 0.9795 2.0361 learning_rate 0.001 validation accuracy 0.528 entropy min/mean/max 0.0001 0.8534 1.9115 learning_rate 0.001 validation accuracy 0.528 entropy min/mean/max 0.0002 0.7405 1.8509 learning_rate 0.0001 validation accuracy 0.748 entropy min/mean/max 0.0380 0.9204 1.9697 learning_rate 0.0001 validation accuracy 0.760 entropy min/mean/max 0.0764 0.9807 1.9861 learning_rate 0.0001 validation accuracy 0.772 entropy min/mean/max 0.0906 0.9646 2.0556 learning_rate 0.0001 validation accuracy 0.781 entropy min/mean/max 0.0768 0.8973 1.9264 learning_rate 0.0001 validation accuracy 0.780 entropy min/mean/max 0.0871 0.8212 1.9814
some observations...
model_2
than model_1
in terms of calibration.
ood
data, but with a drop in overall validation
accuracy. maybe the seperation isn't as much as i'd like, but happy to experiment more given this result.
interesting!! focal loss is definitely going in my "calibration toolkit" :D
recall the distribution plot of entropies from the model_3
experiment
we can eyeball that the entropies values are low for in-distribution sets (train, validate, calibrate and test) and high for the two ood sets (easy and hard) but how do we actually turn this difference into a in-distribution classifier? ( recall this is not a standard supervised learning problem, we only have the in-distribution instances to train against and nothing from the ood sets at training time )
amusingly i spent a bunch of time hacking around with all sorts of density estimation techniques but it turned out the simplest baseline i started with was the best :/
the baseline works on a strong, but valid enough, assumption of the entropies; the entropies for in-distribution instances will be lower than the entropies for OOD instances. given this assumption we can just do the simplest "classifier" of all; thresholding
to get a sense of the precision/recall tradeoff of approach we can do the following...
train
set; this scales all train values to (0, 1)
test
, ood_easy
and ood_hard
; including (0, 1) clipping
x = 1.0 - x
, since we want a low entropy to denote a positive instance of in-distribution
test
instances as positives and ood
instances
as negatives.
with this simple approach we get the following ROC, P/R and reliability plots.
i've included the attempts i made to use a kernel density estimator in fit_kernel_density_to_entropy.py
so if you manage to get a result better than this trivial baseline please
let me know :)
though this post is about using entropy as a way of measuring OOD there are other ways too.
the simplest way i've handled OOD detection in the past is to include an explicit OTHER label. this has worked in cases where i've had data available to train against, e.g. data that was (somehow) sampled but wasn't of interest for actual supervised problem at hand. but i can see in lots of situations this wouldn't always be possible.
another approach i'm still hacking on is doing density estimation on the inputs; this related very closely to the diversity sampling ideas for active learning.
these experiments were written in objax, an object oriented wrapper for jax
]]>Mathematics for Machine Learning by Marc Peter Deisenroth, A. Aldo Faisal & Cheng Soon Ong.
this is my personal favorite book on the general math required for machine learning, the way things are described really resonate with me. available as a free pdf but i got a paper copy to support the authors after reading the first half.
Linear Algebra and Learning from Data by Gilbert Strang.
this is gilbert's most recent work. it's really great, he's such a good teacher, and his freely available lectures are even better. it's a shorter text than his other classic intro below with more of a focus on how things are connected to modern machine learning techniques.
Introduction to Linear Algebra by Gilbert Strang.
this was my favorite linear algebra book for a long time before his 'learning from data' came out. this is a larger book with a more comprehensive view of linear algebra.
Think Stats: Probability and Statistics for Programmers by Allen Downey.
this book focuses on practical computation methods for probability and statistics. i got a lot out of working through this one. it's all in python and available for free. ( exciting update! as part of writing this post i've discovered there's a new edition to read!)
Doing Bayesian Data Analysis by John Kruscgke
on the bayesian side of things this is the book i've most enjoyed working through. i've only got the first edition which was R and BUGS but i see the second edition is R, JAGS and Stan. it'd be fun i'm sure to work through it doing everything in numpyro. i might do that in all my free time. haha. "free time" hahaha. sob.
The Elements of Statistical Learning by Hastie, Tibshirani and Friedman
this is still one of the most amazing fundamental machine learning books i've ever had. in fact i've purchased this book twice and given it away both times :/ i might buy another copy some time soon, even though it's been freely available to download for ages. an amazing piece of work.
Probabilistic Graphical Models by Daphne Koller & Nir Friedman
this is an epic textbook that i'd love to understand better. i've read a couple of sections in detail but not the entire tome yet.
Pattern Recognition and Machine Learning by Christopher Bishop
this is probably the best overall machine learning text book i've ever read. such a beautiful book and the pdf is FREE FOR DOWNLOAD!!!
Machine Learning: A Probabilistic Perspective by Kevin Murphy
this is my second favorite general theory text on machine learning. i got kevin to sign my copy when he was passing my desk once but someone borrowed it and never gave it back :( so if you see a copy with my name on the spine let me know!
Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow by AurÃ©lien GÃ©ron
this is the book i point most people to when they are interested in getting up to speed with modern applied machine learning without too much concern for the theory. it's very up to date (as much as a book can be) with the latest libraries and, most importantly, provides a good overview of not just neural stuff but fundamental scikit-learn as well.
Machine Learning Engineering by Andriy Burkov
a great book focussing on the operations side of running a machine learning system. i'm a bit under half way through the free online version and very likely to buy a physical copy to finish it and support the author. great stuff and, in many ways, a more impactful book than any of the theory books here.
Introduction to Data Mining by Pang-Ning Tan, Michael Steinbach & Vipin Kumar
this is another one that was also on my list from ten years ago and though it's section on neural networks is a bit of chuckle these days there is still a bunch of really great fundamental stuff in this book. very practical and easy to digest. i also see there's a second edition now. i reckon this would compliment the "hands on" book above very well.
Speech and Language Processing by Dan Jurafsky & James Martin
still the best overview of NLP there is (IMHO). can't wait to read the 3rd edition which apparently will cover more modern stuff (e.g. transformers) but until then, for the love of god though, please don't be one of those "this entire book is irrelevant now! just fine tune BERT" people :/
Numerical Optimization by Jorge NocedalStephen J. Wright
this book is super hard core and maybe more an operations research book than machine learning. though i've not read it cover to cover the couple of bits i've worked through really taught me a lot. i'd love to understand the stuff in this text better; it's so so fundamental to machine learning (and more)
Deep Learning by Ian Goodfellow
writing a book specifically on deep learning is very dangerous since things move so fast but if anyone can do it, ian can... i think ian's approach to explaining neural networks from the ground up is one of my favorites. i got the first edition hardback but it's free to download from the website.
Probabilistic Robotics by Sebastian Thrun, Wolfram Burgard and Dieter Fox
when i first joined a robotics group i bought a stack of ML/robotics books and this was by far the best. it's good intro stuff, and maybe already dated in places given it's age (the 2006 edition i have) but i still got a bunch from it.
TinyML by Pete Warden & Daniel Situnayake
this was a super super fun book to tech review! neural networks on microcontrollers?!? yes please!
Evolutionary Computation by David Fogel
this is still by favorite book on evolutionary algorithms; i've had this for a loooong time now. i still feel like evolutionary approaches are due for a big big comeback any time soon....
the good thing about writing a list is you get people telling you cool ones you've missed :)
the top three i've chosen (that are in the mail) are...
Causal Inference in Statistics by Judea Pearl, Madelyn Glymour & Nicholas P. Jewell
recommended by animesh who quite rightly points out the lack of causality in machine learning books in the books above.
Information Theory, Inference and Learning Algorithms by David MacKay
i've seen this book mentioned a number of times and was most recently recommended by my colleague dane so it's time to get it.
]]>it's been about two years since i first saw the awesome very slow movie player project by bryan boyer. i thought it was such an excellent idea but never got around to buying the hardware to make one. more recently though i've seen a couple of references to the project so i decided it was finally time to make one.
one interesting concern about an eink very slow movie player is the screen refresh. simpler eink screens refresh by doing a full cycle of a screen of white or black before displaying the new image. i hated the idea of an ambient slow player doing this every few minutes as it switched frames, so i wanted to make sure i got a piece of hardware that could do incremental update.
after a bit of shopping around i settled on a 6 inch HD screen from waveshare
it ticks all the boxes i wanted
this screen also supports grey scale, but only with a flashy full cycle redraw, so i'm going to stick to just black and white since it supports the partial redraw.
note: even though the partial redraw is basically instant it does suffer from a ghosting problem; when you draw a white pixel over a black one things are fine, but if you draw black over white, in the partial redraw, you get a slight ghosting of gray that is present until a full redraw :/
so how do you display an image when you can only show black and white? dithering! here's an example of a 384x288 RGB image dithered using PILS implementation of the Floyd-Steinberg algorithm
original RGB vs dithered version |
it makes intuitive sense that you could have small variations in the exact locations of the dots as long as you get the densities generally right. s so there's a reasonable question then; how do you dither in such a way that you get a good result, but with minimal pixel changes from a previous frame? (since we're motivated on these screens to change as little as possible)
there are two approaches i see
1) spend 30 minutes googling for a solution that no doubt someone came up with 20 years ago that can be implemented in 10 lines of c running at 1000fps ...
2) .... or train an jax based GAN to generate the dithers with a loss balancing a good dither vs no pixel change. :P
when building a very slow movie player the most critical decision is... what movie to play? i really love the 1979 classic alien, it's such a great dark movie, so i thought i'd go with it. the movie is 160,000 frames so at a play back rate of a frame every 200 seconds it'll take just over a year to finish.
note that in this type of problem there is no concern around overfitting. we have access to all data going in and so it's fine to overfit as much as we like; as long as we're minimising whatever our objective is we're good to go.
i started with a unet that maps 3 channel RGB images to a single channel dither.
v1 architecture |
i tinkered a bit with the architecture but didn't spend too much time tuning it. for the final v3 result i ended with a pretty vanilla stack of encoders & decoders (with skip connections connecting an encoder to the decoder at the same spatial resolution) each encoder/decoder block uses a residual like shortcut around a couple of convolutions. nearest neighbour upsampling gave a nicer result than deconvolutions in the decoder for the v3 result. also, gelu is my new favorite activation :)
for v1 i used a binary cross entropy loss of P(white) per pixel ( since it's what worked well for my bee counting project )
as always i started by overfitting to a single example to get a baseline feel for capacity required.
v1 overfit result |
when scaling up to the full dataset i switched to training on half resolution images against a patch size of 128. working on half resolution consistently gave a better result than working with the full resolution.
as expected though this model gave us the classic type of problem we see with straight unet style image translation; we get a reasonable sense of the shapes, but no fine details around the dithering.
v1 vanilla unet with upsampling example |
side notes:
v1 vanilla unet with deconvolution example |
for v2 i added a GAN objective in an attempt to capture finer details
v2 architecture |
i started with the original pix2pix objective but reasonably quickly moved to use a wasserstein critic style objective since i've always found it more stable.
the generator (G) was the same as the unet above with the discriminator (D) running patch based. at this point i also changed the reconstruction loss from a binary objective to just L1. i ended up using batchnorm in D, but not G. to be honest i only did a little did of manual tuning, i'm sure there's a better result hidden in the hyperparameters somewhere.
so, for this version, the loss for G has two components
1. D(G(rgb)) # fool D 2. L1(G(rgb), dither) # reconstruct the dither
very quickly (i.e. in < 10mins ) we get a reasonable result that is started to show some more detail than just the blobby reconstruction.
v2 partial trained eg |
note: if the loss weight of 2) is 0 we degenerate to v1 (which proved a useful intermediate debugging step). at this point i didn't want to tune to much since the final v3 is coming...
for v3 we finally introduce a loss relating the previous frame (which was one of the main intentions of the project in the first place)
now G takes not just the RGB image, but the dither of the previous frame.
v3 architecture |
the loss for G now has three parts
1. D(G(rgb_t1)) => real # fool D 2. L1(G(rgb_t1), dither_t1) # reconstruct the dither 3. L1(G(rgb_t1), dither_t0) # don't change too much from the last frame
normally with a network that takes as input the same thing it's outputting we have to be careful to include things like teacher forcing. but since we don't intend to use this network for any kind of rollouts we can just always feed the "true" dithers in where required. having said that, rolling out the dithers from this network would be interesting :D
the third loss objective, not changing too many pixels from the last frame, works well for generally stationary shots but is disastrous for scene changes :/
consider the following graph for a sequence of frames showing the pixel difference between frames.
when there is a scene change we observe a clear "spike" in pixel diff. my first thought was to look for these and do a full redraw for them. it's very straightforward to find them (using a simple z-score based anomaly detector on a sliding window) but the problem is that it doesn't pick up the troublesome case of a panning shot where we don't have a scene change exactly. in these cases there is no abrupt scene change, but there are a lot of pixels changing so we end up seeing a lot of ghosting.
i spent ages tinkering with the best way to approach this before deciding that a simple
approach of num_pixels_changed_since_last_redraw > threshold
was good enough to decide
if a full redraw was required (with a cooldown to ensure we not redrawing all the time)
the v3 network gets a very good result very quickly; unsurprisingly since the dither at time t0 provided to G is a pretty good estimate of the dither at t1 :) i.e. G can get a good result simply by copying it!
the following scenario shows this effect...
consider three sequential frames, the middle one being a scene change.
at the very start of training the reconstruction loss is dominant and we get blobby outlines of the frame.
but as the contribution from the dither at time t0 kicks it things look good in general but the frames at the scene change end up being a ghosted mix attempt to copy through the old frame along with dithering the new one. (depending on the relative strength of the loss terms of G).
so the v3 version generally works and i'm sure with some more tuning i could get a better result but, as luck would have it, i actually find the results from v2 more appealing when testing on the actual eink screen. so even though the intention was do something like v3 i'm going to end up running something more like v2 (as shown in these couple of examples (though the resolution does it no justice (not to mention the fact the player will run about 5000 times slower than these gifs)))
i ran for a few weeks with a prototype that lived balanced precariously on a piece of foam below it's younger sibling pi zero eink screen running game of life. eventually i cut up some pieces of an old couch and made a simple wooden frame. a carpenter, i am not :/
prototype | frame |
ensemble nets are a method of representing an ensemble of models as one single logical model. we use jax's vmap operation to batch over not just the inputs but additionally sets of model parameters. we propose some approaches for training ensemble nets and introduce logit dropout as a way to improve ensemble generalisation as well as provide a method of calculating model confidence.
update: though i originally developed this project using vmap, which is how things are described below, the latest version of the code is a port to use pmap so we can run on a tpu pod slice, not just one machine
as part of my "embedding the chickens" project i wanted to use random projection embedding networks to generate pairs of similar images for weak labelling. since this technique works really well in an ensemble i did some playing around and got the ensemble running pretty fast in jax. i wrote it up in this blog post. since doing that i've been wondering how to not just run an ensemble net forward pass but how you might train one too...
for this problem we'll work with the eurosat/rgb dataset. eurosat/rgb is a 10 way classification task across 27,000 64x64 RGB images
here's a sample image from each of the ten classes...
as a baseline we'll start with a simple non ensemble network. it'll consist of a pretty vanilla convolutional stack, global spatial pooling, one dense layer and a final 10 way classification layer.
layer | shape |
input | (B, 64, 64, 3) |
conv2d | (B, 31, 31, 32) |
conv2d | (B, 15, 15, 64) |
conv2d | (B, 7, 7, 96) |
conv2d | (B, 3, 3, 96) |
global spatial pooling | (B, 96) |
dense | (B, 96) |
logits (i.e. dense with no activation) | (B, 10) |
all convolutions use 3x3 kernels with a stride of 2. the conv layers and the single
dense layer use a gelu activation. batch size is represented by B
.
we use no batch norm, no residual connections, nothing fancy at all. we're more interested in the training than getting the absolute best value. this network is small enough that we can train it fast but it still gives reasonable results. residual connections would be trivial to add but batch norm would be a bit more tricky given how we'll build the ensemble net later.
we'll use objax to manage the model params and orchestrate the training loops.
training for the baseline will be pretty standard but let's walk through it so we can call out a couple of specific things for comparison with an ensemble net later...
( we'll use 2 classes in these diagrams for ease of reading though the eurosat problem has 10 classes. )
walking through left to right...
(B, H, W, 3)
(B, H/2, W/2, 32)
(B, 2)
we'll train on 80% of the data, do hyperparam tuning on 10% (validation set) and report final results on the remaining 10% (test set)
for hyperparam tuning we'll use ax on very short runs of 30min for all trials. for experiment tracking we'll use wandb
the hyperparams we'll tune for the baseline will be...
param | description |
max_conv_size | conv layers with be sized as [32, 64, 128, 256] up to a max size of max_conv_size. i.e. a max_conv_size of 75 would imply sizes [32, 64, 75, 75] |
dense_kernel_size | how many units in the dense layer before the logits |
learning_rate | learning rate for optimiser |
we'd usually make choices like the conv sizes being powers of 2 instead of a smooth value but i was curious about the behaviour of ax for tuning. also we didn't bother with a learning rate schedule; we just use simple early stopping (against the validation set)
the best model of this group gets an accuracy of 0.913 on the validation set and 0.903 on the test set. ( usually not a fan of accuracy but the classes are pretty balanced so accuracy isn't a terrible thing to report. )
model | validation | test |
baseline | 0.913 | 0.903 |
so what then is an ensemble net?
logically we can think about our models as being functions that take two things 1) the parameters of the model and 2) an input example. from these they return an output.
# pseudo code
model(params, input) -> output
we pretty much always though run a batch of B
inputs at once.
this can be easily represented as a leading axis on the input and allows us to
make better use of accelerated hardware as well as providing some benefits regarding
learning w.r.t gradient variance.
jax's vmap function makes this trivial to implement by vectorising a call to the model across a vector of inputs to return a vector of outputs.
# pseudo code
vmap(partial(model, params))(b_inputs) -> b_outputs
interestingly we can use this same functionality to batch not across independent
inputs but instead batch across independent sets of M
model params. this effectively
means we run the M
models in parallel. we'll call these M
models sub models
from now on.
# pseudo code
vmap(partial(model, input))(m_params) -> m_outputs
and there's no reason why we can't do both batching across both a set of inputs as well as a set of model params at the same time.
# pseudo code
vmap(partial(model)(b_inputs, m_params) -> b_m_outputs
for a lot more specifics on how i use jax's vmap to support this see my prior post on jax random embedding ensemble nets.
and did somebody say TPUs? turns out we can make ensemble nets run super fast on TPUs by simply swapping the vmap calls for pmap ones! using pmap on a TPU will have each ensemble net run in parallel! see this colab for example code running pmap on TPUs
let's walk through this in a bit more detail with an ensemble net with two sub models.
(B, H, W, 3)
M
axis to
represent the outputs from the M
models and results in (M, B, H/2, W/2, 32)
M
axis is carried all the way through to the logits (M, B, 2)
(M, B, 2)
logits but we need (B, 2)
to compare against (B,)
labels. with logits this reduction is
very simple; just sum over the M
axis!
this gives us a way to train the sub models to act as a single ensemble unit as well as a way to run inference on the ensemble net in a single forward pass.
we'll refer to this approach as single_input since we are starting with a single image for all sub models.
an alternative approach to training is to provide a separate image per sub model. how would things differ if we did that?
M
axis since it's a different batch per sub model.
(M, B, H, W, 3)
M
axis through (M, B, H/2, W/2, 32)
(M, B, 2)
M
seperate labels for the M
inputs so we don't have to combine the logits at all, we can just calculate the mean loss across the (M, B)
training instances.
we'll call this approach multi_input.
note that this way of feeding separate images only really applies to training;
for inference if we want the representation of the ensemble it only really makes
sense to send a batch of (B)
images, not (M, B)
.
let's do some tuning as before but with a couple of additional hyper parameters that this time we'll sweep across.
we'll do each of the six combos of [(single, 2), (single, 4), (single, 8),
(multi, 2), (multi, 4), (multi, 8)]
and tune for 30 min for each.
when we poke around and facet by the various params there's only one that makes a difference; single_input mode consistently does better than multi_input.
in hindsight this is not surprising i suppose since single_input mode is effectively training one network with xM parameters (with an odd summing-of-logits kind of bottleneck)
when we check the best single_input 4 sub model ensemble net we get an accuracy of 0.920 against the validation set and 0.901 against the test set
model | validation | test |
baseline | 0.913 | 0.903 |
single_input | 0.920 | 0.901 |
looking at the confusion matrix the only really thing to note is the slight confusion between 'Permanent Crop' and 'Herbaceous Vegetation' which is reasonable given the similarity in RGB.
we can also review the confusion matrices of each of the 4 sub models run as individuals; i.e. not working as an ensemble. we observe the quality of each isn't great with accuracies of [0.111, 0.634, 0.157, 0.686]. again makes sense since they had been trained only to work together. that first model really loves 'Forests', but don't we all...
the performance of the multi_input ensemble isn't quite as good with a validation accuracy of 0.902 and test accuracy of 0.896. the confusion matrix looks similar to the single_input mode version.
model | validation | test |
baseline | 0.913 | 0.903 |
single_input | 0.920 | 0.901 |
multi_input | 0.902 | 0.896 |
this time though the output of each of the 4 sub models individually is much stronger with accuracies of [0.842, 0.85, 0.84, 0.83, 0.86]. this makes sense since they were trained to not predict as one model. it is nice to see at least that the ensemble result is higher than any one model. and reviewing their confusion matrices they seem to specialise in different aspects with differing pairs of confused classes.
the main failing of the single_input approach is that the sub models are trained to always operate together; that breaks some of the core ideas of why we do ensembles in the first place. as i was thinking about this i was reminded that the core idea of dropout is quite similar; when nodes in a dense layer are running together we can drop some out to ensure other nodes don't overfit to expecting them to always behave in a particular way.
so let's do the same with the sub models of the ensemble. my first thought around this was that the most logical place would be at the logits. we can zero out the logits of a random half of the models during training and, given the ensembling is implemented by summing the logits, this effectively removes those models from the ensemble. the biggest con though is the waste of the forward pass of running those sub models in the first place. during inference we don't have to do anything in terms of masking & there's no need to do any rescaling ( that i can think of ).
so how does it do? accuracy is 0.914 against the validation set and 0.911 against the test set; the best result so far! TBH though, these numbers are pretty close anyways so maybe we were just lucky ;)
model | validation | test |
baseline | 0.913 | 0.903 |
single_input | 0.920 | 0.901 |
multi_input | 0.902 | 0.896 |
logit drop | 0.914 | 0.911 |
the sub models are all now doing OK with accuracies of [0.764, 0.827, 0.772, 0.710]. though the sub models aren't as strong as the sub models of the multi_input mode, the overall performance is the best. great! seems like a nice compromise between the two!
the main problem i had with dropping logits is that there is a wasted forward pass for half the sub models. then i realised why run the models at all? instead of dropping logits we can just choose, through advanced indexing, a random half of the models to run a forward pass through. this has the same effect of running a random half of the models at a time but only requires half the forward pass compute. this approach of dropping models is what the code currently does. (though the dropping of logits is in the git history)
ensembles also provide a clean way of measuring confidence of a prediction. if the variance of predictions across sub models is low it implies the ensemble as a whole is confident. alternatively if the variance is high it implies the ensemble is not confident.
with the ensemble model that has been trained with logit dropout we can
get an idea of this variance by considering the ensemble in a hold-one-out
fashion; we can obtain M
different predictions from the ensemble
by running it as if each of the M
sub models was not present (using the same
idea as the logit dropout).
consider a class that the ensemble is very good at; e.g. 'Sea & Lake'. given a batch of 8 of these images across an ensemble net with 4 sub models we get the following prediction mean and stddevs.
idx | y_pred | mean(P(class)) | std(P(class)) |
0 | Sea & Lake | 0.927 | 0.068 |
1 | Sea & Lake | 1.000 | 0.000 |
2 | Sea & Lake | 1.000 | 0.000 |
3 | Sea & Lake | 0.999 | 0.001 |
4 | Sea & Lake | 0.989 | 0.019 |
5 | Sea & Lake | 1.000 | 0.000 |
6 | Sea & Lake | 1.000 | 0.000 |
7 | Sea & Lake | 1.000 | 0.000 |
whereas when we look at a class the model is not so sure of, e.g. 'Permanent Crop', we can see that for the lower probability cases have a higher variance across the models.
idx | y_pred | mean(P(class)) | std(P(class)) |
0 | Industrial Buildings | 0.508 | 0.282 |
1 | Permanent Crop | 0.979 | 0.021 |
2 | Permanent Crop | 0.703 | 0.167 |
3 | Herbaceous Vegetation | 0.808 | 0.231 |
4 | Permanent Crop | 0.941 | 0.076 |
5 | Permanent Crop | 0.979 | 0.014 |
6 | Permanent Crop | 0.833 | 0.155 |
7 | Permanent Crop | 0.968 | 0.025 |
check it out in this tutorial and let me know what you think!
]]>