Flink UDF自动注册实践

  1. 云栖社区>
  2. 博客>
  3. 正文

Flink UDF自动注册实践

王知无 2019-08-17 20:29:07 浏览237
展开阅读全文

5万人关注的大数据成神之路,不来了解一下吗?
5万人关注的大数据成神之路,真的不来了解一下吗?
5万人关注的大数据成神之路,确定真的不来了解一下吗?

欢迎您关注《大数据成神之路》

1.注册UDF函数
1.1 注册相关方法
此处,我们使用的udf函数为标量函数,它继承的是ScalarFunction,该类在我们的使用中,发现它继承自UserDefinedFunction这个类,该处的udf函数由用户自己定义,而函数的注册此处我们自己实现;

函数注册时,使用flink的tableEnv上下文对象注册该函数,此处注册时使用的方法是TableEnvironment类里面的重载方法registerFunction,这个函数不涉及参数和泛型的问题,具体方法如下:

    * Registers a [[ScalarFunction]] under a unique name. Replaces already existing
    * user-defined functions under this name.
    */
  def registerFunction(name: String, function: ScalarFunction): Unit = {
// check if class could be instantiated
    checkForInstantiation(function.getClass)

    // register in Table API

functionCatalog.registerFunction(name, function.getClass)

    // register in SQL API
functionCatalog.registerSqlFunction(
      createScalarSqlFunction(name, name, function, typeFactory)
)
  }

通过上面得方法,发现在检查完类的实例化之后,便是对该类进行注册使用,分别针对Table API和SQL API两种不同形式去进行注册。

下面是我们注册的小案例:

日常模式

tableEnv.registerFunction("hashCode",new HashCode())
myTable.select("item,item.hashCode(),hashCode(item)")
val hcTest = tableEnv.sqlQuery("select item,hashCode(item) from myTable")

假日模式

tableEnv.registerFunction(m("name").toString, ReflectHelper.newInstanceByClsName[ScalarFunction]
    (m("className").toString,this.getClass.getClassLoader))

1.2 函数示例

