flink-python/pyflink/fn_execution/beam/beam_coder_impl_slow.py (403 lines of code) (raw):
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import datetime
import decimal
import pickle
import struct
from typing import Any
from typing import Generator
from typing import List
import pyarrow as pa
from apache_beam.coders.coder_impl import StreamCoderImpl, create_InputStream, create_OutputStream
from pyflink.fn_execution.ResettableIO import ResettableIO
from pyflink.table.types import Row
from pyflink.table.utils import pandas_to_arrow, arrow_to_pandas
class FlattenRowCoderImpl(StreamCoderImpl):
def __init__(self, field_coders):
self._field_coders = field_coders
self._field_count = len(field_coders)
self._leading_complete_bytes_num = self._field_count // 8
self._remaining_bits_num = self._field_count % 8
self.null_mask_search_table = self.generate_null_mask_search_table()
self.null_byte_search_table = (0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01)
self.data_out_stream = create_OutputStream()
@staticmethod
def generate_null_mask_search_table():
"""
Each bit of one byte represents if the column at the corresponding position is None or not,
e.g. 0x84 represents the first column and the sixth column are None.
"""
null_mask = []
for b in range(256):
every_num_null_mask = [(b & 0x80) > 0, (b & 0x40) > 0, (b & 0x20) > 0, (b & 0x10) > 0,
(b & 0x08) > 0, (b & 0x04) > 0, (b & 0x02) > 0, (b & 0x01) > 0]
null_mask.append(tuple(every_num_null_mask))
return tuple(null_mask)
def encode_to_stream(self, iter_value, out_stream, nested):
field_coders = self._field_coders
data_out_stream = self.data_out_stream
for value in iter_value:
self._write_null_mask(value, data_out_stream)
for i in range(self._field_count):
item = value[i]
if item is not None:
field_coders[i].encode_to_stream(item, data_out_stream, nested)
out_stream.write_var_int64(data_out_stream.size())
out_stream.write(data_out_stream.get())
data_out_stream._clear()
def decode_from_stream(self, in_stream, nested):
while in_stream.size() > 0:
in_stream.read_var_int64()
yield self._decode_one_row_from_stream(in_stream, nested)
def _decode_one_row_from_stream(self, in_stream: create_InputStream, nested: bool) -> List:
null_mask = self._read_null_mask(in_stream)
return [None if null_mask[idx] else self._field_coders[idx].decode_from_stream(
in_stream, nested) for idx in range(0, self._field_count)]
def _write_null_mask(self, value, out_stream):
field_pos = 0
null_byte_search_table = self.null_byte_search_table
remaining_bits_num = self._remaining_bits_num
for _ in range(self._leading_complete_bytes_num):
b = 0x00
for i in range(0, 8):
if value[field_pos + i] is None:
b |= null_byte_search_table[i]
field_pos += 8
out_stream.write_byte(b)
if remaining_bits_num:
b = 0x00
for i in range(remaining_bits_num):
if value[field_pos + i] is None:
b |= null_byte_search_table[i]
out_stream.write_byte(b)
def _read_null_mask(self, in_stream):
null_mask = []
null_mask_search_table = self.null_mask_search_table
remaining_bits_num = self._remaining_bits_num
for _ in range(self._leading_complete_bytes_num):
b = in_stream.read_byte()
null_mask.extend(null_mask_search_table[b])
if remaining_bits_num:
b = in_stream.read_byte()
null_mask.extend(null_mask_search_table[b][0:remaining_bits_num])
return null_mask
def __repr__(self):
return 'FlattenRowCoderImpl[%s]' % ', '.join(str(c) for c in self._field_coders)
class RowCoderImpl(FlattenRowCoderImpl):
def __init__(self, field_coders):
super(RowCoderImpl, self).__init__(field_coders)
def encode_to_stream(self, value, out_stream, nested):
field_coders = self._field_coders
self._write_null_mask(value, out_stream)
for i in range(self._field_count):
item = value[i]
if item is not None:
field_coders[i].encode_to_stream(item, out_stream, nested)
def decode_from_stream(self, in_stream, nested):
return Row(*self._decode_one_row_from_stream(in_stream, nested))
def __repr__(self):
return 'RowCoderImpl[%s]' % ', '.join(str(c) for c in self._field_coders)
class TableFunctionRowCoderImpl(StreamCoderImpl):
def __init__(self, flatten_row_coder):
self._flatten_row_coder = flatten_row_coder
self._field_count = flatten_row_coder._field_count
def encode_to_stream(self, iter_value, out_stream, nested):
for value in iter_value:
if value:
if self._field_count == 1:
value = self._create_tuple_result(value)
self._flatten_row_coder.encode_to_stream(value, out_stream, nested)
out_stream.write_var_int64(1)
out_stream.write_byte(0x00)
def decode_from_stream(self, in_stream, nested):
return self._flatten_row_coder.decode_from_stream(in_stream, nested)
@staticmethod
def _create_tuple_result(value: List) -> Generator:
for result in value:
yield (result,)
def __repr__(self):
return 'TableFunctionRowCoderImpl[%s]' % repr(self._flatten_row_coder)
class ArrayCoderImpl(StreamCoderImpl):
def __init__(self, elem_coder):
self._elem_coder = elem_coder
def encode_to_stream(self, value, out_stream, nested):
out_stream.write_bigendian_int32(len(value))
for elem in value:
if elem is None:
out_stream.write_byte(False)
else:
out_stream.write_byte(True)
self._elem_coder.encode_to_stream(elem, out_stream, nested)
def decode_from_stream(self, in_stream, nested):
size = in_stream.read_bigendian_int32()
elements = [self._elem_coder.decode_from_stream(in_stream, nested)
if in_stream.read_byte() else None for _ in range(size)]
return elements
def __repr__(self):
return 'ArrayCoderImpl[%s]' % repr(self._elem_coder)
class PickledBytesCoderImpl(StreamCoderImpl):
def __init__(self):
self.field_coder = BinaryCoderImpl()
def encode_to_stream(self, value, out_stream, nested):
coded_data = pickle.dumps(value)
self.field_coder.encode_to_stream(coded_data, out_stream, nested)
def decode_from_stream(self, in_stream, nested):
return self._decode_one_value_from_stream(in_stream, nested)
def _decode_one_value_from_stream(self, in_stream: create_InputStream, nested):
real_data = self.field_coder.decode_from_stream(in_stream, nested)
value = pickle.loads(real_data)
return value
def __repr__(self) -> str:
return 'PickledBytesCoderImpl[%s]' % str(self.field_coder)
class DataStreamStatelessMapCoderImpl(StreamCoderImpl):
def __init__(self, field_coder):
self._field_coder = field_coder
self.data_out_stream = create_OutputStream()
def encode_to_stream(self, iter_value, stream,
nested): # type: (Any, create_OutputStream, bool) -> None
data_out_stream = self.data_out_stream
for value in iter_value:
self._field_coder.encode_to_stream(value, data_out_stream, nested)
stream.write_var_int64(data_out_stream.size())
stream.write(data_out_stream.get())
data_out_stream._clear()
def decode_from_stream(self, stream, nested): # type: (create_InputStream, bool) -> Any
while stream.size() > 0:
stream.read_var_int64()
yield self._field_coder.decode_from_stream(stream, nested)
def __repr__(self):
return 'DataStreamStatelessMapCoderImpl[%s]' % repr(self._field_coder)
class DataStreamStatelessFlatMapCoderImpl(StreamCoderImpl):
def __init__(self, field_coder):
self._field_coder = field_coder
def encode_to_stream(self, iter_value, stream,
nested): # type: (Any, create_OutputStream, bool) -> None
for value in iter_value:
self._field_coder.encode_to_stream(value, stream, nested)
def decode_from_stream(self, stream, nested):
return self._field_coder.decode_from_stream(stream, nested)
def __str__(self) -> str:
return 'DataStreamStatelessFlatMapCoderImpl[%s]' % repr(self._field_coder)
class MapCoderImpl(StreamCoderImpl):
def __init__(self, key_coder, value_coder):
self._key_coder = key_coder
self._value_coder = value_coder
def encode_to_stream(self, map_value, out_stream, nested):
out_stream.write_bigendian_int32(len(map_value))
for key in map_value:
self._key_coder.encode_to_stream(key, out_stream, nested)
value = map_value[key]
if value is None:
out_stream.write_byte(True)
else:
out_stream.write_byte(False)
self._value_coder.encode_to_stream(map_value[key], out_stream, nested)
def decode_from_stream(self, in_stream, nested):
size = in_stream.read_bigendian_int32()
map_value = {}
for _ in range(size):
key = self._key_coder.decode_from_stream(in_stream, nested)
is_null = in_stream.read_byte()
if is_null:
map_value[key] = None
else:
value = self._value_coder.decode_from_stream(in_stream, nested)
map_value[key] = value
return map_value
def __repr__(self):
return 'MapCoderImpl[%s]' % ' : '.join([repr(self._key_coder), repr(self._value_coder)])
class BigIntCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, out_stream, nested):
out_stream.write_bigendian_int64(value)
def decode_from_stream(self, in_stream, nested):
return in_stream.read_bigendian_int64()
class TinyIntCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, out_stream, nested):
out_stream.write(struct.pack('b', value))
def decode_from_stream(self, in_stream, nested):
return struct.unpack('b', in_stream.read(1))[0]
class SmallIntCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, out_stream, nested):
out_stream.write(struct.pack('>h', value))
def decode_from_stream(self, in_stream, nested):
return struct.unpack('>h', in_stream.read(2))[0]
class IntCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, out_stream, nested):
out_stream.write_bigendian_int32(value)
def decode_from_stream(self, in_stream, nested):
return in_stream.read_bigendian_int32()
class BooleanCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, out_stream, nested):
out_stream.write_byte(value)
def decode_from_stream(self, in_stream, nested):
return not not in_stream.read_byte()
class FloatCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, out_stream, nested):
out_stream.write(struct.pack('>f', value))
def decode_from_stream(self, in_stream, nested):
return struct.unpack('>f', in_stream.read(4))[0]
class DoubleCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, out_stream, nested):
out_stream.write_bigendian_double(value)
def decode_from_stream(self, in_stream, nested):
return in_stream.read_bigendian_double()
class DecimalCoderImpl(StreamCoderImpl):
def __init__(self, precision, scale):
self.context = decimal.Context(prec=precision)
self.scale_format = decimal.Decimal(10) ** -scale
def encode_to_stream(self, value, out_stream, nested):
user_context = decimal.getcontext()
decimal.setcontext(self.context)
value = value.quantize(self.scale_format)
bytes_value = str(value).encode("utf-8")
out_stream.write_bigendian_int32(len(bytes_value))
out_stream.write(bytes_value, False)
decimal.setcontext(user_context)
def decode_from_stream(self, in_stream, nested):
user_context = decimal.getcontext()
decimal.setcontext(self.context)
size = in_stream.read_bigendian_int32()
value = decimal.Decimal(in_stream.read(size).decode("utf-8")).quantize(self.scale_format)
decimal.setcontext(user_context)
return value
class BigDecimalCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, stream, nested):
bytes_value = str(value).encode("utf-8")
stream.write_bigendian_int32(len(bytes_value))
stream.write(bytes_value, False)
def decode_from_stream(self, stream, nested):
size = stream.read_bigendian_int32()
value = decimal.Decimal(stream.read(size).decode("utf-8"))
return value
class TupleCoderImpl(StreamCoderImpl):
def __init__(self, field_coders):
self._field_coders = field_coders
self._field_count = len(field_coders)
def encode_to_stream(self, value, out_stream, nested):
field_coders = self._field_coders
for i in range(self._field_count):
field_coders[i].encode_to_stream(value[i], out_stream, nested)
def decode_from_stream(self, stream, nested):
decoded_list = [field_coder.decode_from_stream(stream, nested)
for field_coder in self._field_coders]
return (*decoded_list,)
def __repr__(self) -> str:
return 'TupleCoderImpl[%s]' % ', '.join(str(c) for c in self._field_coders)
class BinaryCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, out_stream, nested):
out_stream.write_bigendian_int32(len(value))
out_stream.write(value, False)
def decode_from_stream(self, in_stream, nested):
size = in_stream.read_bigendian_int32()
return in_stream.read(size)
class CharCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, out_stream, nested):
bytes_value = value.encode("utf-8")
out_stream.write_bigendian_int32(len(bytes_value))
out_stream.write(bytes_value, False)
def decode_from_stream(self, in_stream, nested):
size = in_stream.read_bigendian_int32()
return in_stream.read(size).decode("utf-8")
class DateCoderImpl(StreamCoderImpl):
EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
def encode_to_stream(self, value, out_stream, nested):
out_stream.write_bigendian_int32(self.date_to_internal(value))
def decode_from_stream(self, in_stream, nested):
value = in_stream.read_bigendian_int32()
return self.internal_to_date(value)
def date_to_internal(self, d):
return d.toordinal() - self.EPOCH_ORDINAL
def internal_to_date(self, v):
return datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
class TimeCoderImpl(StreamCoderImpl):
def encode_to_stream(self, value, out_stream, nested):
out_stream.write_bigendian_int32(self.time_to_internal(value))
def decode_from_stream(self, in_stream, nested):
value = in_stream.read_bigendian_int32()
return self.internal_to_time(value)
def time_to_internal(self, t):
milliseconds = (t.hour * 3600000
+ t.minute * 60000
+ t.second * 1000
+ t.microsecond // 1000)
return milliseconds
def internal_to_time(self, v):
seconds, milliseconds = divmod(v, 1000)
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
return datetime.time(hours, minutes, seconds, milliseconds * 1000)
class TimestampCoderImpl(StreamCoderImpl):
def __init__(self, precision):
self.precision = precision
def is_compact(self):
return self.precision <= 3
def encode_to_stream(self, value, out_stream, nested):
milliseconds, nanoseconds = self.timestamp_to_internal(value)
if self.is_compact():
assert nanoseconds == 0
out_stream.write_bigendian_int64(milliseconds)
else:
out_stream.write_bigendian_int64(milliseconds)
out_stream.write_bigendian_int32(nanoseconds)
def decode_from_stream(self, in_stream, nested):
if self.is_compact():
milliseconds = in_stream.read_bigendian_int64()
nanoseconds = 0
else:
milliseconds = in_stream.read_bigendian_int64()
nanoseconds = in_stream.read_bigendian_int32()
return self.internal_to_timestamp(milliseconds, nanoseconds)
def timestamp_to_internal(self, timestamp):
seconds = int(timestamp.replace(tzinfo=datetime.timezone.utc).timestamp())
microseconds_of_second = timestamp.microsecond
milliseconds = seconds * 1000 + microseconds_of_second // 1000
nanoseconds = microseconds_of_second % 1000 * 1000
return milliseconds, nanoseconds
def internal_to_timestamp(self, milliseconds, nanoseconds):
second, microsecond = (milliseconds // 1000,
milliseconds % 1000 * 1000 + nanoseconds // 1000)
return datetime.datetime.utcfromtimestamp(second).replace(microsecond=microsecond)
class LocalZonedTimestampCoderImpl(TimestampCoderImpl):
def __init__(self, precision, timezone):
super(LocalZonedTimestampCoderImpl, self).__init__(precision)
self.timezone = timezone
def internal_to_timestamp(self, milliseconds, nanoseconds):
return self.timezone.localize(
super(LocalZonedTimestampCoderImpl, self).internal_to_timestamp(
milliseconds, nanoseconds))
class ArrowCoderImpl(StreamCoderImpl):
def __init__(self, schema, row_type, timezone):
self._schema = schema
self._field_types = row_type.field_types()
self._timezone = timezone
self._resettable_io = ResettableIO()
self._batch_reader = ArrowCoderImpl._load_from_stream(self._resettable_io)
self._batch_writer = pa.RecordBatchStreamWriter(self._resettable_io, self._schema)
self.data_out_stream = create_OutputStream()
self._resettable_io.set_output_stream(self.data_out_stream)
def encode_to_stream(self, iter_cols, out_stream, nested):
data_out_stream = self.data_out_stream
for cols in iter_cols:
self._batch_writer.write_batch(
pandas_to_arrow(self._schema, self._timezone, self._field_types, cols))
out_stream.write_var_int64(data_out_stream.size())
out_stream.write(data_out_stream.get())
data_out_stream._clear()
def decode_from_stream(self, in_stream, nested):
while in_stream.size() > 0:
yield self._decode_one_batch_from_stream(in_stream)
@staticmethod
def _load_from_stream(stream):
reader = pa.ipc.open_stream(stream)
for batch in reader:
yield batch
def _decode_one_batch_from_stream(self, in_stream: create_InputStream) -> List:
self._resettable_io.set_input_bytes(in_stream.read_all(True))
# there is only one arrow batch in the underlying input stream
return arrow_to_pandas(self._timezone, self._field_types, [next(self._batch_reader)])
def __repr__(self):
return 'ArrowCoderImpl[%s]' % self._schema
class PassThroughLengthPrefixCoderImpl(StreamCoderImpl):
def __init__(self, value_coder):
self._value_coder = value_coder
def encode_to_stream(self, value, out: create_OutputStream, nested: bool) -> Any:
self._value_coder.encode_to_stream(value, out, nested)
def decode_from_stream(self, in_stream: create_InputStream, nested: bool) -> Any:
return self._value_coder.decode_from_stream(in_stream, nested)
def get_estimated_size_and_observables(self, value: Any, nested=False):
return 0, []
def __repr__(self):
return 'PassThroughLengthPrefixCoderImpl[%s]' % self._value_coder