Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions EasyChemML/Utilities/DataUtilities/BatchDatatypHolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,15 @@ def checkAll_numbers(self) -> bool:
return False
return True
else:
return True
return BatchDatatypClass.get_dtype_lvl(self[self.getColumns()[0]]) > 0

def check_containsObjects(self) -> bool:
if len(self) > 1:
first = self[self.getColumns()[0]]

for item in self:
item: BatchDatatyp = self[item]
if BatchDatatypClass.get_dtype_lvl(item) < 0:
if BatchDatatypClass.get_dtype_lvl(item) < -1:
return True
return False
else:
Expand Down
16 changes: 7 additions & 9 deletions EasyChemML/Utilities/DataUtilities/BatchTable.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def __setitem__(self, key, value):
if not (isinstance(value, np.ndarray) or isinstance(value, np.void)):
raise Exception(f'the passed batch is not a numpy array or numpy array entry| {type(value)}')

if not self.getDatatypes().check_containsObjects() and not value.dtype == self[0].dtype:
my_datatypes: BatchDatatypHolder = self.getDatatypes()

if my_datatypes.check_containsObjects() or not (value.dtype == self[0].dtype or (BatchDatatypClass.NUMPY_STRING in my_datatypes and len(value.shape)==1)):
raise Exception(f'the dtype of the numpy array is different to the batchtable | {type(value)}')
value: np.array = value

my_datatypes: BatchDatatypHolder = self.getDatatypes()

if BatchDatatypClass.PYTHON_OBJECT in my_datatypes:
my_datatypes.removeAllnoneObject()
# notObjects = [item for item in self.getDatatypes().getColumns() if item not in my_datatypes.getColumns()]
Expand Down Expand Up @@ -270,7 +270,7 @@ def getWith(self, columns: List[str] = None):
else:
return 0

def convert_2_ndarray(self, indicies: List[int] = None, columns: List[str] = None) -> np.ndarray:
def convert_2_ndarray(self, indicies: List[int] = None, columns: List[str] = None):
if columns is None:
columns = self.getColumns()

