# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

package AI::MXNet::Optimizer;
use strict;
use warnings;
use AI::MXNet::NS;
use AI::MXNet::Base;
use AI::MXNet::NDArray;
use AI::MXNet::Random;
use List::Util qw(max);

=head1 NAME

    AI::MXNet::Optimizer - Common Optimization algorithms with regularizations.

=head1  DESCRIPTION

    Common Optimization algorithms with regularizations.
=cut

use Mouse;
use AI::MXNet::Function::Parameters;
my %opt_registry;
method get_opt_registry()
{
    return \%opt_registry;
}

method register()
{
    my $name = $self;
    ($name) = $name =~ /::(\w+)$/;
    {  no strict 'refs'; *{__PACKAGE__."::$name"} = sub { shift; $self->new(@_)  }; }
    $name = lc $name;
    if(exists $opt_registry{ $name })
    {
        my $existing = $opt_registry{ $name };
        warn(
            "WARNING: New optimizer $self.$name"
            ."is overriding existing optimizer $existing.$name"
        );
    }
    $opt_registry{ $name } = $self;
}

=head2 create_optimizer

        Create an optimizer with specified name.

        Parameters
        ----------
        $name: Str
            Name of required optimizer. Should be the name
            of a subclass of Optimizer. Case insensitive.

        :$rescale_grad : Num
            Rescaling factor on gradient. Normally should be 1/batch_size.

        %kwargs: Hash
            Parameters for optimizer

        Returns
        -------
        opt : Optimizer
            The result optimizer.
=cut

method create_optimizer(Str $name, %kwargs)
{
    if(exists $opt_registry{ lc $name })
    {
        my $rescale_grad = delete($kwargs{rescale_grad})//1;
        return $opt_registry{ lc $name }->new(
            rescale_grad => $rescale_grad,
            %kwargs
        );
    }
    confess("Cannot find optimizer $name");
}

*create = \&create_optimizer;

has 'rescale_grad'        => (is => "rw", isa => "Num", default=>1);
has 'lr'                  => (is => "rw", isa => "Num");
has 'learning_rate'       => (is => "rw", isa => "Num", default => 0.01);
has 'lr_scheduler'        => (is => "rw", isa => "Maybe[AI::MXNet::LRScheduler]");
has 'wd'                  => (is => "rw", isa => "Num", default => 0);
has 'lr_mult'             => (is => "rw", isa => "HashRef", default => sub { +{} });
has 'wd_mult'             => (is => "rw", isa => "HashRef", , default => sub { +{} });
has 'num_update'          => (is => "rw", isa => "Int");
has 'begin_num_update'    => (is => "rw", isa => "Int", default => 0);
has '_index_update_count' => (is => "rw", isa => "HashRef", default => sub { +{} });
has 'clip_gradient'       => (is => "rw", isa => "Maybe[Num]");
has 'param_idx2name'      => (is => "rw", isa => "HashRef[Str]", default => sub { +{} });
has 'idx2name'            => (is => "rw", isa => "HashRef[Str]");
has 'sym'                 => (is => "rw", isa => "Maybe[AI::MXNet::Symbol]");
has 'param_dict'          => (is => "rw", isa => "HashRef", default => sub { +{} });

sub BUILD
{
    my $self = shift;
    if($self->lr_scheduler)
    {
        $self->lr_scheduler->base_lr($self->learning_rate);
    }
    $self->lr($self->learning_rate);
    $self->num_update($self->begin_num_update);
    $self->idx2name({ %{ $self->param_idx2name } });
    $self->set_lr_mult({});
    $self->set_wd_mult({});
}
# Create additional optimizer state such as momentum.
# override in implementations.
method create_state($index, $weight){}

# Update the parameters. override in implementations
method update($index, $weight, $grad, $state){}

# set lr scale is deprecated. Use set_lr_mult instead.
method set_lr_scale($args_lrscale)
{
    Carp::cluck("set lr scale is deprecated. Use set_lr_mult instead.");
}

=head2 set_lr_mult

        Set individual learning rate multipler for parameters

        Parameters
        ----------
        args_lr_mult : dict of string/int to float
            set the lr multipler for name/index to float.
            setting multipler by index is supported for backward compatibility,
            but we recommend using name and symbol.
=cut

method set_lr_mult(HashRef[Num] $args_lr_mult)
{
    $self->lr_mult({});
    if($self->sym)
    {
        my $attr = $self->sym->attr_dict();
        for my $name (@{ $self->sym->list_arguments() })
        {
            if(exists $attr->{ $name } and exists $attr->{ $name }{ __lr_mult__ })
            {
                $self->lr_mult->{ $name } = $attr->{ $name }{ __lr_mult__ };
            }
        }
    }
    $self->lr_mult({ %{ $self->lr_mult }, %{ $args_lr_mult } });
}

=head2 set_wd_mult

        Set individual weight decay multipler for parameters.
        By default wd multipler is 0 for all params whose name doesn't
        end with _weight, if param_idx2name is provided.

        Parameters
        ----------
        args_wd_mult : dict of string/int to float
            set the wd multipler for name/index to float.
            setting multipler by index is supported for backward compatibility,
            but we recommend using name and symbol.
=cut

method set_wd_mult(HashRef[Num] $args_wd_mult)
{
    $self->wd_mult({});
    for my $n (values %{ $self->idx2name })
    {
        if(not $n =~ /(?:_weight|_gamma)$/)
        {
            $self->wd_mult->{ $n } = 0;
        }
    }
    if($self->sym)
    {
        my $attr = $self->sym->attr_dict();
        for my $name (@{ $self->sym->list_arguments() })
        {
            if(exists $attr->{ $name } and exists $attr->{ $name }{ __wd_mult__ })
            {
                $self->wd_mult->{ $name } = $attr->{ $name }{ __wd_mult__ };
            }
        }
    }
    $self->wd_mult({ %{ $self->wd_mult }, %{ $args_wd_mult } });
}

method _update_count(Index $index)
{
    if(not exists $self->_index_update_count->{ $index })
    {
        $self->_index_update_count->{ $index } = $self->begin_num_update;
    }
    $self->_index_update_count->{ $index } += 1;
    $self->num_update(max($self->_index_update_count->{ $index }, $self->num_update));
}

method _get_lr(Index $index)
{
    my $lr;
    if($self->lr_scheduler)
    {
        $lr = $self->lr_scheduler->($self->num_update);
    }
    else
    {
        $lr = $self->lr;
    }

    if(exists $self->param_dict->{ $index })
    {
        $lr *= $self->param_dict->{ $index }->lr_mult;
    }
    elsif(exists $self->lr_mult->{ $index })
    {
        $lr *= $self->lr_mult->{ $index };
    }
    elsif(exists $self->idx2name->{ $index })
    {
        $lr *= $self->lr_mult->{ $self->idx2name->{ $index } }//1;
    }
    return $lr;
}

