File size: 264 Bytes
7587d0b |
1 2 3 4 5 |
#pragma once
#include <torch/torch.h>
void residual_rms(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, torch::Tensor& scale_tensor, double epsilon, torch::Tensor& output, torch::Tensor& next_buffer, int64_t num_threads, bool force_scalar); |