为了账号安全,请及时绑定邮箱和手机立即绑定

如何使用 neuraxle 实现延迟数据加载的存储库?

如何使用 neuraxle 实现延迟数据加载的存储库?

拉丁的传说 2023-12-26 15:57:37
在neuraxle 文档中显示了一个示例,使用存储库在管道中延迟加载数据,请参阅以下代码:from neuraxle.pipeline import Pipeline, MiniBatchSequentialPipelinefrom neuraxle.base import ExecutionContextfrom neuraxle.steps.column_transformer import ColumnTransformerfrom neuraxle.steps.flow import TrainOnlyWrappertraining_data_ids = training_data_repository.get_all_ids()context = ExecutionContext('caching_folder').set_service_locator({    BaseRepository: training_data_repository})pipeline = Pipeline([    ConvertIDsToLoadedData().assert_has_services(BaseRepository),    ColumnTransformer([        (range(0, 2), DateToCosineEncoder()),        (3, CategoricalEnum(categeories_count=5, starts_at_zero=True)),    ]),    Normalizer(),    TrainOnlyWrapper(DataShuffler()),    MiniBatchSequentialPipeline([        Model()    ], batch_size=128)]).with_context(context)但是,它没有显示如何实现BaseRepository和ConvertIDsToLoadedData类。实施这些课程的最佳方式是什么?谁能举个例子吗?
查看完整描述

1 回答

?
慕勒3428872

TA贡献1848条经验 获得超6个赞

我没有检查以下是否编译,但它应该如下所示。如果您发现需要更改的内容并尝试编译它,请有人编辑此答案:


class BaseDataRepository(ABC): 


    @abstractmethod

    def get_all_ids(self) -> List[int]: 

        pass


    @abstractmethod

    def get_data_from_id(self, _id: int) -> object: 

        pass


class InMemoryDataRepository(BaseDataRepository): 

    def __init__(self, ids, data): 

        self.ids: List[int] = ids

        self.data: Dict[int, object] = data


    def get_all_ids(self) -> List[int]: 

        return list(self.ids)


    def get_data_from_id(self, _id: int) -> object: 

        return self.data[_id]


class ConvertIDsToLoadedData(BaseStep): 

    def _transform_data_container(self, data_container: DataContainer, context: ExecutionContext): 

        repo: BaseDataRepository = context.get_service(BaseDataRepository)

        ids = data_container.data_inputs


        # Replace data ids by their loaded object counterpart: 

        data_container.data_inputs = [repo.get_data_from_id(_id) for _id in ids]


        return data_container, context


context = ExecutionContext('caching_folder').set_service_locator({

    BaseDataRepository: InMemoryDataRepository(ids, data)  # or insert here any other replacement class that inherits from `BaseDataRepository` when you'll change the database to a real one (e.g.: SQL) rather than a cheap "InMemory" stub. 

})

查看完整回答
反对 回复 2023-12-26
  • 1 回答
  • 0 关注
  • 50 浏览
慕课专栏
更多

添加回答

举报

0/150
提交
取消
意见反馈 帮助中心 APP下载
官方微信