diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 746d6c1ec..09d73f5a6 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -357,9 +357,6 @@ def backward(ctx, grad_output): class MatMul4Bit(torch.autograd.Function): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - @staticmethod def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None): # default of pytorch behavior if inputs are empty @@ -377,7 +374,15 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + # Use linear function which correctly handles 1D and 2D inputs + result = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + + # If out is provided, resize it if necessary and copy the result + if out is not None: + if out.shape != result.shape: + out.resize_(result.shape) + out.copy_(result) + result = out # 3. Save state ctx.state = quant_state @@ -388,7 +393,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] else: ctx.tensors = (None, None) - return output + return result @staticmethod def backward(ctx, grad_output): @@ -458,9 +463,14 @@ def matmul_4bit( ) return MatMul4Bit.apply(A, B, out, bias, quant_state) else: - out = F.gemv_4bit(A, B.t(), out, state=quant_state) + # For 1D case, we'll use the MatMul4Bit implementation which correctly handles out parameter + if out is not None and A.dim() == 1: + return MatMul4Bit.apply(A, B, out, bias, quant_state) + + # For other cases, use gemv_4bit + result = F.gemv_4bit(A, B.t(), out, state=quant_state) if bias is not None: - out += bias - return out + result += bias + return result else: return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index c266f61a0..a32a005e4 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -427,11 +427,18 @@ def _( blocksize: int, out: torch.Tensor, ) -> None: + expected_shape = (*A.shape[:-1], shapeB[0]) + + if len(A.shape) == 1 and len(out.shape) == 2 and out.shape[0] == 1: + out = out.view(shapeB[0]) + expected_shape = (shapeB[0],) + torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + out.shape == expected_shape, + lambda: f"Expected out.shape == {expected_shape}, got {out.shape}", ) torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) @@ -446,32 +453,38 @@ def _gemv_4bit_impl( ) -> None: torch._check_is_size(blocksize) - # Note: these checks are not strictly necessary, and cost more than they are worth, so they are commented out for now. - # torch._check( - # A.numel() == A.size(-1), - # lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - # ) - # torch._check( - # A.dtype in [torch.float16, torch.bfloat16, torch.float32], - # lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", - # ) - # torch._check( - # B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], - # lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", - # ) - # torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") - # torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") + is_1d = A.dim() == 1 + if is_1d: + A_reshaped = A.view(1, -1) + else: + A_reshaped = A + + torch._check( + A.dtype in [torch.float16, torch.bfloat16, torch.float32], + lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", + ) + torch._check( + B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32], + lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}", + ) + torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}") + torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") m = ct.c_int32(shapeB[0]) n = ct.c_int32(1) k = ct.c_int32(shapeB[1]) lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldb = ct.c_int32((A_reshaped.shape[-1] + 1) // 2) ldc = m stream = _get_tensor_stream(A) + if is_1d and out.dim() > 1: + out_view = out.view(-1) + else: + out_view = out + with _cuda_device_of(A): if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16(