method _get_wd(Index $index)
{
    my $wd = $self->wd;
    if(exists $self->param_dict->{ $index })
    {
        $wd *= $self->param_dict->{ $index }->wd_mult;
    }
    elsif(exists $self->wd_mult->{ $index })
    {
        $wd *= $self->wd_mult->{ $index };
    }
    elsif(exists $self->idx2name->{ $index })
    {
        $wd *= $self->wd_mult->{ $self->idx2name->{ $index } }//1;
    }
    return $wd;
}

=head1 NAME

    AI::MXNet::SGD - A very simple SGD optimizer with momentum and weight regularization.
=cut

=head1 DESCRIPTION

    A very simple SGD optimizer with momentum and weight regularization.

    If the storage types of weight and grad are both 'row_sparse', and 'lazy_update' is True,
    **lazy updates** are applied by

        for row in grad.indices:
            rescaled_grad[row] = lr * rescale_grad * clip(grad[row], clip_gradient) + wd * weight[row]
            state[row] = momentum[row] * state[row] + rescaled_grad[row]
            weight[row] = weight[row] - state[row]

    The sparse update only updates the momentum for the weights whose row_sparse
    gradient indices appear in the current batch, rather than updating it for all
    indices. Compared with the original update, it can provide large
    improvements in model training throughput for some applications. However, it
    provides slightly different semantics than the original update, and
    may lead to different empirical results.

    Otherwise, **standard updates** are applied by::

        rescaled_grad = lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
        state = momentum * state + rescaled_grad
        weight = weight - state

    Parameters
    ----------
    learning_rate : Num, optional
        learning_rate of SGD

    momentum : Num, optional
       momentum value

    wd : Num, optional
        L2 regularization coefficient add to all the weights

    rescale_grad : Num, optional
        rescaling factor of gradient. Normally should be 1/batch_size.

    clip_gradient : Num, optional
        clip gradient in range [-clip_gradient, clip_gradient]

    param_idx2name : hash ref of Str/Int to Num, optional
        special treat weight decay in parameter ends with bias, gamma, and beta

    multi_precision: Bool, optional
        Flag to control the internal precision of the optimizer.
        False results in using the same precision as the weights (default),
        True makes internal 32-bit copy of the weights and applies gradients
        in 32-bit precision even if actual weights used in the model have lower precision.
        Turning this on can improve convergence and accuracy when training with float16.

    lazy_update: Bool, optional, default true
=cut

package AI::MXNet::SGD;
use Mouse;
extends 'AI::MXNet::Optimizer';

has 'kwargs'          => (is => "rw", isa => "HashRef[Num]");
has 'momentum'        => (is => "rw", isa => "Num", default => 0);
has 'multi_precision' => (is => "ro", isa => "Bool", default => 0);
has 'lazy_update'     => (is => "ro", isa => "Bool", default => 1);

sub BUILD
{
    my $self = shift;
    $self->kwargs({});
    if($self->momentum)
    {
        $self->kwargs->{momentum} = $self->momentum;
    }
    if($self->clip_gradient)
    {
        $self->kwargs->{clip_gradient} = $self->clip_gradient;
    }
}

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    my $momentum;
    my $weight_master_copy;
    my $stype = $self->lazy_update ? $weight->stype : 'default';
    if($self->multi_precision and $weight->dtype eq 'float16')
    {
        my $weight_master_copy = AI::MXNet::NDArray->array($weight, ctx => $weight->context, dtype => 'float32');
        if($self->momentum != 0)
        {
            $momentum = AI::MXNet::NDArray->zeros($weight->shape, stype => $stype, ctx => $weight->context, dtype => 'float32');
        }
        return [$momentum, $weight_master_copy];
    }
    if($weight->dtype eq 'float16' and not $self->multi_precision)
    {
        AI::MXNet::Logging->warning(
            "Accumulating with float16 in optimizer can lead to ".
            "poor accuracy or slow convergence. ".
            "Consider using multi_precision=True option of the ".
            "SGD optimizer"
        );
    }
    if($self->momentum != 0)
    {
        $momentum = AI::MXNet::NDArray->zeros($weight->shape, stype => $stype, ctx => $weight->context, dtype => $weight->dtype);
    }
    return $momentum;
}

method update(
    Index                     $index,
    AI::MXNet::NDArray        $weight,
    AI::MXNet::NDArray        $grad,
    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
)
{
    $self->_update_count($index);
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    my $kwargs = {
        out => $weight,
        lr  => $lr,
        wd  => $wd,
        rescale_grad => $self->rescale_grad,
        %{ $self->kwargs }
    };
    my $use_multi_precision = ref($state) eq 'ARRAY';
    if(not $use_multi_precision)
    {
        if(defined $state)
        {
            AI::MXNet::NDArray->sgd_mom_update(
                $weight, $grad, $state, $kwargs
            );
        }
        else
        {
            AI::MXNet::NDArray->sgd_update(
                $weight, $grad, $kwargs
            );
        }
    }
    else
    {
        if(defined $state->[0])
        {
            AI::MXNet::NDArray->mp_sgd_mom_update(
                $weight, $grad, $state->[0], $state->[1], $kwargs
            );
        }
        else
        {
            AI::MXNet::NDArray->mp_sgd_update(
                $weight, $grad, $state->[1], $kwargs
            );
        }
    }
}

__PACKAGE__->register;

=head1 NAME

    AI::MXNet::Signum - The Signum optimizer that takes the sign of gradient or momentum.
=cut

=head1 DESCRIPTION

    The optimizer updates the weight by:

        rescaled_grad = rescale_grad * clip(grad, clip_gradient) + wd * weight
        state = momentum * state + (1-momentum)*rescaled_grad
        weight = (1 - lr * wd_lh) * weight - lr * sign(state)

    See the original paper at: https://jeremybernste.in/projects/amazon/signum.pdf

    This optimizer accepts the following parameters in addition to those accepted
    by AI::MXNet::Optimizer

    Parameters
    ----------
    momentum : Num, optional
       The momentum value.
    wd_lh : Num, optional
       The amount of decoupled weight decay regularization, see details in the original paper at:
       https://arxiv.org/abs/1711.05101
=cut

package AI::MXNet::Signum;
use Mouse;
extends 'AI::MXNet::Optimizer';

has 'momentum' => (is => "rw", isa => "Num", default => 0.9);
has 'wd_lh'    => (is => "rw", isa => "Num", default => 0);

method create_state(Index $index, AI::MXNet::NDArray $weight)
{

    my $momentum;
    if($self->momentum != 0)
    {
        $momentum = AI::MXNet::NDArray->zeros(
            $weight->shape,
            ctx => $weight->context,
            dtype=>$weight->dtype,
            stype=>$weight->stype
        );
    }
    return $momentum;
}