class HashCode extends ScalarFunction {var hashcode_factor = 12  override def open(context: FunctionContext): Unit = {// access "hashcode_factor" parameter// "12" would be the default value if parameter does not exist    hashcode_factor = context.getJobParameter("hashcode_factor", "12").toInt  }  def eval(s: String): Int = {    s.hashCode()+hashcode_factor  }}

2.注册UDTF函数
2.1 注册相关方法
在UDTF和UDAF中,我们发现,注册使用的具体函数是包含有一定的格式限制,比如此时我们需要注册的UDTF函数,Split类继承自TableFunction[(String,Int)],那么我们的函数注册中,在java程序编译时会去检查该泛型,后续实际运行时,解析我们的UDTF函数时,对泛型内的类型进行序列化和反序列化时会和我们规定的泛型进行对比,如果此时我们的数据schema或者说我们的数据本身格式不匹配抑或是我们给出了数据的泛型,编译过了擦除掉之后,在实际运行中却发现并没有该字段信息,那么同样也会出错,所以此时,我们更加要去注意产生该问题的根源,那么根源究竟是什么呢,话不多说,接着看代码。

我们需要注册函数的registerFunction方法,来自于StreamTableEnvironment中的registerFunction方法,此处的类请大家和之前区别一下,注意,此处这个类在后续我们使用UDAF时也会使用,那么原因在于这两个函数加入了泛型的约束,所以兜兜转转,会有中间的一个检查判断过程,接着,同样是在TableEnvironment这个类中的registerTableFunctionInternal方法,下来,我会分别给出两个方法,请看代码。

StreamTableEnvironment

/**
    * Registers a [[TableFunction]] under a unique name in the TableEnvironment's catalog.
    * Registered functions can be referenced in SQL queries.
    *
    * @param name The name under which the function is registered.
    * @param tf The TableFunction to register
    */
  def registerFunction[T: TypeInformation](name: String, tf: TableFunction[T]): Unit = {
    registerTableFunctionInternal(name, tf)
  }

TableEnvironment

/**
    * Registers a [[TableFunction]] under a unique name. Replaces already existing
    * user-defined functions under this name.
    */
private[flink] def registerTableFunctionInternal[T: TypeInformation](
    name: String, function: TableFunction[T]): Unit = {
// check if class not Scala object
    checkNotSingleton(function.getClass)
    // check if class could be instantiated
checkForInstantiation(function.getClass)

val typeInfo: TypeInformation[_] = if (function.getResultType != null) {
function.getResultType
    } else {
      implicitly[TypeInformation[T]]
    }

// register in Table API
functionCatalog.registerFunction(name, function.getClass)

    // register in SQL API
val sqlFunction = createTableSqlFunction(name, name, function, typeInfo, typeFactory)
functionCatalog.registerSqlFunction(sqlFunction)
  }

看到了吧,这个T类型规范了我们注册的这个函数的类型,这个在自定义注册时一定要小心;注意我们返回类型是否和我们注册时规定的泛型一致,要让注册能过编译,也要让函数能顺利运行。

2.2 函数示例

class Split(separator: String) extends TableFunction[(String, Int)] {  def eval(str: String): Unit = {    str.split(separator).foreach(x => collect(x, x.length))  }}

这个里面的返回即是(String, Int),因为我们注册时,已经获取了该类的泛型,所以此时,只需要我们在注册前引入隐式转换即可。

2.3 注册部分

//register table schema: [a: String]
    tableEnv.registerDataStream("mySplit", textFiles,'a)
    val mySplit: Table = tableEnv.sqlQuery("select * from mySplit")
    mySplit.printSchema()

    //register udtf
    val split = new Split(",")
    val dslTable =mySplit.join(split('a) as ('word,'length)).select('a,'word,'length)
    val dslLeftTable = mySplit.leftOuterJoin(split('a) as  ('word,'length)).select('a,'word,'length)

    tableEnv.registerFunction("split",split)
    val sqlJoin =  tableEnv.sqlQuery("select a,item,counts from mySplit,LATERAL TABLE(split(a)) as T(item,counts)")
    val sqlLeftJoin =tableEnv.sqlQuery("select a, item, counts from mySplit 
    LEFT JOIN LATERAL TABLE(split(a)) as T(item, counts) ON TRUE")

3.注册UDAF函数
3.1 注册函数
看了上面两种,其实无非是,UDF函数直接注册就可以,UDTF在注册时需要我们规范下类的泛型,而UDAF则不止是这些,不过,take it easy放轻松,趟过的坑马上列出来给你看,哈哈,这里提前说,多了一个返回的类,而此处这个类你们就可要小心啦~~~呜啦啦,开始吧,骚年。。。

此时,我们的具体思路是,要先给出一个类,比如有几个成员变量作为后续AggregateFunction的一个辅助类,然后UDAF函数中用到了它还有它其中的成员变量,下来,改变下思路,先看注册的函数吧:

WeightedAvgAccum

import java.lang.{Long => JLong, Integer => JInteger}
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}

class WeightedAvgAccum extends JTuple2[JLong, JInteger]  {
var   sum = 0L
var   count =0
}

WeightedAvg

/** * Weighted Average user-defined aggregate function. */class WeightedAvg extends AggregateFunction[JLong, CountAccumulator] {override def createAccumulator(): WeightedAvgAccum = {    new WeightedAvgAccum  }  override def getValue(acc: WeightedAvgAccum): JLong = {if (acc.count == 0) {        null    } else {        acc.sum / acc.count    }  }    def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {    acc.sum += iValue * iWeight    acc.count += iWeight  }  def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {    acc.sum -= iValue * iWeight    acc.count -= iWeight  }      def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = {    val iter = it.iterator()while (iter.hasNext) {      val a = iter.next()      acc.count += a.count      acc.sum += a.sum    }  }  def resetAccumulator(acc: WeightedAvgAccum): Unit = {    acc.count = 0    acc.sum = 0L  }override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = {    new TupleTypeInfo(classOf[WeightedAvgAccum], Types.LONG, Types.INT)  }override def getResultType: TypeInformation[JLong] = Types.LONG}

3.2 注册方法及问题详解
大家伙受累看完这两个类,没什么问题的话我们接着往下讲,官网例子,如下包换,在我们使用flink注册时,没什么问题啊,那么你凭什么说要注意呢?此处我们的前提是用户上传到我们的系统,我们通过反射来拿到该类的实例然后再去注册,那么,问题就来了,如果平时使用没有任何问题,而我们自动让flink识别注册时,flink却做不到,原因为何,请先看看,平时使用和我们自动注册时的一些区别;

日常玩法:

tableEnv.registerFunction("wAvg",new WeightedAvg())
val  weightAvgTable = tableEnv.sqlQuery("select item,wAvg(points,counts) AS avgPoints FROM myTable GROUP BY item")

假日玩法:

implicit  val infoTypes = TypeInformation.of(classOf[Object])
tableEnv.registerFunction[Object,Object](m("name").toString, 
ReflectHelper.newInstanceByClsName(m("className").toString,this.getClass.getClassLoader,sm.dac))

亲爱的们,看到问题了吗?其实原因就是我们的程序不是你,它无法推断具体的类的类型,这个需要是我们给出一定的范围或者说我们规范这个流程,即便是这样,引入了对object的隐式转换,过得了编译,但是运行时还会报错,不信,你看:

2019-07-10 19:49:40,872 INFO  org.apache.flink.runtime.executiongraph.ExecutionGraph        
- groupBy: (item), window: (TumblingGroupWindow('w$, 'proctime, 5000.millis)), 
select: (item, udafWeightAvg(counts, points) AS c) -> select: (c, item) -> job_1542187919994 -> Sink: Print to Std. Out (2/2) 
(3a9098c7ddb0a115349f6d89aba606ff) switched from RUNNING to FAILED.
org.apache.flink.types.NullFieldException: Field 0 is null, but expected to hold a value.
at org.apache.flink.api.java.typeutils.runtime.TupleSerializer.serialize(TupleSerializer.java:127)
at org.apache.flink.api.java.typeutils.runtime.TupleSerializer.serialize(TupleSerializer.java:30)
at org.apache.flink.api.java.typeutils.runtime.RowSerializer.serialize(RowSerializer.java:160)
at org.apache.flink.api.java.typeutils.runtime.RowSerializer.serialize(RowSerializer.java:46)
at org.apache.flink.contrib.streaming.state.AbstractRocksDBState.getValueBytes(AbstractRocksDBState.java:171)
at org.apache.flink.contrib.streaming.state.AbstractRocksDBAppendingState.updateInternal
    (AbstractRocksDBAppendingState.java:80)
at org.apache.flink.contrib.streaming.state.RocksDBAggregatingState.add(RocksDBAggregatingState.java:105)
at org.apache.flink.streaming.runtime.operators.windowing.WindowOperator.processElement(WindowOperator.java:391)
at org.apache.flink.streaming.runtime.io.StreamInputProcessor.processInput(StreamInputProcessor.java:202)
at org.apache.flink.streaming.runtime.tasks.OneInputStreamTask.run(OneInputStreamTask.java:105)
at org.apache.flink.streaming.runtime.tasks.StreamTask.invoke(StreamTask.java:300)
at org.apache.flink.runtime.taskmanager.Task.run(Task.java:711)
at java.lang.Thread.run(Thread.java:745)
Caused by: java.lang.NullPointerException
at org.apache.flink.api.common.typeutils.base.LongSerializer.serialize(LongSerializer.java:63)
at org.apache.flink.api.common.typeutils.base.LongSerializer.serialize(LongSerializer.java:27)
at org.apache.flink.api.java.typeutils.runtime.TupleSerializer.serialize(TupleSerializer.java:125)
    ... 12 more
2019-07-10 19:49:40,873 INFO  org.apache.flink.runtime.executiongraph.ExecutionGraph        
 - Job csv_test csv_test_udaf_acc (ec1e56648123905f7ffd85ba884e89ca) switched from state RUNNING to FAILING.
org.apache.flink.types.NullFieldException: Field 0 is null, but expected to hold a value.

报错很显然,flink大群里大佬翻译告诉我说这是tuple里面第一位的数据为空,序列化long为空导致,而我傻傻的找了大半天,发现我的程序没问题啊,那这到底问题出在哪呢?其实有时候,并不是我们的错,只是我们太高看了别人,经过苦苦寻找,才找到原来官网的例子中有猫腻。
下来,同样的,看一下StreamTableEnvironment和TableEnvironment这两个类中的注册方法。

StreamTableEnvironment

/**
    * Registers an [[AggregateFunction]] under a unique name in the TableEnvironment's catalog.
    * Registered functions can be referenced in Table API and SQL queries.
    *
    * @param name The name under which the function is registered.
    * @param f The AggregateFunction to register.
    * @tparam T The type of the output value.
    * @tparam ACC The type of aggregate accumulator.
    */
  def registerFunction[T: TypeInformation, ACC: TypeInformation](
      name: String,
      f: AggregateFunction[T, ACC])
  : Unit = {
    registerAggregateFunctionInternal[T, ACC](name, f)
  }

TableEnvironment

/**
    * Registers an [[AggregateFunction]] under a unique name. Replaces already existing
    * user-defined functions under this name.
    */
private[flink] def registerAggregateFunctionInternal[T: TypeInformation, ACC: TypeInformation](
      name: String, function: AggregateFunction[T, ACC]): Unit = {
// check if class not Scala object
    checkNotSingleton(function.getClass)
    // check if class could be instantiated
checkForInstantiation(function.getClass)

val resultTypeInfo: TypeInformation[_] = getResultTypeOfAggregateFunction(
function,
      implicitly[TypeInformation[T]])

val accTypeInfo: TypeInformation[_] = getAccumulatorTypeOfAggregateFunction(
function,
      implicitly[TypeInformation[ACC]])

    // register in Table API
functionCatalog.registerFunction(name, function.getClass)

    // register in SQL API
val sqlFunctions = createAggregateSqlFunction(
      name,
      name,
function,
      resultTypeInfo,
      accTypeInfo,
      typeFactory)

functionCatalog.registerSqlFunction(sqlFunctions)
  }

具体呢如下:

一是,我们此处规范的是一个[Object,Object]的泛型,对于Accum类里的Jlong,JInteger没法起到限定,进而在解析时无法找到对应的类型,这个反应在TableEnvironment里面的T和ACC,T对应上了,而ACC却没有;

二是,我们的Avg类里面,返回的却是一个 new TupleTypeInfo(classOf[WeightedAvgAccum], Types.LONG, Types.INT)和Types.LONG,那么不难发现,这个tuple里三个元素,我们其实只需要把第一个解析了,而另外两个都是套在它里面的,所以Object只有一个,而WeightedAvgAccum里却有三个,完全不对应,所以我们需要更改这两个类,改完后具体代码如下:

WeightedAvgAccum

class WeightedAvgAccum {
var   sum = 1L
var   count = 2
}

WeightedAvg

import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.TupleTypeInfo
import org.apache.flink.table.api.Types
import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction, utils}
import java.lang.{Integer => JInteger, Long => JLong}
import org.apache.flink.table.functions.utils.UserDefinedFunctionUtils

class WeightedAvg extends  AggregateFunction[JLong, WeightedAvgAccum]{

override def createAccumulator(): WeightedAvgAccum = {
    new WeightedAvgAccum
  }

override def getValue(acc: WeightedAvgAccum): JLong = {
if (acc.count == 0) {
      null
    } else {
      acc.sum / acc.count
    }
  }

  def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
    acc.sum += iValue * iWeight
    acc.count += iWeight
  }

  def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = {
    acc.sum -= iValue * iWeight
    acc.count -= iWeight
  }

  def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = {
    val iter = it.iterator()
while (iter.hasNext) {
      val a = iter.next()
      acc.count += a.count
      acc.sum += a.sum
    }
  }

  def resetAccumulator(acc: WeightedAvgAccum): Unit = {
    acc.count = 0
    acc.sum = 0L
  }

override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = TypeInformation.of(classOf[WeightedAvgAccum])


override def getResultType: TypeInformation[JLong] = Types.LONG

}

那么具体运行如何呢,具体如下:

image

4.问题总结
巴拉巴拉说了这么多,可能对于大神来说并不是什么新鲜问题,但是我相信初次接触的小白来讲还是或多或少有一些帮助的,所以希望后续在码代码的过程中要多方面去思考,也希望还是要加强对底层知识的深度认识,也能够更快的触类旁通,好啦,今天就到这啦,后面持续为大家带来我自己的一些见解和认知,撒由那拉~~~

网友评论

登录后评论
0/500
评论
王知无
+ 关注