Expand All @@ -292,7 +292,6 @@ def convert_2_ndarray(self, indicies: List[int] = None, columns: List[str] = Non
else:
raise Exception('different dTypes are not supported')

np_dtype = datatyp.toNUMPY()
# if indicies is not None:
# shape = (len(indicies), self.getWith(columns))
# else:
Expand All @@ -301,8 +300,7 @@ def convert_2_ndarray(self, indicies: List[int] = None, columns: List[str] = Non
if indicies is None:
if datatyp == BatchDatatypClass.NUMPY_STRING:
raw_data = self.getComplexSelection(columns, indicies)
data = rfn.structured_to_unstructured(raw_data, dtype=np_dtype)
return data
return [i[0].decode() for i in raw_data]
else:
raw_data = self.getComplexSelection(columns, indicies)
if len(columns) == 1:
Expand All @@ -313,8 +311,8 @@ def convert_2_ndarray(self, indicies: List[int] = None, columns: List[str] = Non
else:
if datatyp == BatchDatatypClass.NUMPY_STRING:
raw_data = self.getComplexSelection(columns, indicies)
data = rfn.structured_to_unstructured(raw_data, dtype=np_dtype)
return data
return [i[0].decode() for i in raw_data]

else:
raw_data: np.ndarray = self.getComplexSelection(columns, indicies)
if len(columns) == 1:
Expand Down
5 changes: 3 additions & 2 deletions EasyChemML/Utilities/DataUtilities/BatchTableAlgorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from EasyChemML.Environment import Environment
from EasyChemML.Utilities.DataUtilities.BatchTable import BatchTable
from EasyChemML.Utilities.DataUtilities.TableAlgorithms.Sort.MergeSort import MergeSort
# from EasyChemML.Utilities.DataUtilities.TableAlgorithms.Sort.MergeSort import MergeSort
from EasyChemML.Utilities.DataUtilities.TableAlgorithms.ReorderInplace import ReorderInplace


Expand All @@ -15,7 +15,8 @@ def __init__(self, env:Environment):
self.env = env

def mergeSort(self, batchTable: BatchTable):
MergeSort.sort(batchTable, env=self.env)
# MergeSort.sort(batchTable, env=self.env)
print("no merge sort")

def sort(self, batchTable: BatchTable, key_func: Callable[[Any, Any], int] = None):
length = len(batchTable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class RustBatchListFunctions_duplicates:

def __init__(self):
self._batchlist_duplicates = BatchListFunctions_duplicates_py()

# TODO: string support hinzufügen
def count_duplicates(self, rustbatchholder: RustBatchholder, tableName: str,
get_distibution_by_last_col: bool = False) -> RustBatchListFunctions_duplicates_result:
bt = rustbatchholder.getRustBatchTable(tableName)
Expand All @@ -76,7 +76,11 @@ def count_duplicates(self, rustbatchholder: RustBatchholder, tableName: str,
result.duplicates_by_last_col)
elif dtype == dtype.NUMPY_FLOAT32:
raise Exception('float32 is not sortable at the moment')
elif dtype == dtype.NUMPY_STRING:
result = self._batchlist_duplicates.count_duplicates_on_sorted_list_string(bt, get_distibution_by_last_col)
return result
elif dtype == dtype.NUMPY_FLOAT64:
raise Exception('float64 is not sortable at the moment')
result = self._batchlist_duplicates.count_duplicates_on_sorted_list_f64(bt, get_distibution_by_last_col)
return result
else:
raise Exception('datatype is not sortable at the moment')
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class RustBatchSorter_Radix:

def __init__(self, tmp_path: str):
self._sorter = BatchSorter_Radix_py(os.path.join(tmp_path, 'sort_tmp'))

# TODO: string support hinzufügen
def sort(self, rustbatchholder: RustBatchholder, tableName):
bt = rustbatchholder.getRustBatchTable(tableName)
dtype = rustbatchholder.rustBatchtable[tableName][1]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os.path
from typing import Dict, List, Tuple, Any
from enum import Enum

import numpy as np
from numpy.lib import recfunctions as rfn
from tqdm import tqdm
Expand Down Expand Up @@ -35,7 +34,7 @@ def clean(self):
self.rustBatchtable = {}

def transferToRust(self, bp: BatchPartition, table: str, columns: List[str] = None, memMode: MemoryMode = MemoryMode.InMemory):
bt = bp[table]
bt = bp[table] # bt = batchtable,bp = batchpartition

if columns is None:
columns = bt.getColumns()
Expand All @@ -58,8 +57,10 @@ def transferToRust(self, bp: BatchPartition, table: str, columns: List[str] = No
loaded_chunk = bt.convert_2_ndarray(indicies=list(range(i, len(bt))), columns=columns)
else:
loaded_chunk = bt.convert_2_ndarray(indicies=list(range(i, i + chunksize)), columns=columns)

new_rbt.add_chunk(loaded_chunk)
if dataTypHolder[dataTypHolder.getColumns()[0]] == BatchDatatypClass.NUMPY_STRING and len(columns)==1:
new_rbt.add_chunk_wrapper(loaded_chunk)
elif len(columns)==1:
new_rbt.add_chunk(np.expand_dims(loaded_chunk, axis=1))
bar.update(len(loaded_chunk))

def _create_typed_BatchTable(self, dtype: BatchDatatypHolder, tableName:str):
Expand All @@ -77,6 +78,8 @@ def _create_typed_BatchTable(self, dtype: BatchDatatypHolder, tableName:str):
return self.rustBatchholder.get_batchtable_f32(tableName)
elif dtype == dtype.NUMPY_FLOAT64:
return self.rustBatchholder.get_batchtable_f64(tableName)
elif dtype == dtype.NUMPY_STRING:
return self.rustBatchholder.get_batchtable_string(tableName)

def getRustBatchTable(self, tableName: str):
if tableName in self.rustBatchtable:
Expand All @@ -95,8 +98,12 @@ def transferToBatchtable(self, rustTableName: str, bp: BatchPartition, newBatchT
bt = bp[newBatchTableName]

for index, i in enumerate(range(0, len(bt), chunksize)):
loaded_chunk = rustbt.load_chunk(index)
loaded_chunk = rfn.unstructured_to_structured(loaded_chunk, dtype.toNUMPY_dtypes())
if dtype in BatchDatatypClass.NUMPY_STRING and len(shape) == 1:
loaded_chunk = rustbt.get_loaded_string_chunk(index)
loaded_chunk = np.array(loaded_chunk)
else:
loaded_chunk = rustbt.load_chunk(index)
loaded_chunk = rfn.unstructured_to_structured(loaded_chunk, dtype.toNUMPY_dtypes())
if i + chunksize > len(bt):
bt[i:len(bt)] = loaded_chunk
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,105 +1,115 @@
use crate::BatchSystem::BatchTablesImplementation::BatchTable::{
BatchTable, BatchTablesTypWrapper,
};
use crate::Utilities::array_helper;
use ndarray::{Array, Axis, Ix1, Ix2};
use ndarray_npy::ReadableElement;
use serde::de::DeserializeOwned;
use std::borrow::Borrow;
use std::collections::HashMap;
use std::fmt::Display;
use std::hash::Hash;
use std::sync::{Arc, RwLock};
use ndarray::{Array, Array2, ArrayView2, ArrayViewMut, Axis, Ix1, Ix2, IxDyn, s};
use ndarray_npy::{ReadableElement};
use serde::de::DeserializeOwned;
use serde::Deserialize;
use crate::BatchSystem::BatchTablesImplementation::BatchTable::{BatchTable, BatchTablesTypWrapper};
use crate::Utilities::array_helper;

pub struct duplicat_result<T: ReadableElement + Ord + Clone + Hash + Display+ Copy> {
pub counted_entries:usize,
pub counted_duplicates:usize,
pub entry_most_duplicates:usize,
pub struct duplicat_result<T: ReadableElement + Ord + Clone + Hash + Display + Copy> {
pub counted_entries: usize,
pub counted_duplicates: usize,
pub entry_most_duplicates: usize,
pub duplicates_dist: HashMap<usize, usize>,
pub duplicates_by_last_col: HashMap<T, usize>
pub duplicates_by_last_col: HashMap<T, usize>,
}

pub fn count_duplicates_on_sorted_list<'de,T: ReadableElement + Ord + Clone + Hash + Display+ Copy + DeserializeOwned>(batchtable: Arc<RwLock<BatchTablesTypWrapper>>, calc_duplicates_by_last_col:bool) -> duplicat_result<T> {
pub fn count_duplicates_on_sorted_list<
'de,
T: ReadableElement + Ord + Clone + Hash + Display + Copy + DeserializeOwned,
>(
batchtable: Arc<RwLock<BatchTablesTypWrapper>>,
calc_duplicates_by_last_col: bool,
) -> duplicat_result<T> {
let mut batchtable = batchtable.read().unwrap();
let chunk_count = batchtable.get_table_chunk_count();

let mut counted_entries:usize = 0;
let mut counted_duplicates:usize = 0;
let mut entry_most_duplicates:usize = 0;
let mut counted_entries: usize = 0;
let mut counted_duplicates: usize = 0;
let mut entry_most_duplicates: usize = 0;
let mut duplicates_dist: HashMap<usize, usize> = HashMap::new();
let mut duplicates_by_last_col: HashMap<T,usize> = HashMap::new();
let mut duplicates_by_last_col: HashMap<T, usize> = HashMap::new();

let mut current_entry_duplicated:usize = 0;
let mut current_entry_duplicated: usize = 0;

let mut last_entry:Option<Array<T,Ix1>> = None;
for chunk_index in 0..chunk_count{
let mut last_entry: Option<Array<T, Ix1>> = None;
for chunk_index in 0..chunk_count {
let loaded_chunk: Array<T, Ix2> = batchtable.load_chunk(chunk_index);


for entry in loaded_chunk.axis_iter(Axis(0)){
for entry in loaded_chunk.axis_iter(Axis(0)) {
counted_entries += 1;

if last_entry.is_none(){
if last_entry.is_none() {
last_entry = Option::from(entry.to_owned());
}else{
} else {
let last_clone = last_entry.clone().unwrap();

if array_helper::merge_compare_arrayView1(last_clone.view(), entry, calc_duplicates_by_last_col) == 0{
if current_entry_duplicated == 0{
if array_helper::merge_compare_arrayView1(
last_clone.view(),
entry,
calc_duplicates_by_last_col,
) == 0
{
if current_entry_duplicated == 0 {
counted_duplicates += 2;
current_entry_duplicated = 2;
}else {
} else {
counted_duplicates += 1;
current_entry_duplicated += 1;
}

if calc_duplicates_by_last_col{
if current_entry_duplicated == 2{

if calc_duplicates_by_last_col {
if current_entry_duplicated == 2 {
let val_last = last_clone.last().unwrap();
let val_current = entry.view();
let val_current = val_current.last().unwrap();

if duplicates_by_last_col.contains_key(val_last.borrow()){
let hash_pointer = duplicates_by_last_col.get_mut(val_last.borrow()).unwrap();
if duplicates_by_last_col.contains_key(val_last.borrow()) {
let hash_pointer =
duplicates_by_last_col.get_mut(val_last.borrow()).unwrap();
*hash_pointer += 1;
}else {
} else {
duplicates_by_last_col.insert(*val_last, 1);
}

if duplicates_by_last_col.contains_key(val_current.borrow()){
let hash_pointer = duplicates_by_last_col.get_mut(val_current.borrow()).unwrap();
if duplicates_by_last_col.contains_key(val_current.borrow()) {
let hash_pointer = duplicates_by_last_col
.get_mut(val_current.borrow())
.unwrap();
*hash_pointer += 1;
}else {
} else {
duplicates_by_last_col.insert(*val_current, 1);
}
}else if current_entry_duplicated > 2 {
} else if current_entry_duplicated > 2 {
let val_current = entry.view();
let val_current = val_current.last().unwrap();

if duplicates_by_last_col.contains_key(val_current){
let hash_pointer = duplicates_by_last_col.get_mut(val_current).unwrap();
if duplicates_by_last_col.contains_key(val_current) {
let hash_pointer =
duplicates_by_last_col.get_mut(val_current).unwrap();
*hash_pointer += 1;
}else {
} else {
duplicates_by_last_col.insert(*val_current, 1);
}
}
}

}else{

} else {
last_entry = Option::from(entry.to_owned());

if current_entry_duplicated >= 2{
if duplicates_dist.contains_key(&current_entry_duplicated){
if current_entry_duplicated >= 2 {
if duplicates_dist.contains_key(&current_entry_duplicated) {
*duplicates_dist.get_mut(&current_entry_duplicated).unwrap() += 1;
}else {
} else {
duplicates_dist.insert(current_entry_duplicated, 1);
}
}



if current_entry_duplicated > entry_most_duplicates{
if current_entry_duplicated > entry_most_duplicates {
entry_most_duplicates = current_entry_duplicated;
}

Expand All @@ -109,13 +119,19 @@ pub fn count_duplicates_on_sorted_list<'de,T: ReadableElement + Ord + Clone + Ha
}
}

if current_entry_duplicated >= 2{
if duplicates_dist.contains_key(&current_entry_duplicated){
if current_entry_duplicated >= 2 {
if duplicates_dist.contains_key(&current_entry_duplicated) {
*duplicates_dist.get_mut(&current_entry_duplicated).unwrap() += 1;
}else {
} else {
duplicates_dist.insert(current_entry_duplicated, 1);
}
}

duplicat_result{counted_entries, counted_duplicates, entry_most_duplicates, duplicates_dist, duplicates_by_last_col }
}
duplicat_result {
counted_entries,
counted_duplicates,
entry_most_duplicates,
duplicates_dist,
duplicates_by_last_col,
}
}
Loading