Skip to content

process_raw_pred方法计算逻辑 #47

@zhilinfan

Description

@zhilinfan

首先非常感谢贵校提供的代码,我在看这段的时候有点疑问,不知道是不是代码逻辑问题呢🙋

def process_raw_pred(raw_question_matrix, raw_pred, num_questions: int) -> tuple:
questions = torch.nonzero(raw_question_matrix)[1:, 1] % num_questions ##torch.nonzero(raw_question_matrix)[1:, 1]表示raw_question_matrix中非0的位置,即用户的有效回答
length = questions.shape[0]
pred = raw_pred[: length]
pred = pred.gather(1, questions.view(-1, 1)).flatten()
truth = torch.nonzero(raw_question_matrix)[1:, 1] // num_questions #truth表示真实作答情况,0表示回答正确,1表示回答错误,与原始数据相反?
# truth = 1 - truth#这里逻辑是不是写错了!! 0表示回答正确,1表示回答错误,与原始数据相反?

return pred, truth

关键在这里➡️
truth = torch.nonzero(raw_question_matrix)[1:, 1] // num_questions
比如one -hot encode后 125错了 numofq=100 那就是encode_q[125]=1 ,但是还原时候,125//100 =1 ,这代表的是回答正确呀?

希望您的解答,谢谢

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions