Skip to content

Commit

Permalink
Merge pull request #9 from DaGaiBa/rc4main
Browse files Browse the repository at this point in the history
Fix precision error of FocalLoss op
  • Loading branch information
momo609 authored May 30, 2024
2 parents 2d6978e + ce612f7 commit eafcede
Showing 1 changed file with 66 additions and 12 deletions.
78 changes: 66 additions & 12 deletions mmcv/ops/csrc/pytorch/npu/focal_loss_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ using namespace std;

void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
at::Tensor input_y = input;
at::Tensor output_y = output;
bool is_half = input.scalar_type() == at::kHalf;
if (is_half) {
input_y = input.to(at::kFloat);
output_y = output.to(at::kFloat);
}
int64_t n_class = input.size(1);
at::Tensor target_y = at::ones_like(input);
if (n_class == 1) {
Expand All @@ -15,21 +22,28 @@ void sigmoid_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
}
target_y = target_y.to(at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
at::Tensor weight_y = at::ones_like(input_y);
if (weight_size > 0) {
weight_y = at::broadcast_to(weight, input.sizes());
if (is_half) {
weight_y = weight_y.to(at::kFloat);
}
}
OpCommand cmd;
string reduction = "none";
cmd.Name("SigmoidFocalLoss")
.Input(input)
.Input(input_y)
.Input(target_y)
.Input(weight_y)
.Output(output)
.Output(output_y)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", reduction)
.Run();
if (is_half) {
output_y = output_y.to(at::kHalf);
}
output.copy_(output_y);
}

void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
Expand All @@ -38,6 +52,13 @@ void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor grad_input, float gamma,
float alpha) {
at::Tensor input_y = input;
at::Tensor grad_input_y = grad_input;
bool is_half = input.scalar_type() == at::kHalf;
if (is_half) {
input_y = input.to(at::kFloat);
grad_input_y = grad_input.to(at::kFloat);
}
int64_t n_class = input.size(1);
at::Tensor target_y = at::ones_like(input);
if (n_class == 1) {
Expand All @@ -50,22 +71,29 @@ void sigmoid_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
target_y = target_y.to(at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
at::Tensor weight_y = at::ones_like(input_y);
if (weight_size > 0) {
weight_y = at::broadcast_to(weight, input.sizes());
if (is_half) {
weight_y = weight_y.to(at::kFloat);
}
}
OpCommand cmd;
string reduction = "none";
cmd.Name("SigmoidFocalLossGrad")
.Input(input)
.Input(input_y)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Output(grad_input_y)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", reduction)
.Run();
if (is_half) {
grad_input_y = grad_input_y.to(at::kHalf);
}
grad_input.copy_(grad_input_y);
}

void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,
Expand All @@ -74,26 +102,38 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target,

void softmax_focal_loss_forward_npu(Tensor input, Tensor target, Tensor weight,
Tensor output, float gamma, float alpha) {
at::Tensor input_y = input;
bool is_half = input.scalar_type() == at::kHalf;
if (is_half) {
input_y = input.to(at::kFloat);
}
int64_t n_class = input.size(1);
at::Tensor target_y = at::one_hot(target, n_class);
target_y = target_y.to(at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
at::Tensor weight_y = at::ones_like(input_y);
if (weight_size > 0) {
weight_y = at::broadcast_to(weight, input.sizes());
if (is_half) {
weight_y = weight_y.to(at::kFloat);
}
}
at::Tensor op_output = at::ones_like(input);

at::Tensor op_output = at::ones_like(input_y);
OpCommand cmd;
string reduction = "none";
cmd.Name("SoftmaxFocalLoss")
.Input(input)
.Input(input_y)
.Input(target_y)
.Input(weight_y)
.Output(op_output)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", reduction)
.Run();
if (is_half) {
op_output = op_output.to(at::kHalf);
}
int64_t n_batch = input.size(0);
c10::SmallVector<int64_t, 2> offsets = {0, 0};
c10::SmallVector<int64_t, 2> sizes = {n_batch, 1};
Expand Down Expand Up @@ -124,27 +164,41 @@ void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight,
void softmax_focal_loss_backward_npu(Tensor input, Tensor target, Tensor weight,
Tensor buff, Tensor grad_input,
float gamma, float alpha) {
at::Tensor input_y = input;
at::Tensor grad_input_y = grad_input;
bool is_half = input.scalar_type() == at::kHalf;
if (is_half) {
input_y = input.to(at::kFloat);
grad_input_y = grad_input.to(at::kFloat);
}
int64_t n_class = input.size(1);
at::Tensor target_y = at::one_hot(target, n_class);
target_y = target_y.to(at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
at::Tensor weight_y = at::ones_like(input_y);
if (weight_size > 0) {
weight_y = at::broadcast_to(weight, input.sizes());
if (is_half) {
weight_y = weight_y.to(at::kFloat);
}
}
OpCommand cmd;
string reduction = "none";
cmd.Name("SoftmaxFocalLossGrad")
.Input(input)
.Input(input_y)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Output(grad_input_y)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", reduction)
.Run();
if (is_half) {
grad_input_y = grad_input_y.to(at::kHalf);
}
grad_input.copy_(grad_input_y);
}

void softmax_focal_loss_backward_impl(Tensor input, Tensor target,
Expand Down

0 comments on commit eafcede

Please sign in to comment.