method update(
    Index                     $index,
    AI::MXNet::NDArray        $weight,
    AI::MXNet::NDArray        $grad,
    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
)
{
    $self->_update_count($index);
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    my %kwargs = (
        out => $weight,
        lr  => $lr,
        wd  => $wd,
        rescale_grad => $self->rescale_grad,
    );
    if($self->momentum > 0)
    {
        $kwargs{momentum} = $self->momentum;
    }
    if($self->clip_gradient)
    {
        $kwargs{clip_gradient} = $self->clip_gradient;
    }
    if($self->wd_lh)
    {
        $kwargs{wd_lh} = $self->wd_lh;
    }
    if(defined $state)
    {
        AI::MXNet::NDArray->signum_update(
            $weight, $grad, $state, %kwargs
        );
    }
    else
    {
        AI::MXNet::NDArray->signsgd_update(
            $weight, $grad, %kwargs
        );
    }
}

__PACKAGE__->register;

=head1 NAME

    AI::MXNet::FTML - The FTML optimizer.
=cut

=head1 DESCRIPTION

    This class implements the optimizer described in
    *FTML - Follow the Moving Leader in Deep Learning*,
    available at http://proceedings.mlr.press/v70/zheng17a/zheng17a.pdf.

    This optimizer accepts the following parameters in addition to those accepted
    by AI::MXNet::Optimizer

    Parameters
    ----------
    beta1 : Num, optional
        0 < beta1 < 1. Generally close to 0.5.
    beta2 : Num, optional
        0 < beta2 < 1. Generally close to 1.
    epsilon : Num, optional
        Small value to avoid division by 0.
=cut

package AI::MXNet::FTML;
use Mouse;
extends 'AI::MXNet::Optimizer';

has 'beta1'   => (is => "rw", isa => "Num", default => 0.6);
has 'beta2'   => (is => "rw", isa => "Num", default => 0.999);
has 'epsilon' => (is => "rw", isa => "Num", default => 1e-8);

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    return [
        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # d_0
        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # v_0
        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # z_0
    ];
}

method update(
    Index                     $index,
    AI::MXNet::NDArray        $weight,
    AI::MXNet::NDArray        $grad,
    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
)
{
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    my $t = $self->_update_count($index);
    my %kwargs = (
        out => $weight,
        lr  => $lr,
        wd  => $wd,
        t   => $t,
        beta1 => $self->beta1,
        beta2 => $self->beta2,
        epsilon => $self->epsilon,
        rescale_grad => $self->rescale_grad
    );
    if($self->clip_gradient)
    {
        $kwargs{clip_grad} = $self->clip_gradient;
    }
    AI::MXNet::NDArray->ftml_update($weight, $grad, @{ $state }, \%kwargs);
}

__PACKAGE__->register;

=head1 NAME

    AI::MXNet::LBSGD - The Large Batch SGD optimizer with momentum and weight decay.
=cut

