From f177d537b3239bc75d2817dfd8e4182cf382bb1a Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 19 Apr 2024 17:25:17 +0000
Subject: [PATCH] Make TPU v4 kernel drivers compatible with linux kernel 6.2.

PiperOrigin-RevId: 626403733
---
 tools/driver/drivers/accel/accel.c            |  4 +
 .../asic_sw/asic_fw_device_owner_accessor.h   | 10 +--
 .../asic_sw/asic_fw_indirect_accessor.h       | 81 ++++++++++---------
 tools/driver/drivers/gasket/gasket_core.c     | 17 ++--
 tools/driver/drivers/gasket/gasket_dmabuf.c   |  6 --
 .../driver/drivers/gasket/gasket_page_table.c | 16 ----
 6 files changed, 63 insertions(+), 71 deletions(-)

diff --git a/tools/driver/drivers/accel/accel.c b/tools/driver/drivers/accel/accel.c
index de86f9f2..6d521a49 100644
--- a/tools/driver/drivers/accel/accel.c
+++ b/tools/driver/drivers/accel/accel.c
@@ -299,7 +299,11 @@ static struct attribute *accel_dev_attrs[] = {
  NULL,
 };
 ATTRIBUTE_GROUPS(accel_dev);
+#if LINUX_VERSION_CODE < KERNEL_VERSION(6, 2, 0)
 static int accel_dev_uevent(struct device *dev, struct kobj_uevent_env *env)
