multithreaded chunk scoring & use max cpu w/ multiprocessing

pull/15/head
wlski 2024-06-23 22:12:57 -04:00
rodzic 4e863a3a9c
commit 50ee4f6826
1 zmienionych plików z 76 dodań i 22 usunięć

Wyświetl plik

@ -4,6 +4,7 @@
# (C) 2023 Thinkst Applied Research, PTY
# Author: Jacob Torrey <jacob@thinkst.com>
import time
import lzma, argparse, os, itertools
from zlib import compressobj, Z_FINISH
from brotli import compress as brotli_compress, MODE_TEXT
@ -24,17 +25,20 @@ class CompressionEngine(Enum):
ZLIB = 2
BROTLI = 3
# Precompile regex patterns for reuse
WHITESPACE_PATTERN = re.compile(r'[ \t]+')
NEWLINE_PATTERN = re.compile(r'\n+')
LEADING_TRAILING_NEWLINE_PATTERN = re.compile(r'(\n )|( \n)')
NON_ALNUM_PATTERN = re.compile(r'[^0-9A-Za-z,\.\(\) \n]')
def clean_text(s : str) -> str:
'''
Removes formatting and other non-content data that may skew compression ratios (e.g., duplicate spaces)
'''
# Remove extra spaces and duplicate newlines.
s = re.sub(' +', ' ', s)
s = re.sub('\t', '', s)
s = re.sub('\n+', '\n', s)
s = re.sub('\n ', '\n', s)
s = re.sub(' \n', '\n', s)
# Remove non-alphanumeric chars
s = re.sub(r'[^0-9A-Za-z,\.\(\) \n]', '', s)#.lower()
def clean_text(s: str) -> str:
s = WHITESPACE_PATTERN.sub(' ', s)
s = NEWLINE_PATTERN.sub('\n', s)
s = LEADING_TRAILING_NEWLINE_PATTERN.sub('\n', s)
s = NON_ALNUM_PATTERN.sub('', s)
return s
# The prelude file is a text file containing only AI-generated text, it is used to 'seed' the LZMA dictionary
@ -188,6 +192,9 @@ class LzmaLlmDetector(AIDetector):
determination = 'Human'
return (determination, abs(delta * 100))
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional, List, Tuple
class Zippy:
'''
@ -241,9 +248,45 @@ class Zippy:
'''
with open(filename, 'r', encoding='utf-8') as fp:
contents = fp.read()
return self.run_on_text_chunked(contents, chunk_size, prelude_ratio=prelude_ratio)
return (filename, self.run_on_text_chunked(contents, chunk_size, prelude_ratio=prelude_ratio))
def run_on_text_chunked(self, s : str, chunk_size : int = 1500, prelude_file : Optional[str] = None, prelude_ratio : Optional[float] = None) -> Optional[Score]:
# def run_on_text_chunked(self, s : str, chunk_size : int = 1500, prelude_file : Optional[str] = None, prelude_ratio : Optional[float] = None) -> Optional[Score]:
# '''
# Given a string (and an optional chunk size and number of decimal places to round to) returns the score for the passed string.
# This function chunks the input into at most chunk_size parts to score separately, then returns an average. This prevents a very large input
# being skewed because its compression ratio starts to overwhelm the prelude file.
# '''
# contents = clean_text(s)
# start = 0
# end = 0
# chunks = []
# while start + chunk_size < len(contents) and end != -1:
# end = contents.rfind(' ', start, start + chunk_size + 1)
# chunks.append(contents[start:end])
# start = end + 1
# chunks.append(contents[start:])
# scores = []
# if len(chunks) > 2:
# with Pool(cpu_count()) as pool:
# for r in pool.starmap(self._score_chunk, zip(chunks, itertools.repeat(prelude_file), itertools.repeat(prelude_ratio))):
# scores.append(r)
# else:
# for c in chunks:
# scores.append(self._score_chunk(c, prelude_file=prelude_file, prelude_ratio=prelude_ratio))
# ssum : float = 0.0
# for i, s in enumerate(scores):
# if s[0] == 'AI':
# ssum -= s[1] * (len(chunks[i]) / len(contents))
# else:
# ssum += s[1] * (len(chunks[i]) / len(contents))
# sa : float = ssum
# if sa < 0:
# return ('AI', abs(sa))
# else:
# return ('Human', abs(sa))
def run_on_text_chunked(self, s: str, chunk_size: int = 1500, prelude_file: Optional[str] = None, prelude_ratio: Optional[float] = None):
'''
Given a string (and an optional chunk size and number of decimal places to round to) returns the score for the passed string.
This function chunks the input into at most chunk_size parts to score separately, then returns an average. This prevents a very large input
@ -259,21 +302,26 @@ class Zippy:
chunks.append(contents[start:end])
start = end + 1
chunks.append(contents[start:])
scores = []
if len(chunks) > 2:
with Pool(cpu_count()) as pool:
for r in pool.starmap(self._score_chunk, zip(chunks, itertools.repeat(prelude_file), itertools.repeat(prelude_ratio))):
scores.append(r)
with ThreadPoolExecutor() as executor:
future_to_chunk = {executor.submit(self._score_chunk, chunk, prelude_file, prelude_ratio): chunk for chunk in chunks}
for future in as_completed(future_to_chunk):
result = future.result()
scores.append(result)
else:
for c in chunks:
scores.append(self._score_chunk(c, prelude_file=prelude_file, prelude_ratio=prelude_ratio))
ssum : float = 0.0
ssum: float = 0.0
for i, s in enumerate(scores):
if s[0] == 'AI':
ssum -= s[1] * (len(chunks[i]) / len(contents))
else:
ssum += s[1] * (len(chunks[i]) / len(contents))
sa : float = ssum
sa: float = ssum
if sa < 0:
return ('AI', abs(sa))
else:
@ -326,7 +374,7 @@ class EnsembledZippy:
'''
with open(filename, 'r', encoding='utf-8') as fp:
contents = fp.read()
return self.run_on_text_chunked(contents, chunk_size)
return (filename, self.run_on_text_chunked(contents, chunk_size))
def run_on_text_chunked(self, s : str, chunk_size : int = 1500, prelude_file : Optional[str] = None, prelude_ratio : Optional[float] = None) -> Optional[Score]:
'''
@ -340,6 +388,7 @@ class EnsembledZippy:
return self._combine_scores(scores)
def main():
start_time = time.perf_counter()
parser = argparse.ArgumentParser()
parser.add_argument("-v", required=False, action='store_true', help="Display the version and exit")
parser.add_argument("-n", required=False, action='store_true', help="Normalize scoring based on length of sample compared to length of PRELUDE_FILE")
@ -384,10 +433,15 @@ def main():
z = Zippy(engine, normalize=normalize)
else:
z = EnsembledZippy()
for f in args.sample_files:
print(f)
if os.path.isfile(f):
print(str(z.run_on_file_chunked(f)))
with Pool(cpu_count()) as pool:
results = pool.map(z.run_on_file_chunked, args.sample_files)
for r in results:
print(r)
end_time = time.perf_counter()
print(f"total time: {end_time - start_time:0.4f}")
if __name__ == '__main__':
main()