=head1 DESCRIPTION

    The optimizer updates the weight by::

        state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight
        weight = weight - state

    Parameters
    ----------
    momentum : Num, optional
       The momentum value.
    multi_precision: Bool, optional
       Flag to control the internal precision of the optimizer.
       0 results in using the same precision as the weights (default),
       1 makes internal 32-bit copy of the weights and applies gradients
                in 32-bit precision even if actual weights used in the model have lower precision.`<
                Turning this on can improve convergence and accuracy when training with float16.
    warmup_strategy: string ('linear', 'power2', 'sqrt'. , 'lars'   default : 'linear')
    warmup_epochs: unsigned, default: 5
    batch_scale:   unsigned, default: 1 (same as batch size*numworkers)
    updates_per_epoch: updates_per_epoch (default: 32, Default might not reflect true number batches per epoch. Used for warmup.)
    begin_epoch: unsigned, default 0, starting epoch.
=cut

package AI::MXNet::LBSGD;
use Mouse;
extends 'AI::MXNet::Optimizer';

has 'momentum'          => (is => 'rw', isa => 'Num', default => 0);
has 'multi_precision'   => (is => 'rw', isa => 'Bool', default => 0);
has 'warmup_startegy'   => (is => 'rw', isa => 'Str', default => 'linear');
has 'warmup_epochs'     => (is => 'rw', isa => 'Int', default => 5);
has 'batch_scale'       => (is => 'rw', isa => 'Num', default => 1);
has 'updates_per_epoch' => (is => 'rw', isa => 'Int', default => 32);
has 'begin_epoch'       => (is => 'rw', isa => 'Int', default => 0);
has 'num_epochs'        => (is => 'rw', isa => 'Int', default => 60);
has 'beta2'             => (is => 'rw', isa => 'Num', default => 0.999);
has 'epsilon'           => (is => 'rw', isa => 'Num', default => 1e-8);
has 'init_updates'      => (is => 'rw', init_arg => undef);
has [qw/lbmult
        cumgrads
        adaptive
        init_updates
        admult/]        => (is => 'rw', init_arg => undef);

sub BUILD
{
    my $self = shift;
    AI::MXNet::Logging->info('Running Large-Batch SGD Algorithm');
    AI::MXNet::Logging->info(
        '(Batch_scale=%f, warmup_epochs=%d, warmup_strategy=%s, updates_per_epoch=%d)',
        map { $self->$_ } qw/batch_scale warmup_epochs warmup_strategy updates_per_epoch/
    );
    $self->init_updates($self->begin_epoch * $self->updates_per_epoch);
    $self->lbmult(1);
    $self->cumgrads({});
    $self->adaptive(0);
    $self->admult(1);
}

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    return [
        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # d_0
        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # v_0
        AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype), # z_0
    ];
    my $momentum;
    my $weight_master_copy;
    if($self->multi_precision and $weight->dtype eq 'float16')
    {
        $weight_master_copy = AI::MXNet::NDArray->array($weight, ctx=>$weight->context, dtype=>'float32');
        if($self->momentum != 0)
        {
            $momentum = AI::MXNet::NDArray->zeros(
                $weight->shape, ctx => $weight->context, dtype => 'float32',
                stype => $weight->stype
            );
        }
        return [$momentum, $weight_master_copy];
    }
    if($weight->dtype eq 'float16' and not $self->multi_precision)
    {
        AI::MXNet::Logging->warning(
            "Accumulating with float16 in optimizer can lead to "
            ."poor accuracy or slow convergence. "
            ."Consider using multi_precision=True option of the "
            ."LBSGD optimizer"
        );
    }
    if($self->momentum != 0)
    {
        $momentum = AI::MXNet::NDArray->zeros(
            $weight->shape, ctx => $weight->context, dtype => $weight->dtype,
            stype => $weight->stype
        );
    }
    return $momentum;
}

method _get_lbmult($nup)
{
    my $nwup = $self->warmup_epochs * $self->updates_per_epoch;
    my $strategy = $self->warmup_strategy;
    my $maxmult = $self->batch_scale;
    my $mult;
    if($nup >= $nwup)
    {
        $mult = $maxmult;
    }
    elsif($nwup <= 1)
    {
        $mult = 1;
    }
    else
    {
        if ($strategy eq 'linear')
        {
            $mult = 1 + ($maxmult - 1) * $nup / $nwup;
        }
        elsif($strategy eq 'power2')
        {
            $mult = 1 + ($maxmult-1) * ($nup*$nup)/($nwup*$nwup);
        }
        elsif($strategy eq 'sqrt')
        {
            $mult = 1 + ($maxmult - 1) * sqrt($nup / $nwup);
        }
        else
        {
            $mult = 1;
        }
    }
    return $mult;
}


method _get_lars($weight, $g, $wd)
{
    my $weight2 = $self->_l2norm($weight);
    my $grad2 = $self->_l2norm($g);
    my $lars = sqrt($weight2 / ($grad2 + $wd * $weight2 + 1e-18));
    if($lars < 0.01)
    {
        $lars = 0.01;
    }
    elsif($lars > 100)
    {
        $lars = 100;
    }
    return $lars;
}

method _l2norm($v)
{
    my $norm = AI::MXNet::NDArray->multiply($v, $v)->aspdl->sum;
    return $norm;
}

method  _reset_cum_gradient($index)
{
    $self->cumgrads->{$index}{cum_grad} = 0;
}

method _get_cum_gradient($index)
{
    if(exists $self->cumgrads->{$index})
    {
        return $self->cumgrads->{$index};
    }
    else
    {
        return {}
    }
}

method _put_cum_gradient($index, $cgrad)
{
    $self->cumgrads->{$index} = $cgrad;
}

method _cumulate_gradient($grad, $index)
{
    my $cgrad = $self->_get_cum_gradient($index);
    my ($num_cums, $cum_grad);
    if(%{ $cgrad })
    {
        my $num_cums = $cgrad->{num_cums};
        if($num_cums > 0)
        {
            $cum_grad = $cgrad->{cum_grad} + $grad;
            $num_cums += 1;
        }
        else
        {
            $cum_grad = $grad;
            $num_cums = $self->init_updates + 1;
        }
    }
    else
    {
        $cum_grad = $grad;
        $num_cums = $self->init_updates + 1;
    }
    $cgrad = {cum_grad => $cum_grad, num_cums => $num_cums};
    $self->_put_cum_gradient($index, $cgrad);
    return $cgrad;
}



method update(
    Index                     $index,
    AI::MXNet::NDArray        $weight,
    AI::MXNet::NDArray        $grad,
    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
)
{
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    my $t = $self->_update_count($index);
    my $cgrad = $self->_cumulate_gradient($grad, $index);
    if(($cgrad->{num_cums} % $self->batch_scale) == 0)
    {
        my $lbmult;
        $grad = $cgrad->{cum_grad} / $self->batch_scale;
        if($self->warmup_strategy eq 'lars')
        {
            $lbmult = $self->_get_lars($weight, $grad, $wd);
        }
        else
        {
            $lbmult = $self->_get_lbmult($cgrad->{num_cums});
        }
        $lr = $lr * $lbmult;
        my %kwargs = (
            out => $weight,
            lr  => $lr,
            wd  => $wd,
            rescale_grad => $self->rescale_grad
        );
        if($self->clip_gradient)
        {
            $kwargs{clip_gradient} = $self->clip_gradient;
        }
        if($self->momentum > 0)
        {
            $kwargs{momentum} = $self->momentum;
        }
        my $use_multi_precision = ref($state) eq 'ARRAY';
        if(not $use_multi_precision)
        {
            if(defined $state)
            {
                AI::MXNet::NDArray->sgd_mom_update($weight, $grad, $state, %kwargs);
            }
            else
            {
                AI::MXNet::NDArray->sgd_update($weight, $grad, %kwargs);
            }
        }
        else
        {
            if(defined $state->[0])
            {
                AI::MXNet::NDArray->mp_sgd_mom_update($weight, $grad, @{ $state }, %kwargs);
            }
            else
            {
                AI::MXNet::NDArray->mp_sgd_update($weight, $grad, $state->[1], %kwargs);
            }
        }
        $self->_reset_cum_gradient($index);
    }
    else
    {
        AI::MXNet::NDArray->sgd_update($weight, $grad, out => $weight, lr => 0, wd => $wd);
    }
}

__PACKAGE__->register;

package AI::MXNet::DCASGD;
use Mouse;
use AI::MXNet::Base;
extends 'AI::MXNet::Optimizer';

=head1 NAME

    AI::MXNet::DCASGD - DCASGD optimizer with momentum and weight regularization.
=cut

=head1 DESCRIPTION

    DCASGD optimizer with momentum and weight regularization.

    Implements paper "Asynchronous Stochastic Gradient Descent with
                    Delay Compensation for Distributed Deep Learning"

    Parameters
    ----------
    learning_rate : Num, optional
        learning_rate of SGD

    momentum : Num, optional
       momentum value

    lamda : NUm, optional
       scale DC value

    wd : Num, optional
        L2 regularization coefficient add to all the weights

    rescale_grad : Num, optional
        rescaling factor of gradient. Normally should be 1/batch_size.

    clip_gradient : Num, optional
        clip gradient in range [-clip_gradient, clip_gradient]

    param_idx2name : hash ref of Str/Int to Num, optional
        special threating of weight decay for parameters that end with bias, gamma, and beta
=cut
has 'momentum'        => (is => 'ro', isa => 'Num', default => 0);
has 'lamda'           => (is => 'ro', isa => 'Num', default => 0.04);
has 'weight_previous' => (is => 'rw', init_arg => undef);

sub BUILD
{
    my $self = shift;
    $self->weight_previous({});
}

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
        return [
            $self->momentum ? AI::MXNet::NDArray->zeros(
                $weight->shape, ctx => $weight->context, dtype => $weight->dtype
            ) : undef,
            $weight->copy
        ];
}

method update(
    Index                     $index,
    AI::MXNet::NDArray        $weight,
    AI::MXNet::NDArray        $grad,
    Maybe[AI::MXNet::NDArray|ArrayRef[Maybe[AI::MXNet::NDArray]]] $state
)
{
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    $self->_update_count($index);
    $grad *= $self->rescale_grad;
    if($self->clip_gradient)
    {
        $grad = AI::MXNet::NDArray->clip(
            $grad,
            -$self->clip_gradient,
            $self->clip_gradient
        );
    }
    my ($mom, $weight_previous) = @{ $state };
    if(defined $mom)
    {
        $mom *= $self->momentum;
        $mom += -$lr * (
                $grad + $wd * $weight
                    +
                $self->lamda * $grad * $grad * ($weight - $weight_previous)
        );
    }
    else
    {
        assert($self->momentum == 0);
        $mom = -$lr * (
                $grad + $wd * $weight
                    +
                $self->lamda * $grad * $grad * ($weight - $weight_previous)
        );
    }
    $weight_previous .= $weight;
    $weight += $mom;
}

__PACKAGE__->register;

=head1 NAME

    AI::MXNet::NAG - SGD with Nesterov weight handling.
=cut

=head1 DESCRIPTION

    It is implemented according to
    https://github.com/torch/optim/blob/master/sgd.lua
=cut

package AI::MXNet::NAG;
use Mouse;
extends 'AI::MXNet::SGD';

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    my $momentum;
    my $weight_master_copy;
    my $do_multi_precision = ($self->multi_precision and $weight->dtype eq 'float16');
    if($do_multi_precision)
    {
        if($self->momentum != 0)
        {
            $momentum = AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>'float32');
        }
        $weight_master_copy = AI::MXNet::NDArray->array($weight, ctx=>$weight->context, dtype=>'float32');
        return [$weight_master_copy, $momentum];
    }
    else
    {
        if($self->momentum != 0)
        {
            $momentum = AI::MXNet::NDArray->zeros($weight->shape, ctx => $weight->context, dtype=>$weight->dtype);
        }
        return $momentum;
    }
}

method update($index, $weight, $grad, $state)
{
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    $self->_update_count($index);
    my $use_multi_precision = (defined $state and not Scalar::Util::blessed($state) and ref($state eq 'ARRAY'));
    if(not $use_multi_precision)
    {
        $grad *= $self->rescale_grad;
        if(defined $self->clip_gradient)
        {
            $grad = AI::MXNet::NDArray->clip($grad, -$self->clip_gradient, $self->clip_gradient);
        }
        if($self->momentum == 0)
        {
            $weight += -$lr * ($grad + $wd * $weight);
        }
        else
        {
            my $mom = $state;
            $mom *= $self->momentum;
            $grad += $wd * $weight;
            $mom += $grad;
            $grad += $self->momentum * $mom;
            $weight += -$lr * $grad;
        }
    }
    else
    {
        my $grad32 = AI::MXNet::NDArray->array($grad, ctx=>$grad->context, dtype=>'float32');
        $grad32 *= $self->rescale_grad;
        if(defined $self->clip_gradient)
        {
            $grad32 = AI::MXNet::NDArray->clip($grad32, -$self->clip_gradient, $self->clip_gradient);
        }
        my $mom = $state->[1];
        my $weight32 = $state->[0];
        if($self->momentum == 0)
        {
            $weight32 += -$lr * ($grad32 + $wd * $weight32);
        }
        else
        {
            $mom *= $self->momentum;
            $grad32 += $wd * $weight32;
            $mom += $grad32;
            $grad32 += $self->momentum * $mom;
            $weight32 += -$lr * $grad32;
        }
        my $tmp = $weight32->astype($weight->dtype);
        $tmp->copyto($weight);
    }
}

__PACKAGE__->register;

=head1 NAME

    AI::MXNet::SGLD - Stochastic Gradient Riemannian Langevin Dynamics.
=cut

=head1 DESCRIPTION

    Stochastic Gradient Riemannian Langevin Dynamics.

    This class implements the optimizer described in the paper *Stochastic Gradient
    Riemannian Langevin Dynamics on the Probability Simplex*, available at
    https://papers.nips.cc/paper/4883-stochastic-gradient-riemannian-langevin-dynamics-on-the-probability-simplex.pdf.

    Parameters
    ----------
    learning_rate : Num, optional
        learning_rate of SGD

    wd : Num, optional
        L2 regularization coefficient add to all the weights

    rescale_grad : Num, optional
        rescaling factor of gradient. Normally should be 1/batch_size.

    clip_gradient : Num, optional
        clip gradient in range [-clip_gradient, clip_gradient]
=cut

package AI::MXNet::SGLD;
use Mouse;

extends 'AI::MXNet::Optimizer';

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    return undef;
}

method update(
    Index $index,
    AI::MXNet::NDArray $weight,
    AI::MXNet::NDArray $grad,
    AI::MXNet::NDArray|Undef $state
)
{
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    $self->_update_count($index);
    $grad *= $self->rescale_grad;
    if($self->clip_gradient)
    {
        $grad = AI::MXNet::NDArray->clip(
            $grad,
            -$self->clip_gradient,
             $self->clip_gradient
        );
    }
    $weight +=  - $lr/2 * ($grad + $wd * $weight)
                    +
                AI::MXNet::Random->normal(
                        0, sqrt($lr),
                        shape => $weight->shape,
                        ctx   => $weight->context,
                        dtype => $weight->dtype
                );
}

__PACKAGE__->register;

=head1 NAME

    AI::MXNet::Adam - Adam optimizer as described in [King2014]_.
=cut

=head1 DESCRIPTION

    Adam optimizer as described in [King2014]_.

    .. [King2014] Diederik Kingma, Jimmy Ba,
       *Adam: A Method for Stochastic Optimization*,
       http://arxiv.org/abs/1412.6980

    Parameters
    ----------
    learning_rate : Num, optional
        Step size.
        Default value is set to 0.001.
    beta1 : Num, optional
        Exponential decay rate for the first moment estimates.
        Default value is set to 0.9.
    beta2 : Num, optional
        Exponential decay rate for the second moment estimates.
        Default value is set to 0.999.
    epsilon : Num, optional
        Default value is set to 1e-8.

    wd : NUm, optional
        L2 regularization coefficient add to all the weights
    rescale_grad : Num, optional
        rescaling factor of gradient. Normally should be 1/batch_size.

    clip_gradient : Num, optional
        clip gradient in range [-clip_gradient, clip_gradient]
=cut
package AI::MXNet::Adam;
use Mouse;

extends 'AI::MXNet::Optimizer';

has 'kwargs'   => (is => "rw", isa => "HashRef[Num]");
has '+learning_rate' => (default => 0.001);
has 'beta1'    => (is => "rw", isa => "Num", default => 0.9);
has 'beta2'    => (is => "rw", isa => "Num", default => 0.999);
has 'epsilon'  => (is => "rw", isa => "Num", default => 1e-8);
has 'lazy_update' => (is => 'rw', isa => 'Bool', default => 1);

sub BUILD
{
    my $self = shift;
    $self->kwargs({
        beta1   => $self->beta1,
        beta2   => $self->beta2,
        epsilon => $self->epsilon
    });
    if($self->clip_gradient)
    {
        $self->kwargs->{clip_gradient} = $self->clip_gradient;
    }
}

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    my $stype = $self->lazy_update ? $weight->stype : 'default';
    return [AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context,
                dtype => $weight->dtype,
                stype => $stype
            ),  # mean
            AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context,
                dtype => $weight->dtype,
                stype => $stype
            )  # variance
    ];
}

method update(
    Index $index,
    AI::MXNet::NDArray $weight,
    AI::MXNet::NDArray $grad,
    ArrayRef[AI::MXNet::NDArray] $state
)
{
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    $self->_update_count($index);
    my $t = $self->_index_update_count->{$index};
    my $coef1 = 1 - $self->beta1**$t;
    my $coef2 = 1 - $self->beta2**$t;
    $lr *= sqrt($coef2)/$coef1;
    my ($mean, $var) = @{ $state };
    AI::MXNet::NDArray->adam_update(
        $weight, $grad, $mean, $var,
        {
            out => $weight,
            lr  => $lr,
            wd  => $wd,
            rescale_grad => $self->rescale_grad,
            %{ $self->kwargs }
        }
    );
}

__PACKAGE__->register;

=head1 NAME

    AI::MXNet::AdaGrad - AdaGrad optimizer of Duchi et al., 2011
=cut

=head1 DESCRIPTION

    AdaGrad optimizer of Duchi et al., 2011,

    This code follows the version in http://arxiv.org/pdf/1212.5701v1.pdf  Eq(5)
    by Matthew D. Zeiler, 2012. AdaGrad will help the network to converge faster
    in some cases.

    Parameters
    ----------
    learning_rate : Num, optional
        Step size.
        Default value is set to 0.05.

    wd : Num, optional
        L2 regularization coefficient add to all the weights

    rescale_grad : Num, optional
        rescaling factor of gradient. Normally should be 1/batch_size.

    eps: Num, optional
        A small float number to make the updating processing stable
        Default value is set to 1e-7.

    clip_gradient : Num, optional
        clip gradient in range [-clip_gradient, clip_gradient]
=cut
package AI::MXNet::AdaGrad;
use Mouse;

extends 'AI::MXNet::Optimizer';

has 'eps'    => (is => "rw", isa => "Num", default => 1e-7);

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    return AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context,
                stype => $weight->stype
    );  # history
}

method update(
    Index $index,
    AI::MXNet::NDArray $weight,
    AI::MXNet::NDArray $grad,
    AI::MXNet::NDArray $state
)
{
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    $self->_update_count($index);
    my $is_sparse = $grad->stype eq 'row_sparse' ? 1 : 0;
    my $history = $state;
    if($is_sparse)
    {
        my %kwargs = (
            epsilon => $self->eps,
            rescale_grad => $self->rescale_grad
        );
        if($self->clip_gradient)
        {
            $kwargs{clip_gradient} = $self->clip_gradient;
        }
        AI::MXNet::NDArray::Sparse->adagrad_update($weight, $grad, $history, { out=>$weight, lr=>$lr, wd=>$wd, %kwargs });
    }
    else
    {
        $grad *= $self->rescale_grad;
        if(defined $self->clip_gradient)
        {
            $grad = AI::MXNet::NDArray->clip($grad, -$self->clip_gradient, $self->clip_gradient);
        }
        $history += $grad->square;
        my $div = $grad / ($history + $self->eps)->sqrt;
        $weight += ($div + $weight * $wd) * -$lr;
    }
}

__PACKAGE__->register;

=head1 NAME

    AI::MXNet::RMSProp - RMSProp optimizer of Tieleman & Hinton, 2012.
=cut

=head1 DESCRIPTION

    RMSProp optimizer of Tieleman & Hinton, 2012,

    For centered=False, the code follows the version in
    http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf by
    Tieleman & Hinton, 2012

    For centered=True, the code follows the version in
    http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) by Alex Graves, 2013.

    Parameters
    ----------
    learning_rate : Num, optional
        Step size.
        Default value is set to 0.001.
    gamma1: Num, optional
        decay factor of moving average for gradient^2.
        Default value is set to 0.9.
    gamma2: Num, optional
        "momentum" factor.
        Default value if set to 0.9.
        Only used if centered=True
    epsilon : Num, optional
        Default value is set to 1e-8.
    centered : Bool, optional
        Use Graves or Tielemans & Hintons version of RMSProp
    wd : Num, optional
        L2 regularization coefficient add to all the weights
    rescale_grad : Num, optional
        rescaling factor of gradient.
    clip_gradient : Num, optional
        clip gradient in range [-clip_gradient, clip_gradient]
    clip_weights : Num, optional
        clip weights in range [-clip_weights, clip_weights]
=cut

package AI::MXNet::RMSProp;
use Mouse;

extends 'AI::MXNet::Optimizer';

has '+learning_rate' => (default => 0.001);
has 'gamma1'         => (is => "ro", isa => "Num",  default => 0.9);
has 'gamma2'         => (is => "ro", isa => "Num",  default => 0.9);
has 'epsilon'        => (is => "ro", isa => "Num",  default => 1e-8);
has 'centered'       => (is => "ro", isa => "Bool", default => 0);
has 'clip_weights'   => (is => "ro", isa => "Num");
has 'kwargs'         => (is => "rw", init_arg => undef);

sub BUILD
{
    my $self = shift;
    $self->kwargs({
        gamma1       => $self->gamma1,
        epsilon      => $self->epsilon
    });
    if($self->centered)
    {
        $self->kwargs->{gamma2} = $self->gamma2;
    }
    if($self->clip_gradient)
    {
        $self->kwargs->{clip_gradient} = $self->clip_gradient;
    }
    if($self->clip_weights)
    {
        $self->kwargs->{clip_weights} = $self->clip_weights;
    }
}

# For centered=False: n
# For centered=True: n, g, delta
method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    return [
            $self->centered
            ? (
                AI::MXNet::NDArray->zeros(
                    $weight->shape,
                    ctx => $weight->context,
                    stype => $weight->stype
                ),  # n
                AI::MXNet::NDArray->zeros(
                    $weight->shape,
                    ctx => $weight->context,
                    stype => $weight->stype
                ),  # g
                AI::MXNet::NDArray->zeros(
                    $weight->shape,
                    ctx => $weight->context,
                    stype => $weight->stype
                )
            )   # delta
            : (
                AI::MXNet::NDArray->zeros(
                    $weight->shape,
                    ctx => $weight->context,
                    stype => $weight->stype
                ),  # n
            )
    ];
}

method update(
    Index $index,
    AI::MXNet::NDArray $weight,
    AI::MXNet::NDArray $grad,
    ArrayRef[AI::MXNet::NDArray] $state
)
{
    my $lr = $self->_get_lr($index);
    my $wd = $self->_get_wd($index);
    $self->_update_count($index);
    my ($n, $g, $delta) = @{ $state };
    if($self->centered)
    {
        AI::MXNet::NDArray->rmspropalex_update(
            $weight, $grad, $n, $g, $delta,
            {
                out => $weight,
                lr  => $lr,
                wd  => $wd,
                rescale_grad => $self->rescale_grad,
                %{ $self->kwargs }
            }
        );
    }
    else
    {
        AI::MXNet::NDArray->rmsprop_update(
            $weight, $grad, $n,
            {
                out => $weight,
                lr  => $lr,
                wd  => $wd,
                rescale_grad => $self->rescale_grad,
                %{ $self->kwargs }
            }
        );
    }
}

__PACKAGE__->register;

=head1 NAME

    AI::MXNet::AdaDelta - AdaDelta optimizer.
=cut

=head1 DESCRIPTION

    AdaDelta optimizer as described in
    Zeiler, M. D. (2012).
    *ADADELTA: An adaptive learning rate method.*

    http://arxiv.org/abs/1212.5701

    Parameters
    ----------
    rho: Num
        Decay rate for both squared gradients and delta x
    epsilon : Num
        The constant as described in the thesis
    wd : Num
        L2 regularization coefficient add to all the weights
    rescale_grad : Num, optional
        rescaling factor of gradient. Normally should be 1/batch_size.
    clip_gradient : Num, optional
        clip gradient in range [-clip_gradient, clip_gradient]
=cut
package AI::MXNet::AdaDelta;
use Mouse;

extends 'AI::MXNet::Optimizer';

has 'rho'    => (is => "rw", isa => "Num", default => 0.9);
has 'epsilon'    => (is => "rw", isa => "Num", default => 1e-5);

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    return [
            AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context
            ),  # accumulated g
            AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context
            )   # accumulated delta
    ];
}

method update(
    Index $index,
    AI::MXNet::NDArray $weight,
    AI::MXNet::NDArray $grad,
    ArrayRef[AI::MXNet::NDArray] $state
)
{
    my $wd = $self->_get_wd($index);
    $self->_update_count($index);
    $grad *= $self->rescale_grad;
    if($self->clip_gradient)
    {
        $grad = AI::MXNet::NDArray->clip(
            $grad,
            -$self->clip_gradient,
             $self->clip_gradient
        );
    }
    my ($acc_g, $acc_delta) = @{ $state };
    $acc_g .= $self->rho * $acc_g + (1 - $self->rho) * $grad * $grad;
    my $current_delta = ($acc_delta + $self->epsilon)->sqrt
                            /
                        ($acc_g + $self->epsilon)->sqrt
                            *
                        $grad;
    $acc_delta .= $self->rho * $acc_delta + (1 - $self->rho) * $current_delta * $current_delta;
    $weight -= $current_delta + $wd * $weight;
}

__PACKAGE__->register;

# For test use
package AI::MXNet::Test;
use Mouse;

extends 'AI::MXNet::Optimizer';

# Create a state to duplicate weight
method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    return AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context
    );
}

# performs w += rescale_grad * grad
method update(
    Index $index,
    AI::MXNet::NDArray $weight,
    AI::MXNet::NDArray $grad,
    AI::MXNet::NDArray $state
)
{
    $weight += $grad * $self->rescale_grad;
    $state .= $weight;
}

__PACKAGE__->register;

package AI::MXNet::Ftrl;


=head1 NAME

    AI::MXNet::Ftrl
=cut

=head1 DESCRIPTION

    Referenced from *Ad Click Prediction: a View from the Trenches*, available at
    http://dl.acm.org/citation.cfm?id=2488200.

    The optimizer updates the weight by:

        rescaled_grad = clip(grad * rescale_grad, clip_gradient)
        z += rescaled_grad - (sqrt(n + rescaled_grad**2) - sqrt(n)) * weight / learning_rate
        n += rescaled_grad**2
        w = (sign(z) * lamda1 - z) / ((beta + sqrt(n)) / learning_rate + wd) * (abs(z) > lamda1)

    If the storage types of weight, state and grad are all row_sparse,
    **sparse updates** are applied by::

        for row in grad.indices:
            rescaled_grad[row] = clip(grad[row] * rescale_grad, clip_gradient)
            z[row] += rescaled_grad[row] - (sqrt(n[row] + rescaled_grad[row]**2) - sqrt(n[row])) * weight[row] / learning_rate
            n[row] += rescaled_grad[row]**2
            w[row] = (sign(z[row]) * lamda1 - z[row]) / ((beta + sqrt(n[row])) / learning_rate + wd) * (abs(z[row]) > lamda1)

    The sparse update only updates the z and n for the weights whose row_sparse
    gradient indices appear in the current batch, rather than updating it for all
    indices. Compared with the original update, it can provide large
    improvements in model training throughput for some applications. However, it
    provides slightly different semantics than the original update, and
    may lead to different empirical results.

    This optimizer accepts the following parameters in addition to those accepted
    by AI::MXNet::Optimizer

    Parameters
    ----------
    lamda1 : Num, optional
        L1 regularization coefficient.
    learning_rate : Num, optional
        The initial learning rate.
    beta : Num, optional
        Per-coordinate learning rate correlation parameter.
=cut

use Mouse;
extends 'AI::MXNet::Optimizer';
has '+learning_rate' => (default => 0.1);
has 'beta'           => (is => "ro", isa => "Num",  default => 1);
has 'lamda1'         => (is => "ro", isa => "Num",  default => 0.01);

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    return [
            AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context,
                stype => $weight->stype
            ),  # z
            AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context,
                stype => $weight->stype
            )   # n
    ];
}

method update(
    Index $index,
    AI::MXNet::NDArray $weight,
    AI::MXNet::NDArray $grad,
    ArrayRef[AI::MXNet::NDArray] $state
)
{
    $self->_update_count($index);
    my $wd = $self->_get_wd($index);
    my $lr = $self->_get_lr($index);
    my %kwargs = (lamda1 => $self->lamda1, beta => $self->beta, rescale_grad => $self->rescale_grad);
    if($self->clip_gradient)
    {
        $kwargs{clip_gradient} = $self->clip_gradient;
    }
    # accumulated g and delta initialization
    my ($z, $n) = @{ $state };
    AI::MXNet::NDArray->ftrl_update(
        $weight, $grad, $z, $n,
        { lr => $lr, wd => $wd, %kwargs, out => $weight }
    );
}

__PACKAGE__->register;

package AI::MXNet::Adamax;

=head1 NAME

    AI::MXNet::Adamax
=cut

=head1 DESCRIPTION

    It is a variant of Adam based on the infinity norm
    available at http://arxiv.org/abs/1412.6980 Section 7.

    This optimizer accepts the following parameters in addition to those accepted
    AI::MXNet::Optimizer.

    Parameters
    ----------
    beta1 : Num, optional
        Exponential decay rate for the first moment estimates.
    beta2 : Num, optional
        Exponential decay rate for the second moment estimates.
=cut

use Mouse;
extends 'AI::MXNet::Optimizer';
has '+learning_rate' => (default => 0.002);
has 'beta1'          => (is => "ro", isa => "Num",  default => 0.9);
has 'beta2'          => (is => "ro", isa => "Num",  default => 0.999);

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    return [
            AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context,
                dtype => $weight->dtype
            ),  # mean
            AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context,
                dtype => $weight->dtype
            )   # variance
    ];
}

method update(
    Index $index,
    AI::MXNet::NDArray $weight,
    AI::MXNet::NDArray $grad,
    ArrayRef[AI::MXNet::NDArray] $state
)
{
    my $wd = $self->_get_wd($index);
    my $lr = $self->_get_lr($index);
    $self->_update_count($index);
    my $t = $self->_index_update_count->{$index};
    $lr /= (1 - $self->beta1**$t);

    $grad = $grad * $self->rescale_grad + $wd * $weight;
    if($self->clip_gradient)
    {
        $grad = AI::MXNet::NDArray->clip(
            $grad,
            -$self->clip_gradient,
             $self->clip_gradient
        );
    }

    # update m_t and u_t
    my($m_t, $u_t) = @{ $state };
    $m_t .= $self->beta1 * $m_t + (1 - $self->beta1) * $grad;
    $u_t .= AI::MXNet::NDArray->maximum($self->beta2 * $u_t, $grad->abs);

    # update weight
    $weight -= $lr * $m_t / $u_t;
}

__PACKAGE__->register;

package AI::MXNet::Nadam;

=head1 NAME

    AI::MXNet::Nadam
=cut

=head1 DESCRIPTION

    The Nesterov Adam optimizer.

    Much like Adam is essentially RMSprop with momentum,
    Nadam is Adam RMSprop with Nesterov momentum available
    at http://cs229.stanford.edu/proj2015/054_report.pdf.

    This optimizer accepts the following parameters in addition to those accepted
    by AI::MXNet::Optimizer.

    Parameters
    ----------
    beta1 : Num, optional
        Exponential decay rate for the first moment estimates.
    beta2 : Num, optional
        Exponential decay rate for the second moment estimates.
    epsilon : Num, optional
        Small value to avoid division by 0.
    schedule_decay : Num, optional
        Exponential decay rate for the momentum schedule
=cut

use Mouse;
extends 'AI::MXNet::Optimizer';
has '+learning_rate' => (default => 0.001);
has 'beta1'          => (is => "ro", isa => "Num",  default => 0.9);
has 'beta2'          => (is => "ro", isa => "Num",  default => 0.999);
has 'epsilon'        => (is => "ro", isa => "Num",  default => 1e-8);
has 'schedule_decay' => (is => "ro", isa => "Num",  default => 0.004);
has 'm_schedule'     => (is => "rw", default => 1, init_arg => undef);

method create_state(Index $index, AI::MXNet::NDArray $weight)
{
    return [
            AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context,
                dtype => $weight->dtype
            ),  # mean
            AI::MXNet::NDArray->zeros(
                $weight->shape,
                ctx => $weight->context,
                dtype => $weight->dtype
            )   # variance
    ];
}

method update(
    Index $index,
    AI::MXNet::NDArray $weight,
    AI::MXNet::NDArray $grad,
    ArrayRef[AI::MXNet::NDArray] $state
)
{
    my $wd = $self->_get_wd($index);
    my $lr = $self->_get_lr($index);
    $self->_update_count($index);
    my $t = $self->_index_update_count->{$index};
    $grad = $grad * $self->rescale_grad + $wd * $weight;
    if($self->clip_gradient)
    {
        $grad = AI::MXNet::NDArray->clip(
            $grad,
            -$self->clip_gradient,
             $self->clip_gradient
        );
    }
    # warming momentum schedule
    my $momentum_t    = $self->beta1 * (1 - 0.5 * (0.96**($t * $self->schedule_decay)));
    my $momentum_t_1  = $self->beta1 * (1 - 0.5 * (0.96**(($t + 1) * $self->schedule_decay)));
    $self->m_schedule = $self->m_schedule * $momentum_t;
    my $m_schedule_next  = $self->m_schedule * $momentum_t_1;

    # update m_t and v_t
    my ($m_t, $v_t) = @{ $state };
    $m_t .= $self->beta1 * $m_t + (1 - $self->beta1) * $grad;
    $v_t .= $self->beta2 * $v_t + (1 - $self->beta2) * $grad * $grad;

    my $grad_prime = $grad / (1 - $self->m_schedule);
    my $m_t_prime  = $m_t  / (1 - $m_schedule_next);
    my $v_t_prime  = $v_t  / (1 - $self->beta2**$t);
    my $m_t_bar    = (1 - $momentum_t) * $grad_prime + $momentum_t_1 * $m_t_prime;

    # update weight
    $weight -= $lr * $m_t_bar / (sqrt($v_t_prime) + $self->epsilon);
}

__PACKAGE__->register;

=head1 NAME

    AI::MXNet::Updater - Updater for kvstore
=cut

package AI::MXNet::Updater;
use Mouse;
use Storable qw(thaw freeze);
use overload "&{}" => sub { my $self = shift; sub { $self->call(@_) } },
             fallback => 1;

has "optimizer"     => (is => "rw", isa => "AI::MXNet::Optimizer");
has "states"        => (is => "rw", isa => "HashRef", default => sub { +{} });
has "states_synced" => (is => "rw", isa => "HashRef", default => sub { +{} });

method call(Index $index, AI::MXNet::NDArray $grad, AI::MXNet::NDArray $weight)
{
    if(not exists $self->states->{ $index })
    {
        $self->states->{ $index } = $self->optimizer->create_state($index, $weight);
        $self->states_synced->{ $index } = 1;
    }
    elsif(not $self->states_synced->{ $index })
    {
        $self->states->{ $index } = $self->sync_state_context($self->states->{ $index }, $weight->context);
        $self->states_synced->{ $index } = 1;
    }
    $self->optimizer->update($index, $weight, $grad, $self->states->{ $index });
}
*slice = *call;

method sync_state_context(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $state, AI::MXNet::Context $context)
{
    if(blessed $state)
    {
        return $state->as_in_context($context);
    }
    elsif(ref $state)
    {
        return [map { $self->sync_state_context($_, $context) } @{ $state }];
    }
    return $state;
}

=head2 set_states

    Sets updater states.
=cut

method set_states($states)
{
    my $thawed_states = thaw($states);
    my ($optimizer);
    if(ref $thawed_states eq 'ARRAY')
    {
        ($thawed_states, $optimizer) = @{ $thawed_states };
        $self->optimizer($optimizer);
    }
    $self->states($thawed_states);
    %{ $self->states_synced } = map { $_ => 0 } keys %{ $thawed_states };
}

=head2 get_states

        Gets updater states.

        Parameters
        ----------
        dump_optimizer : bool, default False
            Whether to also save the optimizer itself. This would also save optimizer
            information such as learning rate and weight decay schedules.
=cut

method get_states(Bool $dump_optimizer=0)
{
    if($dump_optimizer)
    {
        my $param_dict = $self->optimizer->param_dict;
        $self->optimizer->param_dict({});
        my $freezed = freeze([$self->states, $self->optimizer]);
        $self->optimizer->param_dict($param_dict);
        return $freezed;
    }
    return freeze($self->states);
}

package AI::MXNet::Optimizer;

method get_updater(AI::MXNet::Optimizer $optimizer)
{
    return AI::MXNet::Updater->new(optimizer => $optimizer);
}

1;