residual_rms_rocm / torch-ext /torch_binding.h
medmekk's picture
medmekk HF Staff
first commit
7587d0b
raw
history blame
264 Bytes
#pragma once
#include <torch/torch.h>
void residual_rms(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, torch::Tensor& scale_tensor, double epsilon, torch::Tensor& output, torch::Tensor& next_buffer, int64_t num_threads, bool force_scalar);