import tensorflow

from dataiku.base.utils import package_is_at_least
from dataiku.doctor.deep_learning.tfcompat import \
    _build_and_fit_model,\
    _compute_perf_and_save_best_model_callback as callback,\
    _tf_imports_compat,\
    _list_physical_devices

set_session = _tf_imports_compat.set_session
keras_load_model = _tf_imports_compat.keras_load_model
pad_sequences = _tf_imports_compat.pad_sequences
text = _tf_imports_compat.text
eval = _tf_imports_compat.eval
get_loss = _tf_imports_compat.get_loss
Session = _tf_imports_compat.Session
ConfigProto = _tf_imports_compat.ConfigProto
Sequence = _tf_imports_compat.Sequence
Callback = _tf_imports_compat.Callback
optimizers = _tf_imports_compat.optimizers

if package_is_at_least(tensorflow, "2.2"):
    build_and_fit_model = _build_and_fit_model._build_and_fit_model_tf2
    compute_perf_and_save_best_model_callback = callback._compute_perf_and_save_best_model_callback_tf2
    list_physical_devices = _list_physical_devices._list_physical_devices_tf2
else:
    build_and_fit_model = _build_and_fit_model._build_and_fit_model_tf1
    compute_perf_and_save_best_model_callback = callback._compute_perf_and_save_best_model_callback_tf1
    list_physical_devices = _list_physical_devices._list_physical_devices_tf1

