Skip to content

Commit

Permalink
Merge pull request #4 from nenadjakic/improvement/executor-ocr-config
Browse files Browse the repository at this point in the history
Improvement/executor ocr config
  • Loading branch information
nenadjakic authored Sep 4, 2024
2 parents 8ad10c3 + 71952e9 commit c7f33c6
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import com.github.nenadjakic.ocr.studio.entity.Task
import com.github.nenadjakic.ocr.studio.extension.collectionMap
import com.github.nenadjakic.ocr.studio.service.TaskService
import io.swagger.v3.oas.annotations.Operation
import io.swagger.v3.oas.annotations.media.Content
import io.swagger.v3.oas.annotations.media.Encoding
import io.swagger.v3.oas.annotations.responses.ApiResponse
import io.swagger.v3.oas.annotations.responses.ApiResponses
import io.swagger.v3.oas.annotations.tags.Tag
Expand Down Expand Up @@ -71,23 +73,33 @@ open class TaskController(
@Operation(
operationId = "createTask",
summary = "Create task.",
description = "Creates a new task based on the provided model."
description = "Creates a new task based on the provided model.",
requestBody =
io.swagger.v3.oas.annotations.parameters.RequestBody(
content = [Content(encoding = [Encoding(name = "model", contentType = "application/json")]
)]
)
)
@ApiResponses(
value = [
ApiResponse(responseCode = "201", description = "Task created successfully."),
ApiResponse(responseCode = "400", description = "Invalid request data.")
]
)
@PostMapping(consumes = [MediaType.MULTIPART_FORM_DATA_VALUE], produces = [MediaType.APPLICATION_JSON_VALUE])
fun create(@Valid @RequestParam model: TaskAddRequest, files: Collection<MultipartFile>): ResponseEntity<Void> {
@PostMapping(consumes = [MediaType.MULTIPART_FORM_DATA_VALUE])
fun create(
@Valid @RequestPart(name = "model")
//@Schema(implementation = TaskAddRequest::class)
model: TaskAddRequest,
@RequestPart(value = "files", required = false) files: Collection<MultipartFile>?
): ResponseEntity<Void> {
var task = modelMapper.map(model, Task::class.java)
return insert(task)
return insert(task, files)
}

@Operation(
operationId = "createDraftTask",
summary = "Create task.",
summary = "Create draft task.",
description = "Creates a new task based on the provided model."
)
@ApiResponses(
Expand Down Expand Up @@ -172,8 +184,8 @@ open class TaskController(
@RequestPart("files") multipartFiles: Collection<MultipartFile>
): ResponseEntity<List<UploadDocumentResponse>> = ResponseEntity.ok(modelMapper.collectionMap(taskService.upload(id, multipartFiles), UploadDocumentResponse::class.java))

private fun insert(task: Task): ResponseEntity<Void> {
val createdTask = taskService.insert(task)
private fun insert(task: Task, files: Collection<MultipartFile>? = null): ResponseEntity<Void> {
val createdTask = taskService.insert(task, files)

val location = ServletUriComponentsBuilder
.fromCurrentRequest()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package com.github.nenadjakic.ocr.studio.dto

import com.github.nenadjakic.ocr.studio.entity.OcrConfig
import com.github.nenadjakic.ocr.studio.entity.OcrConfig.FileFormat

class OcrConfigRequest {
lateinit var ocrEngineMode: OcrConfig.OcrEngineMode
lateinit var pageSegmentationMode: OcrConfig.PageSegmentationMode
lateinit var language: String
var preProcessing: Boolean = false
lateinit var fileFormat: FileFormat
var mergeDocuments: Boolean = false
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package com.github.nenadjakic.ocr.studio.entity

import net.sourceforge.tess4j.ITesseract.RenderedFormat

class OcrConfig {
enum class OcrEngineMode(val tesseractValue: Int, val descritpion: String) {
LEGACY(0, "Legacy engine only."),
Expand All @@ -24,9 +26,27 @@ class OcrConfig {
MODE_12(12 ,"Sparse text with OSD."),
MODE_13(13 ,"Raw line. Treat the image as a single text line, bypassing hacks that are Tesseract-specific.")
}
enum class FileFormat {
PDF, HOCR, TEXT;

fun toRenderedFormat(): RenderedFormat = when (this) {
PDF -> RenderedFormat.PDF
HOCR -> RenderedFormat.HOCR
TEXT -> RenderedFormat.TEXT
}

fun getExtension(): String = when (this) {
PDF -> "pdf"
HOCR -> "hocr"
TEXT -> "txt"
}
}

var language: String = "eng"
var ocrEngineMode: OcrEngineMode = OcrEngineMode.DEFAULT
var pageSegmentationMode: PageSegmentationMode = PageSegmentationMode.MODE_3
var preProcessing: Boolean = false
var fileFormat: FileFormat = FileFormat.TEXT
var mergeDocuments: Boolean = false

}
128 changes: 114 additions & 14 deletions src/main/kotlin/com/github/nenadjakic/ocr/studio/executor/OcrExecutor.kt
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
package com.github.nenadjakic.ocr.studio.executor

import com.github.nenadjakic.ocr.studio.config.OcrProperties
import com.github.nenadjakic.ocr.studio.entity.OcrConfig
import com.github.nenadjakic.ocr.studio.entity.OutDocument
import com.github.nenadjakic.ocr.studio.entity.Status
import com.github.nenadjakic.ocr.studio.extension.toOcrProgress
import com.github.nenadjakic.ocr.studio.repository.TaskRepository
import com.github.nenadjakic.ocr.studio.service.TaskFileSystemService
import net.sourceforge.tess4j.ITesseract
import net.sourceforge.tess4j.ITesseract.RenderedFormat
import net.sourceforge.tess4j.util.ImageHelper
import org.apache.pdfbox.Loader
import org.apache.pdfbox.pdmodel.PDDocument
import org.apache.pdfbox.rendering.ImageType
import org.apache.pdfbox.rendering.PDFRenderer
import org.apache.tika.metadata.PDF
import org.slf4j.LoggerFactory
import java.io.File
import java.io.IOException
import java.nio.file.Path
import java.time.ZonedDateTime
import java.util.*
import javax.imageio.ImageIO
import kotlin.collections.Map

class OcrExecutor(
override val id: UUID,
Expand All @@ -31,23 +41,113 @@ class OcrExecutor(

task.ocrProgress = progressInfo.toOcrProgress()
taskRepository.save(task)
for (document in task.inDocuments) {
val inFile = Path.of(ocrProperties.taskPath, task.id.toString(), "input", document.randomizedFileName).toFile()
if (inFile.exists()) {
val outFile =
Path.of(ocrProperties.taskPath, task.id.toString(), "output", UUID.randomUUID().toString()).toFile()
try {
for (document in task.inDocuments) {
val inFile =
Path.of(ocrProperties.taskPath, task.id.toString(), "input", document.randomizedFileName).toFile()
if (inFile.exists()) {
val outFile =
Path.of(ocrProperties.taskPath, task.id.toString(), "output", UUID.randomUUID().toString())
.toFile()

document.outDocument = OutDocument()
document.outDocument!!.outputFileName = outFile.name
document.outDocument = OutDocument()
document.outDocument!!.outputFileName = outFile.name

tesseract.createDocuments(inFile.absolutePath, outFile.absolutePath, mutableListOf(RenderedFormat.PDF))
val filesToOcr = preProcessDocument(task.ocrConfig.preProcessing, inFile)
if (filesToOcr.size > 1) {
logger.info("Starting ocr of multi paged document.")
PDDocument().use { pdDocument ->
filesToOcr.entries.sortedBy { it.key }.forEach { filesToOcrEntry ->
val index = filesToOcrEntry.key
val fileToOcr = filesToOcrEntry.value

val tempOutFile = File.createTempFile("___", ".pdf")
logger.debug("OCR of pdf page: $index.")
tesseract.createDocuments(
fileToOcr.absolutePath,
tempOutFile.absolutePath.removeSuffix(".pdf"),
mutableListOf(task.ocrConfig.fileFormat.toRenderedFormat())
)

val outDocument = Loader.loadPDF(tempOutFile)
outDocument.pages.forEach { page ->
pdDocument.addPage(page)
}

}
pdDocument.save(outFile)
}
} else {
logger.info("Starting ocr of one paged document.")
filesToOcr.entries.sortedBy { it.key }.forEach {
val index = it.key
val fileToOcr = it.value

val tempOutFile = File.createTempFile("___", ".tmp")
tesseract.createDocuments(
fileToOcr.absolutePath,
outFile.absolutePath,
mutableListOf(task.ocrConfig.fileFormat.toRenderedFormat())
)
}
}
progressInfo.taskDone++
}
task.ocrProgress = progressInfo.toOcrProgress()
taskRepository.save(task)
}
progressInfo.progressInfoStatus = ProgressInfo.ProgressInfoStatus.FINISHED
task.ocrProgress = progressInfo.toOcrProgress()
taskRepository.save(task)
} catch (ex: Exception) {
logger.error("OCR for task id $id failed.", ex)
progressInfo.progressInfoStatus = ProgressInfo.ProgressInfoStatus.FAILED
}
}

private data class InputData (
val fileFormat: OcrConfig.FileFormat,
val file: File
)

@Throws(IOException::class)
private fun preProcessDocument(preProcess: Boolean, inFile: File): Map<Long, File> {
val files = mutableMapOf<Long, File>()
var order = 1L
val mediaType = TaskFileSystemService.getContentType(inFile)
if (preProcess) {
logger.info("Pre processing of input document ${inFile.name}.")
if (mediaType.type.equals("image")) {
val originalImage = ImageIO.read(inFile)
val grayscaleImage = ImageHelper.convertImageToGrayscale(originalImage)
val tempGrayscaleImage = File.createTempFile("___", "_tmp")
val result = ImageIO.write(grayscaleImage, mediaType.subtype, tempGrayscaleImage)
if (result) {
files[order++] = tempGrayscaleImage
}
} else if (mediaType.toString().equals("application/pdf")) {
Loader.loadPDF(inFile).use {
val pdfRenderer = PDFRenderer(it)
logger.info("Starting of pdf preprocess. Total pages: ${it.numberOfPages}")
for (pageNumber in 0..<it.numberOfPages) {
logger.debug("Preprocessing of pdf page $pageNumber")
val pdfPage = pdfRenderer.renderImageWithDPI(pageNumber, 300F, ImageType.GRAY)
val tempPdfPage = File.createTempFile("___", "_tmp")
val result = ImageIO.write(pdfPage, "png", tempPdfPage)
if (result) {
files[order++] = tempPdfPage
} else {
order++
}
}
}
} else {
logger.warn("File {} for task {} does not exist on filesystem.", inFile.name, id)
files[order++] = inFile
}

progressInfo.taskDone++
} else {
files[order++] = inFile
}
task.ocrProgress = progressInfo.toOcrProgress()
taskRepository.save(task)
return files
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class ParallelizationManagerImpl(
}

val future: ScheduledFuture<out Any> = if (executor.startDateTime != null) {
taskScheduler.schedule({ executor }, executor.startDateTime!!.toInstant())
taskScheduler.schedule({ executor.run() }, executor.startDateTime!!.toInstant())
} else {
taskScheduler.schedule({ executor }, Instant.now().plusSeconds(30L))
taskScheduler.schedule({ executor.run() }, Instant.now().plusSeconds(30L))
}

runnables[executor.id] = executor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fun ProgressInfo.toOcrProgress(): OcrProgress {
ProgressInfo.ProgressInfoStatus.CANCELED -> Status.CANCELED
ProgressInfo.ProgressInfoStatus.INTERRUPTED -> Status.INTERRUPTED
}
val progress = "$this.taskDone.toString() / $this.totalTasks.toString()"
val progress = "${this.taskDone} / ${this.totalTasks}"
val description = this.description

return OcrProgress(status, progress, description)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class OcrService(
val task = taskRepository.findById(id).orElseThrow { OcrException("Cannot find task with id: $id") }

if (Status.getInProgressStatuses().contains(task.ocrProgress.status)) {
throw OcrException("Task with id: {} is in progress and cannot be scheduled.")
throw OcrException("Task with id: $id is in progress and cannot be scheduled.")
}
val tesseract = tesseractFactory.create(
task.ocrConfig.language,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.github.nenadjakic.ocr.studio.config.OcrProperties
import org.apache.tika.config.TikaConfig
import org.apache.tika.detect.Detector
import org.apache.tika.metadata.Metadata
import org.apache.tika.mime.MediaType
import org.springframework.stereotype.Service
import org.springframework.web.multipart.MultipartFile
import java.io.ByteArrayInputStream
Expand All @@ -22,8 +23,29 @@ class TaskFileSystemService(
) {
private val inputDirectoryName: String = "input"
private val outputDirectoryName: String = "output"
private final val tikaConfig = TikaConfig()
private val detector: Detector = tikaConfig.detector

companion object {
private val tikaConfig = TikaConfig()
private val detector: Detector = tikaConfig.detector

fun getContentType(file: File): MediaType = detector.detect(cloneInputStream(file.inputStream()), Metadata())

fun getContentType(multiPartFile: MultipartFile): String {
var contentType = multiPartFile.contentType

if ("application/octet-stream".equals(contentType, true)) {
contentType = detector.detect(cloneInputStream(multiPartFile.inputStream), Metadata()).toString()
}
return contentType
}

private fun cloneInputStream (inputStream: InputStream): InputStream {
val byteArrayOutputStream = ByteArrayOutputStream()
inputStream.transferTo(byteArrayOutputStream)

return ByteArrayInputStream(byteArrayOutputStream.toByteArray())
}
}

@Throws(IOException::class)
fun createTaskDirectories(id: UUID) {
Expand All @@ -49,15 +71,6 @@ class TaskFileSystemService(
multiPartFile.transferTo(targetFile.absoluteFile)
}

fun getContentType(multiPartFile: MultipartFile): String {
var contentType = multiPartFile.contentType

if ("application/octet-stream".equals(contentType, true)) {
contentType = detector.detect(cloneInputStream(multiPartFile.inputStream), Metadata()).toString()
}
return contentType
}

fun cleanUp(id: UUID) {
deleteDirectoryRecursively(Path.of(ocrProperties.taskPath))
}
Expand All @@ -73,11 +86,4 @@ class TaskFileSystemService(
.map(Path::toFile)
.forEach(File::delete)
}

private fun cloneInputStream (inputStream: InputStream): InputStream {
val byteArrayOutputStream = ByteArrayOutputStream()
inputStream.transferTo(byteArrayOutputStream)

return ByteArrayInputStream(byteArrayOutputStream.toByteArray())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ class TaskService(
return taskRepository.insert(entity)
}

fun insert(entity: Task, files: Collection<MultipartFile> = emptyList()): Task {
fun insert(entity: Task, files: Collection<MultipartFile>? = emptyList()): Task {
val createdEntity = insert(entity)
if (!files.isEmpty()) {
if (files != null && !files!!.isEmpty()) {
upload(createdEntity.id!!, files)
}
return createdEntity;
Expand All @@ -57,7 +57,7 @@ class TaskService(
val document = Document()
document.originalFileName = multiPartFile.originalFilename!!
document.randomizedFileName = UUID.randomUUID().toString()
document.type = taskFileSystemService.getContentType(multiPartFile)
document.type = TaskFileSystemService.getContentType(multiPartFile)

taskFileSystemService.uploadFile(multiPartFile, id, document.randomizedFileName)
task.addInDocument(document)
Expand Down
Loading

0 comments on commit c7f33c6

Please sign in to comment.