Coverage Summary for Class: NamedParameterUtils (com.kotlinorm.beans.parser)

Class Method, % Branch, % Line, % Instruction, %
NamedParameterUtils 100% (13/13) 69.1% (121/175) 89.2% (173/194) 80.5% (883/1097)
NamedParameterUtils$ParameterHolder 100% (1/1) 100% (1/1) 100% (17/17)
Total 100% (14/14) 69.1% (121/175) 89.2% (174/195) 80.8% (900/1114)


 /**
  * Copyright 2022-2025 kronos-orm
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
  *     http://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
 
 package com.kotlinorm.beans.parser
 
 import com.kotlinorm.cache.namedSqlCache
 import com.kotlinorm.exceptions.InvalidDataAccessApiUsageException
 import com.kotlinorm.exceptions.InvalidParameterException
 import com.kotlinorm.interfaces.KPojo
 
 
 /**
  * Created by OUSC on 2022/11/4 11:32
  *
  * Codes based on <a href="https://github.com/spring-projects/spring-framework/blob/main/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java">NamedParameterUtils</a>
  *
  * Add path parse support for NamedParameterParameterSource
  *
  * Such as `:array[0].list[1].map[KPojo].id`
  *
  * All rights reserved.
  */
 object NamedParameterUtils {
     private val START_SKIP = arrayOf("'", "\"", "--", "/*", "`")
 
     /**
      * Set of characters that at are the corresponding comment or quotes ending characters.
      */
     private val STOP_SKIP = arrayOf("'", "\"", "\n", "*/", "`")
 
     /**
      * Set of characters that qualify as parameter separators,
      * indicating that a parameter name in an SQL String has ended.
      */
     private const val PARAMETER_SEPARATORS = "\"':&,;()|=+-*%/\\<>^"
 
     private const val SEPARATOR_INDEX_SIZE = 128
 
     /**
      * An index with separator flags per character code.
      * Technically only needed between 34 and 124 at this point.
      */
     private val separatorIndex = BooleanArray(SEPARATOR_INDEX_SIZE)
 
     init {
         for (i in PARAMETER_SEPARATORS.indices) {
             separatorIndex[PARAMETER_SEPARATORS[i].code] = true
         }
     }
 
     //-------------------------------------------------------------------------
     // Core methods used by NamedParameterJdbcTemplate and SqlQuery/SqlUpdate
     //-------------------------------------------------------------------------
 
     //-------------------------------------------------------------------------
     // Core methods used by NamedParameterJdbcTemplate and SqlQuery/SqlUpdate
     //-------------------------------------------------------------------------
     /**
      * Parse the SQL statement and locate any placeholders or named parameters.
      * Named parameters are substituted for a JDBC placeholder.
      * @param sql the SQL statement
      * @return the parsed statement, represented as com.kotlinorm.beans.parser.ParsedSql instance
      */
     fun parseSqlStatement(sql: String, paramMap: Map<String, Any?> = mapOf()): ParsedSql {
         val original = namedSqlCache[sql]
         if (original != null) {
             return ParsedSql(
                 sql,
                 paramMap,
                 original.parameterNames,
                 original.parameterIndexes,
                 original.namedParameterCount,
                 original.unnamedParameterCount,
                 original.totalParameterCount,
                 original.jdbcSql
             )
         }
         val namedParameters: MutableSet<String> = HashSet()
         val sqlToUse = StringBuilder(sql)
         val parameterList: MutableList<ParameterHolder> = ArrayList()
         val statement = sql.toCharArray()
         var namedParameterCount = 0
         var unnamedParameterCount = 0
         var totalParameterCount = 0
         var escapes = 0
         var i = 0
         while (i < statement.size) {
             var skipToPosition: Int
             while (i < statement.size) {
                 skipToPosition = skipCommentsAndQuotes(statement, i)
                 if (i == skipToPosition) {
                     break
                 } else {
                     i = skipToPosition
                 }
             }
             if (i >= statement.size) {
                 break
             }
             var c = statement[i]
             if (c == ':' || c == '&') {
                 var j = i + 1
                 if ((c == ':') && j < statement.size && statement[j] == ':') {
                     // Postgres-style "::" casting operator should be skipped
                     i += 2
                     continue
                 }
                 var parameter: String?
                 if ((c == ':' && j < statement.size) && statement[j] == '{') {
                     // :{x} style parameter
                     while (statement[j] != '}') {
                         j++
                         if (j >= statement.size) {
                             throw InvalidParameterException(
                                 "Non-terminated named parameter declaration " +
                                         "at position " + i + " in statement: " + sql
                             )
                         }
                         if (statement[j] == ':' || statement[j] == '{') {
                             throw InvalidParameterException(
                                 ("Parameter name contains invalid character '" +
                                         statement[j] + "' at position " + i + " in statement: " + sql)
                             )
                         }
                     }
                     if (j - i > 2) {
                         parameter = sql.substring(i + 2, j)
                         namedParameterCount = addNewNamedParameter(namedParameters, namedParameterCount, parameter)
                         totalParameterCount = addNamedParameter(
                             parameterList, totalParameterCount, escapes, i, j + 1, parameter
                         )
                     }
                     j++
                 } else {
                     var paramWithSquareBrackets = false
                     while (j < statement.size) {
                         c = statement[j]
                         if (isParameterSeparator(c)) {
                             break
                         }
                         if (c == '[') {
                             paramWithSquareBrackets = true
                         } else if (c == ']') {
                             if (!paramWithSquareBrackets) {
                                 break
                             }
                             paramWithSquareBrackets = false
                         }
                         j++
                     }
                     if (j - i > 1) {
                         parameter = sql.substring(i + 1, j)
                         namedParameterCount = addNewNamedParameter(namedParameters, namedParameterCount, parameter)
                         totalParameterCount = addNamedParameter(
                             parameterList, totalParameterCount, escapes, i, j, parameter
                         )
                     }
                 }
                 i = j - 1
             } else {
                 if (c == '\\') {
                     val j = i + 1
                     if (j < statement.size && statement[j] == ':') {
                         // escaped ":" should be skipped
                         sqlToUse.deleteCharAt(i - escapes)
                         escapes++
                         i += 2
                         continue
                     }
                 }
                 if (c == '?') {
                     val j = i + 1
                     if (j < statement.size && ((statement[j] == '?') || (statement[j] == '|') || (statement[j] == '&'))) {
                         // Postgres-style "??", "?|", "?&" operator should be skipped
                         i += 2
                         continue
                     }
                     unnamedParameterCount++
                     totalParameterCount++
                 }
             }
             i++
         }
         val parsedSql = ParsedSql(sqlToUse.toString(), paramMap)
         for (ph: ParameterHolder in parameterList) {
             parsedSql.addNamedParameter(ph.parameterName, ph.startIndex, ph.endIndex)
         }
         parsedSql.namedParameterCount = namedParameterCount
         parsedSql.unnamedParameterCount = unnamedParameterCount
         parsedSql.totalParameterCount = totalParameterCount
         parsedSql.jdbcSql = substituteNamedParameters(parsedSql)
         namedSqlCache[sql] = parsedSql
 
         return parsedSql
     }
 
     private fun addNamedParameter(
         parameterList: MutableList<ParameterHolder>,
         totalParameterCount: Int,
         escapes: Int,
         i: Int,
         j: Int,
         parameter: String
     ): Int {
         parameterList.add(ParameterHolder(parameter, i - escapes, j - escapes))
         return totalParameterCount + 1
     }
 
     private fun addNewNamedParameter(
         namedParameters: MutableSet<String>,
         namedParameterCount: Int,
         parameter: String
     ): Int {
         if (!namedParameters.contains(parameter)) {
             namedParameters.add(parameter);
             return namedParameterCount + 1
         }
         return namedParameterCount
     }
 
     /**
      * Skip over comments and quoted names present in an SQL statement.
      * @param statement character array containing SQL statement
      * @param position current position of statement
      * @return next position to process after any comments or quotes are skipped
      */
     private fun skipCommentsAndQuotes(statement: CharArray, position: Int): Int {
         for (i in START_SKIP.indices) {
             if (statement[position] == START_SKIP[i][0]) {
                 var match = true
                 for (j in 1 until START_SKIP[i].length) {
                     if (statement[position + j] != START_SKIP[i][j]) {
                         match = false
                         break
                     }
                 }
                 if (match) {
                     val offset: Int = START_SKIP[i].length
                     for (m in position + offset until statement.size) {
                         if (statement[m] == STOP_SKIP[i][0]) {
                             var endMatch = true
                             var endPos = m
                             for (n in 1 until STOP_SKIP[i].length) {
                                 if (m + n >= statement.size) {
                                     // last comment not closed properly
                                     return statement.size
                                 }
                                 if (statement[m + n] != STOP_SKIP[i][n]) {
                                     endMatch = false
                                     break
                                 }
                                 endPos = m + n
                             }
                             if (endMatch) {
                                 // found character sequence ending comment or quote
                                 return endPos + 1
                             }
                         }
                     }
                     // character sequence ending comment or quote not found
                     return statement.size
                 }
             }
         }
         return position
     }
 
     /**
      * Determine whether a parameter name ends at the current position,
      * that is, whether the given character qualifies as a separator.
      */
     private fun isParameterSeparator(c: Char): Boolean {
         return (c.code < 128 && separatorIndex[c.code]) || Character.isWhitespace(c)
     }
 
 
     internal class ParameterHolder(val parameterName: String, val startIndex: Int, val endIndex: Int)
 
     /**
      * Parse the SQL statement and locate any placeholders or named parameters. Named
      * parameters are substituted for a JDBC placeholder, and any select list is expanded
      * to the required number of placeholders. Select lists may contain an array of
      * objects, and in that case the placeholders will be grouped and enclosed with
      * parentheses. This allows for the use of "expression lists" in the SQL statement
      * like: <br></br><br></br>
      * `select id, name, state from table where (name, age) in (('John', 35), ('Ann', 50))`
      *
      * The parameter values passed in are used to determine the number of
      * placeholders to be used for a select list. Select lists should not be empty
      * and should be limited to 100 or fewer elements. An empty list or a larger
      * number of elements is not guaranteed to be supported by the database and
      * is strictly vendor-dependent.
      * @param parsedSql the parsed representation of the SQL statement
      * @param paramSource the source for named parameters
      * @return the SQL statement with substituted parameters
      * @see .parseSqlStatement
      */
     fun substituteNamedParameters(parsedSql: ParsedSql, paramSource: Map<String, Any?>? = null): String {
         val originalSql: String = parsedSql.originalSql
         val paramNames: List<String> = parsedSql.parameterNames
         if (paramNames.isEmpty()) {
             return originalSql
         }
 
         val actualSql = java.lang.StringBuilder(originalSql.length)
         var lastIndex = 0
         for (i in paramNames.indices) {
             val paramName = paramNames[i]
             val indexes: IntArray = parsedSql.parameterIndexes[i]
             val startIndex = indexes[0]
             val endIndex = indexes[1]
             actualSql.append(originalSql, lastIndex, startIndex)
             if (paramSource != null && paramSource.containsKey(paramName) && paramSource[paramName] != null) {
                 val value: Any = paramSource[paramName]!!
                 if (value is Iterable<*>) {
                     for ((k, entryItem) in value.withIndex()) {
                         if (k > 0) {
                             actualSql.append(", ")
                         }
                         if (entryItem is Array<*> && entryItem.isArrayOf<Any>()) {
                             actualSql.append('(')
                             for (m in entryItem.indices) {
                                 if (m > 0) {
                                     actualSql.append(", ")
                                 }
                                 actualSql.append('?')
                             }
                             actualSql.append(')')
                         } else {
                             actualSql.append('?')
                         }
                     }
                 } else {
                     actualSql.append('?')
                 }
             } else {
                 actualSql.append('?')
             }
             lastIndex = endIndex
         }
         actualSql.append(originalSql, lastIndex, originalSql.length)
         return actualSql.toString()
     }
 
     /**
      * Convert a Map of named parameter values to a corresponding array.
      * @param parsedSql the parsed SQL statement
      * @param paramSource the source for named parameters
      * @param declaredParams the List of declared SqlParameter objects
      * (may be `null`). If specified, the parameter metadata will
      * be built into the value array in the form of SqlParameterValue objects.
      * @return the array of values
      */
     fun buildValueArray(
         parsedSql: ParsedSql, paramSource: Map<String, Any?>
     ): Array<Any?> {
         val paramArray = arrayOfNulls<Any>(parsedSql.totalParameterCount)
         if (parsedSql.namedParameterCount > 0 && parsedSql.unnamedParameterCount > 0) {
             throw InvalidDataAccessApiUsageException(
                 "Not allowed to mix named and traditional ? placeholders. You have " +
                         parsedSql.namedParameterCount + " named parameter(s) and " +
                         parsedSql.unnamedParameterCount + " traditional placeholder(s) in statement: " +
                         parsedSql.originalSql
             )
         }
         val paramNames: List<String> = parsedSql.parameterNames
         for (i in paramNames.indices) {
             paramArray[i] = getValueFromMap(paramSource, paramNames[i])
         }
         return paramArray
     }
 
     private fun getValueFromMap(map: Map<String, Any?>, path: String): Any? {
         // 解析路径
         val keys = parsePath(path)
 
         // 逐级取值
         var current: Any? = map
         for (key in keys) {
             current = when (current) {
                 is Map<*, *> -> current[key] // 如果当前值是 Map,取出对应的值
                 is KPojo -> current.toDataMap()[key] // 如果当前值是 KPojo,取出对应的值
                 is Iterable<*>, is Array<*>, is IntArray, is LongArray, is ShortArray, is ByteArray, is DoubleArray, is FloatArray, is BooleanArray ->
                     // 如果当前值是 List,取出对应的索引
                     key.toIntOrNull()?.let {
                         when (current) {
                             is IntArray -> (current as IntArray)[it]
                             is LongArray -> (current as LongArray)[it]
                             is ShortArray -> (current as ShortArray)[it]
                             is ByteArray -> (current as ByteArray)[it]
                             is DoubleArray -> (current as DoubleArray)[it]
                             is FloatArray -> (current as FloatArray)[it]
                             is BooleanArray -> (current as BooleanArray)[it]
                             is Array<*> -> (current as Array<*>)[it]
                             is Iterable<*> -> (current as Iterable<*>).elementAt(it)
                             else -> throw InvalidDataAccessApiUsageException(
                                 "Collection named '$key' in parameter source is not an Iterable or Array"
                             )
                         }
                     } // 如果当前值是 List,取出对应的索引
                 else -> null // 其他类型则返回 null
             }
             if (current == null) break // 如果中途遇到 null,停止
         }
 
         return current
     }
 
     private fun parsePath(path: String): List<String> {
         // 使用正则表达式解析路径
         val regex = """\.|(\[([0-9]+)])|(?<key>[^.\[\]]+)""".toRegex()
         return regex.findAll(path).mapNotNull {
             it.groups["key"]?.value ?: it.groups[2]?.value
         }.toList()
     }
 }