+#else
+static int accel_dev_uevent(const struct device *dev, struct kobj_uevent_env *env)
+#endif
 {
  struct accel_dev *adev = to_accel_dev(dev);
  int retval = 0;
diff --git a/tools/driver/drivers/asic_sw/asic_fw_device_owner_accessor.h b/tools/driver/drivers/asic_sw/asic_fw_device_owner_accessor.h
index e0e2d0b3..c85d6f3f 100644
--- a/tools/driver/drivers/asic_sw/asic_fw_device_owner_accessor.h
+++ b/tools/driver/drivers/asic_sw/asic_fw_device_owner_accessor.h
@@ -5,14 +5,14 @@
 #ifndef _DRIVERS_ASIC_SW_ASIC_FW_DEVICE_OWNER_ACCESSOR_H_
 #define _DRIVERS_ASIC_SW_ASIC_FW_DEVICE_OWNER_ACCESSOR_H_ 
 #include "drivers/gasket/gasket_types.h"
-static inline uint64 asic_fw_device_owner_value(const uint64 reg_value)
+static inline uint64_t asic_fw_device_owner_value(const uint64_t reg_value)
 {
- return (uint64)((((reg_value >> 0) & 0xffffffffffffffffULL) << 0));
+ return (uint64_t)((((reg_value >> 0) & 0xffffffffffffffffULL) << 0));
 }
-static inline int set_asic_fw_device_owner_value(uint64 *reg_value,
-       uint64 value)
+static inline int set_asic_fw_device_owner_value(uint64_t *reg_value,
+       uint64_t value)
 {
- if (value & ~(0xffffffffffffffffULL))
+ if ((uint64_t)value < 0x0ULL || (uint64_t)value > 0xffffffffffffffffULL)
   return 1;
  (*reg_value) = ((*reg_value) & ~((0xffffffffffffffffULL) << 0)) |
          (((value >> 0) & (0xffffffffffffffffULL)) << 0);
diff --git a/tools/driver/drivers/asic_sw/asic_fw_indirect_accessor.h b/tools/driver/drivers/asic_sw/asic_fw_indirect_accessor.h
index 2c7730a7..43d8d3d7 100644
--- a/tools/driver/drivers/asic_sw/asic_fw_indirect_accessor.h
+++ b/tools/driver/drivers/asic_sw/asic_fw_indirect_accessor.h
@@ -12,57 +12,59 @@ enum asic_fw_indirect_accessor_status_status_value {
 };
 typedef enum asic_fw_indirect_accessor_status_status_value
  asic_fw_indirect_accessor_status_status_value;
-static inline uint64
-asic_fw_indirect_accessor_version_version(const uint64 reg_value)
+static inline uint64_t
+asic_fw_indirect_accessor_version_version(const uint64_t reg_value)
 {
- return (uint64)((((reg_value >> 0) & 0xffffffffffffffffULL) << 0));
+ return (uint64_t)((((reg_value >> 0) & 0xffffffffffffffffULL) << 0));
 }
 static inline int
-set_asic_fw_indirect_accessor_version_version(uint64 *reg_value, uint64 value)
+set_asic_fw_indirect_accessor_version_version(uint64_t *reg_value,
+           uint64_t value)
 {
- if (value & ~(0xffffffffffffffffULL))
+ if ((uint64_t)value < 0x0ULL || (uint64_t)value > 0xffffffffffffffffULL)
   return 1;
  (*reg_value) = ((*reg_value) & ~((0xffffffffffffffffULL) << 0)) |
          (((value >> 0) & (0xffffffffffffffffULL)) << 0);
  return 0;
 }
-static inline uint64
-asic_fw_indirect_accessor_address_address(const uint64 reg_value)
+static inline uint64_t
+asic_fw_indirect_accessor_address_address(const uint64_t reg_value)
 {
- return (uint64)((((reg_value >> 0) & 0xffffffffffffffffULL) << 0));
+ return (uint64_t)((((reg_value >> 0) & 0xffffffffffffffffULL) << 0));
 }
 static inline int
-set_asic_fw_indirect_accessor_address_address(uint64 *reg_value, uint64 value)
+set_asic_fw_indirect_accessor_address_address(uint64_t *reg_value,
+           uint64_t value)
 {
- if (value & ~(0xffffffffffffffffULL))
+ if ((uint64_t)value < 0x0ULL || (uint64_t)value > 0xffffffffffffffffULL)
   return 1;
  (*reg_value) = ((*reg_value) & ~((0xffffffffffffffffULL) << 0)) |
          (((value >> 0) & (0xffffffffffffffffULL)) << 0);
  return 0;
 }
-static inline uint8
-asic_fw_indirect_accessor_control_write(const uint64 reg_value)
+static inline uint8_t
+asic_fw_indirect_accessor_control_write(const uint64_t reg_value)
 {
- return (uint8)((((reg_value >> 0) & 0x1ULL) << 0));
+ return (uint8_t)((((reg_value >> 0) & 0x1ULL) << 0));
 }
-static inline int set_asic_fw_indirect_accessor_control_write(uint64 *reg_value,
-             uint8 value)
+static inline int
+set_asic_fw_indirect_accessor_control_write(uint64_t *reg_value, uint8_t value)
 {
- if (value & ~(0x1ULL))
+ if ((uint64_t)value < 0x0ULL || (uint64_t)value > 0x1ULL)
   return 1;
  (*reg_value) = ((*reg_value) & ~((0x1ULL) << 0)) |
          (((value >> 0) & (0x1ULL)) << 0);
  return 0;
 }
-static inline uint8
-asic_fw_indirect_accessor_control_read(const uint64 reg_value)
+static inline uint8_t
+asic_fw_indirect_accessor_control_read(const uint64_t reg_value)
 {
- return (uint8)((((reg_value >> 1) & 0x1ULL) << 0));
+ return (uint8_t)((((reg_value >> 1) & 0x1ULL) << 0));
 }
-static inline int set_asic_fw_indirect_accessor_control_read(uint64 *reg_value,
-            uint8 value)
+static inline int
+set_asic_fw_indirect_accessor_control_read(uint64_t *reg_value, uint8_t value)
 {
- if (value & ~(0x1ULL))
+ if ((uint64_t)value < 0x0ULL || (uint64_t)value > 0x1ULL)
   return 1;
  (*reg_value) = ((*reg_value) & ~((0x1ULL) << 1)) |
          (((value >> 0) & (0x1ULL)) << 1);
@@ -97,44 +99,45 @@ static inline const char *asic_fw_indirect_accessor_status_status_value_name(
  return "UNKNOWN VALUE";
 }
 static inline asic_fw_indirect_accessor_status_status_value
-asic_fw_indirect_accessor_status_status(const uint64 reg_value)
+asic_fw_indirect_accessor_status_status(const uint64_t reg_value)
 {
- return (asic_fw_indirect_accessor_status_status_value)(
-  (((reg_value >> 0) & 0xffULL) << 0));
+ return (asic_fw_indirect_accessor_status_status_value)((
+  ((reg_value >> 0) & 0xffULL) << 0));
 }
 static inline int set_asic_fw_indirect_accessor_status_status(
- uint64 *reg_value, asic_fw_indirect_accessor_status_status_value value)
+ uint64_t *reg_value,
+ asic_fw_indirect_accessor_status_status_value value)
 {
- if (value & ~(0xffULL))
+ if ((uint64_t)value < 0x0ULL || (uint64_t)value > 0xffULL)
   return 1;
  (*reg_value) = ((*reg_value) & ~((0xffULL) << 0)) |
          (((value >> 0) & (0xffULL)) << 0);
  return 0;
 }
-static inline uint8
-asic_fw_indirect_accessor_status_chip_specific_status(const uint64 reg_value)
+static inline uint8_t
+asic_fw_indirect_accessor_status_chip_specific_status(const uint64_t reg_value)
 {
- return (uint8)((((reg_value >> 8) & 0xffULL) << 0));
+ return (uint8_t)((((reg_value >> 8) & 0xffULL) << 0));
 }
 static inline int
-set_asic_fw_indirect_accessor_status_chip_specific_status(uint64 *reg_value,
-         uint8 value)
+set_asic_fw_indirect_accessor_status_chip_specific_status(uint64_t *reg_value,
+         uint8_t value)
 {
- if (value & ~(0xffULL))
+ if ((uint64_t)value < 0x0ULL || (uint64_t)value > 0xffULL)
   return 1;
  (*reg_value) = ((*reg_value) & ~((0xffULL) << 8)) |
          (((value >> 0) & (0xffULL)) << 8);
  return 0;
 }
-static inline uint64
-asic_fw_indirect_accessor_value_value(const uint64 reg_value)
+static inline uint64_t
+asic_fw_indirect_accessor_value_value(const uint64_t reg_value)
 {
- return (uint64)((((reg_value >> 0) & 0xffffffffffffffffULL) << 0));
+ return (uint64_t)((((reg_value >> 0) & 0xffffffffffffffffULL) << 0));
 }
-static inline int set_asic_fw_indirect_accessor_value_value(uint64 *reg_value,
-           uint64 value)
+static inline int set_asic_fw_indirect_accessor_value_value(uint64_t *reg_value,
+           uint64_t value)
 {
- if (value & ~(0xffffffffffffffffULL))
+ if ((uint64_t)value < 0x0ULL || (uint64_t)value > 0xffffffffffffffffULL)
   return 1;
  (*reg_value) = ((*reg_value) & ~((0xffffffffffffffffULL) << 0)) |
          (((value >> 0) & (0xffffffffffffffffULL)) << 0);
diff --git a/tools/driver/drivers/gasket/gasket_core.c b/tools/driver/drivers/gasket/gasket_core.c
index c4b36aa5..b471917e 100644
--- a/tools/driver/drivers/gasket/gasket_core.c
+++ b/tools/driver/drivers/gasket/gasket_core.c
@@ -552,8 +552,8 @@ static void gasket_setup_pci_iommu(struct gasket_dev *gasket_dev)
  } else if (driver_desc->iommu_mappings == GASKET_IOMMU_PREFER) {
 #if 0
 #else
-  gasket_log_warn(gasket_dev,
-   "IOMMU Mappings: Cannot enable");
+                gasket_log_warn(gasket_dev,
+                                "IOMMU Mappings: Cannot enable");
 #endif
  }
 }
@@ -1255,6 +1255,8 @@ int gasket_mm_unmap_region(
  int bar_index;
  ulong bar_offset;
  ulong virt_offset;
+ ulong address;
+ ulong size;
  struct gasket_mappable_region mappable_region;
  if (vma->vm_private_data != gasket_dev)
   return -EINVAL;
@@ -1279,9 +1281,14 @@ int gasket_mm_unmap_region(
  if (!gasket_mm_get_mapping_addrs(map_region, bar_offset,
   vma->vm_end - vma->vm_start, &mappable_region, &virt_offset))
   return 1;
- zap_vma_ptes(vma, vma->vm_start + virt_offset,
-  DIV_ROUND_UP(mappable_region.length_bytes, PAGE_SIZE) *
-   PAGE_SIZE); return 0;
+ address = vma->vm_start + virt_offset;
+ size = DIV_ROUND_UP(mappable_region.length_bytes, PAGE_SIZE) *
+   PAGE_SIZE;
+ if (address < vma->vm_start || address + size > vma->vm_end ||
+   !(vma->vm_flags & VM_PFNMAP))
+  return -1;
+ zap_vma_ptes(vma, address, size);
+ return 0;
 }
 EXPORT_SYMBOL(gasket_mm_unmap_region);
 static enum do_map_region_status do_map_region(
diff --git a/tools/driver/drivers/gasket/gasket_dmabuf.c b/tools/driver/drivers/gasket/gasket_dmabuf.c
index 6130d0b9..64858db2 100644
--- a/tools/driver/drivers/gasket/gasket_dmabuf.c
+++ b/tools/driver/drivers/gasket/gasket_dmabuf.c
@@ -133,7 +133,6 @@ static struct sg_table *gasket_dma_buf_ops_map(
  phys_addr = gasket_dbuf->mmap_offset -
       gasket_dev->driver_desc->bar_descriptions[bar_index].base +
       gasket_dev->bar_data[bar_index].phys_base;
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 1, 0)
  addr = dma_map_resource(attachment->dev, phys_addr, dbuf->size,
   direction, DMA_ATTR_SKIP_CPU_SYNC);
  ret = dma_mapping_error(attachment->dev, addr);
@@ -143,9 +142,6 @@ static struct sg_table *gasket_dma_buf_ops_map(
   sg_free_table(sgt);
   goto error_alloc_table;
  }
-#else
- addr = phys_addr;
-#endif
  sg_set_page(sgt->sgl, NULL, dbuf->size, 0);
  sg_dma_address(sgt->sgl) = addr;
  sg_dma_len(sgt->sgl) = dbuf->size;
@@ -157,7 +153,6 @@ static struct sg_table *gasket_dma_buf_ops_map(
 static void gasket_dma_buf_ops_unmap(struct dma_buf_attachment *attachment,
  struct sg_table *sgt, enum dma_data_direction direction)
 {
-#if LINUX_VERSION_CODE >= KERNEL_VERSION(5, 1, 0)
  struct scatterlist *sg;
  int i;
  for_each_sg((sgt)->sgl, sg, (sgt)->orig_nents, i)
@@ -165,7 +160,6 @@ static void gasket_dma_buf_ops_unmap(struct dma_buf_attachment *attachment,
   dma_unmap_resource(attachment->dev, sg_dma_address(sg),
    sg_dma_len(sg), direction, DMA_ATTR_SKIP_CPU_SYNC);
  }
-#endif
  sg_free_table(sgt);
  kfree(sgt);
 }
diff --git a/tools/driver/drivers/gasket/gasket_page_table.c b/tools/driver/drivers/gasket/gasket_page_table.c
index c83b0b3b..ed3f3da2 100644
--- a/tools/driver/drivers/gasket/gasket_page_table.c
+++ b/tools/driver/drivers/gasket/gasket_page_table.c
@@ -12,9 +12,7 @@
 #include <linux/pagemap.h>
 #include <linux/version.h>
 #include <linux/vmalloc.h>
-#if LINUX_VERSION_CODE > KERNEL_VERSION(5, 7, 19)
 #include <linux/dma-resv.h>
-#endif
 #if LINUX_VERSION_CODE > KERNEL_VERSION(5, 16, 0)
 MODULE_IMPORT_NS(DMA_BUF);
 #endif
@@ -915,7 +913,6 @@ static size_t gasket_sgt_get_contiguous_size(
  }
  return sz;
 }
-#if LINUX_VERSION_CODE > KERNEL_VERSION(5, 7, 19)
 static void gasket_page_table_dma_buf_move_notify(
  struct dma_buf_attachment *attachment)
 {
@@ -941,7 +938,6 @@ static const struct dma_buf_attach_ops gasket_dma_buf_attach_ops = {
  .allow_peer2peer = true,
  .move_notify = gasket_page_table_dma_buf_move_notify,
 };
-#endif
 static struct gasket_sgt_mapping *gasket_page_table_import_dma_buf(
  struct gasket_page_table *pg_tbl, int dma_buf_fd)
 {
@@ -971,12 +967,8 @@ static struct gasket_sgt_mapping *gasket_page_table_import_dma_buf(
   ret = -ENOMEM;
   goto failed_mapping_alloc;
  }
-#if LINUX_VERSION_CODE > KERNEL_VERSION(5, 7, 19)
  mapping->dbuf_attach = dma_buf_dynamic_attach(dbuf,
   &gasket_dev->pci_dev->dev, &gasket_dma_buf_attach_ops, pg_tbl);
-#else
- mapping->dbuf_attach = dma_buf_attach(dbuf, &gasket_dev->pci_dev->dev);
-#endif
  if (IS_ERR(mapping->dbuf_attach)) {
   ret = PTR_ERR(mapping->dbuf_attach);
   gasket_log_error(
@@ -984,14 +976,10 @@ static struct gasket_sgt_mapping *gasket_page_table_import_dma_buf(
   goto failed_attach;
  }
  mapping->size = dbuf->size;
-#if LINUX_VERSION_CODE > KERNEL_VERSION(5, 7, 19)
  dma_resv_lock(dbuf->resv, NULL);
-#endif
  mapping->sgt =
   dma_buf_map_attachment(mapping->dbuf_attach, DMA_BIDIRECTIONAL);
-#if LINUX_VERSION_CODE > KERNEL_VERSION(5, 7, 19)
  dma_resv_unlock(dbuf->resv);
-#endif
  if (IS_ERR(mapping->sgt)) {
   ret = PTR_ERR(mapping->sgt);
   gasket_log_error(gasket_dev,
@@ -1016,14 +1004,10 @@ static void gasket_page_table_detach_sgt_mapping(
   list_del_init(&mapping->entry);
  if (mapping->dbuf_attach) {
   dbuf = mapping->dbuf_attach->dmabuf;
-#if LINUX_VERSION_CODE > KERNEL_VERSION(5, 7, 19)
   dma_resv_lock(dbuf->resv, NULL);
-#endif
   dma_buf_unmap_attachment(
    mapping->dbuf_attach, mapping->sgt, DMA_BIDIRECTIONAL);
-#if LINUX_VERSION_CODE > KERNEL_VERSION(5, 7, 19)
   dma_resv_unlock(dbuf->resv);
-#endif
   dma_buf_detach(dbuf, mapping->dbuf_attach);
   dma_buf_put(dbuf);
  }