My Project
Loading...
Searching...
No Matches
src.model.TextToParams Class Reference
Inheritance diagram for src.model.TextToParams:

Public Member Functions

 __init__ (self, num_plugin_parameters=9)
 forward (self, input_ids, attention_mask)

Public Attributes

 bert = AutoModel.from_pretrained("prajjwal1/bert-tiny")
 head

Constructor & Destructor Documentation

◆ __init__()

src.model.TextToParams.__init__ ( self,
num_plugin_parameters = 9 )

Member Function Documentation

◆ forward()

src.model.TextToParams.forward ( self,
input_ids,
attention_mask )

Member Data Documentation

◆ bert

src.model.TextToParams.bert = AutoModel.from_pretrained("prajjwal1/bert-tiny")

◆ head

src.model.TextToParams.head
Initial value:
= nn.Sequential(
nn.Linear(128, 256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(256, num_plugin_parameters),
nn.Sigmoid() # Output 0.0 - 1.0
)

The documentation for this class was generated from the following file: