diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..ee5082a --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,25 @@ +# PR Title + +## Summary +- What: Briefly describe the change +- Why: Problem it solves / value +- How: Key changes (bullets) + +## Testing +- [ ] Unit tests added/updated +- [ ] Manual/Integration verified +- [ ] CI passes locally (tests + lint) + +## Risk/Impact +- Breaking changes: yes/no (explain) +- Migration notes (if any) + +## Checklist +- [ ] Small, logical commits +- [ ] Docs/README updated (if user-facing) +- [ ] Backwards compatible (or documented) +- [ ] Security/Privacy reviewed (if applicable) +- [ ] Linked issue(s) / references + +## Notes +- Screenshots/logs (if relevant) diff --git a/zerfoo.pb.go b/zerfoo.pb.go index 7206e1e..7dca33e 100644 --- a/zerfoo.pb.go +++ b/zerfoo.pb.go @@ -25,34 +25,55 @@ const ( type Tensor_DataType int32 const ( - Tensor_FLOAT32 Tensor_DataType = 0 - Tensor_FLOAT16 Tensor_DataType = 1 - Tensor_BFLOAT16 Tensor_DataType = 2 - Tensor_FLOAT8 Tensor_DataType = 3 - Tensor_INT32 Tensor_DataType = 4 - Tensor_INT64 Tensor_DataType = 5 - Tensor_FLOAT64 Tensor_DataType = 6 + Tensor_BFLOAT16 Tensor_DataType = 0 + Tensor_BOOL Tensor_DataType = 1 + Tensor_FLOAT16 Tensor_DataType = 2 + Tensor_FLOAT32 Tensor_DataType = 3 + Tensor_FLOAT64 Tensor_DataType = 4 + Tensor_FLOAT8 Tensor_DataType = 5 + Tensor_INT16 Tensor_DataType = 6 + Tensor_INT32 Tensor_DataType = 7 + Tensor_INT64 Tensor_DataType = 8 + Tensor_INT8 Tensor_DataType = 9 + Tensor_STRING Tensor_DataType = 10 + Tensor_UINT32 Tensor_DataType = 11 + Tensor_UINT64 Tensor_DataType = 12 + Tensor_UINT8 Tensor_DataType = 13 ) // Enum value maps for Tensor_DataType. var ( Tensor_DataType_name = map[int32]string{ - 0: "FLOAT32", - 1: "FLOAT16", - 2: "BFLOAT16", - 3: "FLOAT8", - 4: "INT32", - 5: "INT64", - 6: "FLOAT64", + 0: "BFLOAT16", + 1: "BOOL", + 2: "FLOAT16", + 3: "FLOAT32", + 4: "FLOAT64", + 5: "FLOAT8", + 6: "INT16", + 7: "INT32", + 8: "INT64", + 9: "INT8", + 10: "STRING", + 11: "UINT32", + 12: "UINT64", + 13: "UINT8", } Tensor_DataType_value = map[string]int32{ - "FLOAT32": 0, - "FLOAT16": 1, - "BFLOAT16": 2, - "FLOAT8": 3, - "INT32": 4, - "INT64": 5, - "FLOAT64": 6, + "BFLOAT16": 0, + "BOOL": 1, + "FLOAT16": 2, + "FLOAT32": 3, + "FLOAT64": 4, + "FLOAT8": 5, + "INT16": 6, + "INT32": 7, + "INT64": 8, + "INT8": 9, + "STRING": 10, + "UINT32": 11, + "UINT64": 12, + "UINT8": 13, } ) @@ -80,34 +101,40 @@ func (x Tensor_DataType) Number() protoreflect.EnumNumber { // Deprecated: Use Tensor_DataType.Descriptor instead. func (Tensor_DataType) EnumDescriptor() ([]byte, []int) { - return file_zerfoo_proto_rawDescGZIP(), []int{5, 0} + return file_zerfoo_proto_rawDescGZIP(), []int{9, 0} } -// Model is the top-level container for a serialized Zerfoo model. -type Model struct { +// Attribute represents a named, non-tensor parameter for a node. +type Attribute struct { state protoimpl.MessageState `protogen:"open.v1"` - // The computation graph defining the model's architecture. - Graph *Graph `protobuf:"bytes,1,opt,name=graph,proto3" json:"graph,omitempty"` - // Metadata about the model, such as its producer. - Metadata *Metadata `protobuf:"bytes,2,opt,name=metadata,proto3" json:"metadata,omitempty"` + // Types that are valid to be assigned to Value: + // + // *Attribute_F + // *Attribute_I + // *Attribute_S + // *Attribute_Floats + // *Attribute_Ints + // *Attribute_Strings + // *Attribute_B + Value isAttribute_Value `protobuf_oneof:"value"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *Model) Reset() { - *x = Model{} +func (x *Attribute) Reset() { + *x = Attribute{} mi := &file_zerfoo_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *Model) String() string { +func (x *Attribute) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Model) ProtoMessage() {} +func (*Attribute) ProtoMessage() {} -func (x *Model) ProtoReflect() protoreflect.Message { +func (x *Attribute) ProtoReflect() protoreflect.Message { mi := &file_zerfoo_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -119,52 +146,149 @@ func (x *Model) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Model.ProtoReflect.Descriptor instead. -func (*Model) Descriptor() ([]byte, []int) { +// Deprecated: Use Attribute.ProtoReflect.Descriptor instead. +func (*Attribute) Descriptor() ([]byte, []int) { return file_zerfoo_proto_rawDescGZIP(), []int{0} } -func (x *Model) GetGraph() *Graph { +func (x *Attribute) GetValue() isAttribute_Value { if x != nil { - return x.Graph + return x.Value } return nil } -func (x *Model) GetMetadata() *Metadata { +func (x *Attribute) GetF() float32 { if x != nil { - return x.Metadata + if x, ok := x.Value.(*Attribute_F); ok { + return x.F + } + } + return 0 +} + +func (x *Attribute) GetI() int64 { + if x != nil { + if x, ok := x.Value.(*Attribute_I); ok { + return x.I + } + } + return 0 +} + +func (x *Attribute) GetS() string { + if x != nil { + if x, ok := x.Value.(*Attribute_S); ok { + return x.S + } + } + return "" +} + +func (x *Attribute) GetFloats() *Floats { + if x != nil { + if x, ok := x.Value.(*Attribute_Floats); ok { + return x.Floats + } } return nil } -// Metadata stores information about the model's origin and versioning. -type Metadata struct { - state protoimpl.MessageState `protogen:"open.v1"` - // The name of the tool or framework that produced this model. - ProducerName string `protobuf:"bytes,1,opt,name=producer_name,json=producerName,proto3" json:"producer_name,omitempty"` - // The version of the producer. - ProducerVersion string `protobuf:"bytes,2,opt,name=producer_version,json=producerVersion,proto3" json:"producer_version,omitempty"` - // The version of the ZMF operator set this model conforms to. - OpsetVersion int64 `protobuf:"varint,3,opt,name=opset_version,json=opsetVersion,proto3" json:"opset_version,omitempty"` +func (x *Attribute) GetInts() *Ints { + if x != nil { + if x, ok := x.Value.(*Attribute_Ints); ok { + return x.Ints + } + } + return nil +} + +func (x *Attribute) GetStrings() *Strings { + if x != nil { + if x, ok := x.Value.(*Attribute_Strings); ok { + return x.Strings + } + } + return nil +} + +func (x *Attribute) GetB() bool { + if x != nil { + if x, ok := x.Value.(*Attribute_B); ok { + return x.B + } + } + return false +} + +type isAttribute_Value interface { + isAttribute_Value() +} + +type Attribute_F struct { + F float32 `protobuf:"fixed32,1,opt,name=f,proto3,oneof"` +} + +type Attribute_I struct { + I int64 `protobuf:"varint,2,opt,name=i,proto3,oneof"` +} + +type Attribute_S struct { + S string `protobuf:"bytes,3,opt,name=s,proto3,oneof"` +} + +type Attribute_Floats struct { + Floats *Floats `protobuf:"bytes,4,opt,name=floats,proto3,oneof"` +} + +type Attribute_Ints struct { + Ints *Ints `protobuf:"bytes,5,opt,name=ints,proto3,oneof"` +} + +type Attribute_Strings struct { + Strings *Strings `protobuf:"bytes,6,opt,name=strings,proto3,oneof"` +} + +type Attribute_B struct { + B bool `protobuf:"varint,7,opt,name=b,proto3,oneof"` // Added boolean support +} + +func (*Attribute_F) isAttribute_Value() {} + +func (*Attribute_I) isAttribute_Value() {} + +func (*Attribute_S) isAttribute_Value() {} + +func (*Attribute_Floats) isAttribute_Value() {} + +func (*Attribute_Ints) isAttribute_Value() {} + +func (*Attribute_Strings) isAttribute_Value() {} + +func (*Attribute_B) isAttribute_Value() {} + +// Floats is a wrapper for repeated float values in attributes. +type Floats struct { + state protoimpl.MessageState `protogen:"open.v1"` + Val []float32 `protobuf:"fixed32,1,rep,packed,name=val,proto3" json:"val,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *Metadata) Reset() { - *x = Metadata{} +func (x *Floats) Reset() { + *x = Floats{} mi := &file_zerfoo_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *Metadata) String() string { +func (x *Floats) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Metadata) ProtoMessage() {} +func (*Floats) ProtoMessage() {} -func (x *Metadata) ProtoReflect() protoreflect.Message { +func (x *Floats) ProtoReflect() protoreflect.Message { mi := &file_zerfoo_proto_msgTypes[1] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -176,30 +300,16 @@ func (x *Metadata) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Metadata.ProtoReflect.Descriptor instead. -func (*Metadata) Descriptor() ([]byte, []int) { +// Deprecated: Use Floats.ProtoReflect.Descriptor instead. +func (*Floats) Descriptor() ([]byte, []int) { return file_zerfoo_proto_rawDescGZIP(), []int{1} } -func (x *Metadata) GetProducerName() string { - if x != nil { - return x.ProducerName - } - return "" -} - -func (x *Metadata) GetProducerVersion() string { - if x != nil { - return x.ProducerVersion - } - return "" -} - -func (x *Metadata) GetOpsetVersion() int64 { +func (x *Floats) GetVal() []float32 { if x != nil { - return x.OpsetVersion + return x.Val } - return 0 + return nil } // Graph represents the computation graph of the model. @@ -277,33 +387,28 @@ func (x *Graph) GetOutputs() []*ValueInfo { return nil } -// ValueInfo describes a tensor, including its name, type, and shape. -type ValueInfo struct { - state protoimpl.MessageState `protogen:"open.v1"` - // The name of the tensor. - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - // The data type of the tensor. - Dtype Tensor_DataType `protobuf:"varint,2,opt,name=dtype,proto3,enum=zmf.Tensor_DataType" json:"dtype,omitempty"` - // The shape of the tensor. A value of -1 can be used for dynamic dimensions. - Shape []int64 `protobuf:"varint,3,rep,packed,name=shape,proto3" json:"shape,omitempty"` +// Ints is a wrapper for repeated int64 values in attributes. +type Ints struct { + state protoimpl.MessageState `protogen:"open.v1"` + Val []int64 `protobuf:"varint,1,rep,packed,name=val,proto3" json:"val,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *ValueInfo) Reset() { - *x = ValueInfo{} +func (x *Ints) Reset() { + *x = Ints{} mi := &file_zerfoo_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *ValueInfo) String() string { +func (x *Ints) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ValueInfo) ProtoMessage() {} +func (*Ints) ProtoMessage() {} -func (x *ValueInfo) ProtoReflect() protoreflect.Message { +func (x *Ints) ProtoReflect() protoreflect.Message { mi := &file_zerfoo_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -315,66 +420,45 @@ func (x *ValueInfo) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ValueInfo.ProtoReflect.Descriptor instead. -func (*ValueInfo) Descriptor() ([]byte, []int) { +// Deprecated: Use Ints.ProtoReflect.Descriptor instead. +func (*Ints) Descriptor() ([]byte, []int) { return file_zerfoo_proto_rawDescGZIP(), []int{3} } -func (x *ValueInfo) GetName() string { - if x != nil { - return x.Name - } - return "" -} - -func (x *ValueInfo) GetDtype() Tensor_DataType { - if x != nil { - return x.Dtype - } - return Tensor_FLOAT32 -} - -func (x *ValueInfo) GetShape() []int64 { +func (x *Ints) GetVal() []int64 { if x != nil { - return x.Shape + return x.Val } return nil } -// Node represents a single layer or operation in the computation graph. -type Node struct { +// Metadata stores information about the model's origin and versioning. +type Metadata struct { state protoimpl.MessageState `protogen:"open.v1"` - // A unique name for this node, e.g., "transformer_block_0". - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - // The type of operation this node performs, e.g., "RMSNorm", "GlobalAttention". - // This must map to a registered constructor in the importer. - OpType string `protobuf:"bytes,2,opt,name=op_type,json=opType,proto3" json:"op_type,omitempty"` - // A list of names of the tensors that are inputs to this node. - // These can be names of parameters or outputs of other nodes. - Inputs []string `protobuf:"bytes,3,rep,name=inputs,proto3" json:"inputs,omitempty"` - // A list of names for the output tensors of this node. - Outputs []string `protobuf:"bytes,4,rep,name=outputs,proto3" json:"outputs,omitempty"` - // A map of attributes for this node, for values that are not tensors. - // e.g., "epsilon" for RMSNorm, "hidden_dim" for FFN. - Attributes map[string]*Attribute `protobuf:"bytes,5,rep,name=attributes,proto3" json:"attributes,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // The name of the tool or framework that produced this model. + ProducerName string `protobuf:"bytes,1,opt,name=producer_name,json=producerName,proto3" json:"producer_name,omitempty"` + // The version of the producer. + ProducerVersion string `protobuf:"bytes,2,opt,name=producer_version,json=producerVersion,proto3" json:"producer_version,omitempty"` + // The version of the ZMF operator set this model conforms to. + OpsetVersion int64 `protobuf:"varint,3,opt,name=opset_version,json=opsetVersion,proto3" json:"opset_version,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *Node) Reset() { - *x = Node{} +func (x *Metadata) Reset() { + *x = Metadata{} mi := &file_zerfoo_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *Node) String() string { +func (x *Metadata) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Node) ProtoMessage() {} +func (*Metadata) ProtoMessage() {} -func (x *Node) ProtoReflect() protoreflect.Message { +func (x *Metadata) ProtoReflect() protoreflect.Message { mi := &file_zerfoo_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -386,72 +470,59 @@ func (x *Node) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Node.ProtoReflect.Descriptor instead. -func (*Node) Descriptor() ([]byte, []int) { +// Deprecated: Use Metadata.ProtoReflect.Descriptor instead. +func (*Metadata) Descriptor() ([]byte, []int) { return file_zerfoo_proto_rawDescGZIP(), []int{4} } -func (x *Node) GetName() string { +func (x *Metadata) GetProducerName() string { if x != nil { - return x.Name + return x.ProducerName } return "" } -func (x *Node) GetOpType() string { +func (x *Metadata) GetProducerVersion() string { if x != nil { - return x.OpType + return x.ProducerVersion } return "" } -func (x *Node) GetInputs() []string { - if x != nil { - return x.Inputs - } - return nil -} - -func (x *Node) GetOutputs() []string { - if x != nil { - return x.Outputs - } - return nil -} - -func (x *Node) GetAttributes() map[string]*Attribute { +func (x *Metadata) GetOpsetVersion() int64 { if x != nil { - return x.Attributes + return x.OpsetVersion } - return nil + return 0 } -// Tensor represents a multi-dimensional array of data (e.g., a weight matrix). -type Tensor struct { +// Model is the top-level container for a serialized Zerfoo model. +type Model struct { state protoimpl.MessageState `protogen:"open.v1"` - Dtype Tensor_DataType `protobuf:"varint,1,opt,name=dtype,proto3,enum=zmf.Tensor_DataType" json:"dtype,omitempty"` - // The shape (dimensions) of the tensor. - Shape []int64 `protobuf:"varint,2,rep,packed,name=shape,proto3" json:"shape,omitempty"` - // The raw tensor data, stored as bytes. - Data []byte `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"` + // The computation graph defining the model's architecture. + Graph *Graph `protobuf:"bytes,1,opt,name=graph,proto3" json:"graph,omitempty"` + // Metadata about the model, such as its producer. + Metadata *Metadata `protobuf:"bytes,2,opt,name=metadata,proto3" json:"metadata,omitempty"` + // Version of the ZMF format itself. + ZmfVersion string `protobuf:"bytes,3,opt,name=zmf_version,json=zmfVersion,proto3" json:"zmf_version,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *Tensor) Reset() { - *x = Tensor{} +func (x *Model) Reset() { + *x = Model{} mi := &file_zerfoo_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *Tensor) String() string { +func (x *Model) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Tensor) ProtoMessage() {} +func (*Model) ProtoMessage() {} -func (x *Tensor) ProtoReflect() protoreflect.Message { +func (x *Model) ProtoReflect() protoreflect.Message { mi := &file_zerfoo_proto_msgTypes[5] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -463,62 +534,70 @@ func (x *Tensor) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Tensor.ProtoReflect.Descriptor instead. -func (*Tensor) Descriptor() ([]byte, []int) { +// Deprecated: Use Model.ProtoReflect.Descriptor instead. +func (*Model) Descriptor() ([]byte, []int) { return file_zerfoo_proto_rawDescGZIP(), []int{5} } -func (x *Tensor) GetDtype() Tensor_DataType { +func (x *Model) GetGraph() *Graph { if x != nil { - return x.Dtype + return x.Graph } - return Tensor_FLOAT32 + return nil } -func (x *Tensor) GetShape() []int64 { +func (x *Model) GetMetadata() *Metadata { if x != nil { - return x.Shape + return x.Metadata } return nil } -func (x *Tensor) GetData() []byte { +func (x *Model) GetZmfVersion() string { if x != nil { - return x.Data + return x.ZmfVersion } - return nil + return "" } -// Attribute represents a named, non-tensor parameter for a node. -type Attribute struct { +// Node represents a single layer or operation in the computation graph. +type Node struct { state protoimpl.MessageState `protogen:"open.v1"` - // Types that are valid to be assigned to Value: - // - // *Attribute_F - // *Attribute_I - // *Attribute_S - // *Attribute_Floats - // *Attribute_Ints - // *Attribute_Strings - Value isAttribute_Value `protobuf_oneof:"value"` + // A unique name for this node, e.g., "transformer_block_0". + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + // The type of operation this node performs, e.g., "RMSNorm", "GlobalAttention". + // This must map to a registered constructor in the importer. + OpType string `protobuf:"bytes,2,opt,name=op_type,json=opType,proto3" json:"op_type,omitempty"` + // A list of names of the tensors that are inputs to this node. + // These can be names of parameters or outputs of other nodes. + Inputs []string `protobuf:"bytes,3,rep,name=inputs,proto3" json:"inputs,omitempty"` + // A list of names for the output tensors of this node. + Outputs []string `protobuf:"bytes,4,rep,name=outputs,proto3" json:"outputs,omitempty"` + // A map of attributes for this node, for values that are not tensors. + // e.g., "epsilon" for RMSNorm, "hidden_dim" for FFN. + Attributes map[string]*Attribute `protobuf:"bytes,5,rep,name=attributes,proto3" json:"attributes,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // Common attributes for explicitness + Epsilon *float32 `protobuf:"fixed32,6,opt,name=epsilon,proto3,oneof" json:"epsilon,omitempty"` // Used in normalization layers + Perm []int64 `protobuf:"varint,7,rep,packed,name=perm,proto3" json:"perm,omitempty"` // Used in transpose operations + Axis *int64 `protobuf:"varint,8,opt,name=axis,proto3,oneof" json:"axis,omitempty"` // Used in reduction operations unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *Attribute) Reset() { - *x = Attribute{} +func (x *Node) Reset() { + *x = Node{} mi := &file_zerfoo_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *Attribute) String() string { +func (x *Node) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Attribute) ProtoMessage() {} +func (*Node) ProtoMessage() {} -func (x *Attribute) ProtoReflect() protoreflect.Message { +func (x *Node) ProtoReflect() protoreflect.Message { mi := &file_zerfoo_proto_msgTypes[6] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -530,135 +609,163 @@ func (x *Attribute) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Attribute.ProtoReflect.Descriptor instead. -func (*Attribute) Descriptor() ([]byte, []int) { +// Deprecated: Use Node.ProtoReflect.Descriptor instead. +func (*Node) Descriptor() ([]byte, []int) { return file_zerfoo_proto_rawDescGZIP(), []int{6} } -func (x *Attribute) GetValue() isAttribute_Value { +func (x *Node) GetName() string { if x != nil { - return x.Value + return x.Name } - return nil + return "" } -func (x *Attribute) GetF() float32 { +func (x *Node) GetOpType() string { if x != nil { - if x, ok := x.Value.(*Attribute_F); ok { - return x.F - } + return x.OpType } - return 0 + return "" } -func (x *Attribute) GetI() int64 { +func (x *Node) GetInputs() []string { if x != nil { - if x, ok := x.Value.(*Attribute_I); ok { - return x.I - } + return x.Inputs } - return 0 + return nil } -func (x *Attribute) GetS() string { +func (x *Node) GetOutputs() []string { if x != nil { - if x, ok := x.Value.(*Attribute_S); ok { - return x.S - } + return x.Outputs } - return "" + return nil } -func (x *Attribute) GetFloats() *Floats { +func (x *Node) GetAttributes() map[string]*Attribute { if x != nil { - if x, ok := x.Value.(*Attribute_Floats); ok { - return x.Floats - } + return x.Attributes } return nil } -func (x *Attribute) GetInts() *Ints { - if x != nil { - if x, ok := x.Value.(*Attribute_Ints); ok { - return x.Ints - } +func (x *Node) GetEpsilon() float32 { + if x != nil && x.Epsilon != nil { + return *x.Epsilon } - return nil + return 0 } -func (x *Attribute) GetStrings() *Strings { +func (x *Node) GetPerm() []int64 { if x != nil { - if x, ok := x.Value.(*Attribute_Strings); ok { - return x.Strings - } + return x.Perm } return nil } -type isAttribute_Value interface { - isAttribute_Value() +func (x *Node) GetAxis() int64 { + if x != nil && x.Axis != nil { + return *x.Axis + } + return 0 } -type Attribute_F struct { - F float32 `protobuf:"fixed32,1,opt,name=f,proto3,oneof"` +// Quantization describes the parameters for affine quantization. +type Quantization struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The scale factor for quantization. + Scale float32 `protobuf:"fixed32,1,opt,name=scale,proto3" json:"scale,omitempty"` + // The zero point for quantization. + ZeroPoint int64 `protobuf:"varint,2,opt,name=zero_point,json=zeroPoint,proto3" json:"zero_point,omitempty"` + // The minimum value of the quantized range. + QMin *int64 `protobuf:"varint,3,opt,name=q_min,json=qMin,proto3,oneof" json:"q_min,omitempty"` + // The maximum value of the quantized range. + QMax *int64 `protobuf:"varint,4,opt,name=q_max,json=qMax,proto3,oneof" json:"q_max,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } -type Attribute_I struct { - I int64 `protobuf:"varint,2,opt,name=i,proto3,oneof"` +func (x *Quantization) Reset() { + *x = Quantization{} + mi := &file_zerfoo_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } -type Attribute_S struct { - S string `protobuf:"bytes,3,opt,name=s,proto3,oneof"` +func (x *Quantization) String() string { + return protoimpl.X.MessageStringOf(x) } -type Attribute_Floats struct { - Floats *Floats `protobuf:"bytes,4,opt,name=floats,proto3,oneof"` -} +func (*Quantization) ProtoMessage() {} -type Attribute_Ints struct { - Ints *Ints `protobuf:"bytes,5,opt,name=ints,proto3,oneof"` +func (x *Quantization) ProtoReflect() protoreflect.Message { + mi := &file_zerfoo_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) } -type Attribute_Strings struct { - Strings *Strings `protobuf:"bytes,6,opt,name=strings,proto3,oneof"` +// Deprecated: Use Quantization.ProtoReflect.Descriptor instead. +func (*Quantization) Descriptor() ([]byte, []int) { + return file_zerfoo_proto_rawDescGZIP(), []int{7} } -func (*Attribute_F) isAttribute_Value() {} - -func (*Attribute_I) isAttribute_Value() {} - -func (*Attribute_S) isAttribute_Value() {} +func (x *Quantization) GetScale() float32 { + if x != nil { + return x.Scale + } + return 0 +} -func (*Attribute_Floats) isAttribute_Value() {} +func (x *Quantization) GetZeroPoint() int64 { + if x != nil { + return x.ZeroPoint + } + return 0 +} -func (*Attribute_Ints) isAttribute_Value() {} +func (x *Quantization) GetQMin() int64 { + if x != nil && x.QMin != nil { + return *x.QMin + } + return 0 +} -func (*Attribute_Strings) isAttribute_Value() {} +func (x *Quantization) GetQMax() int64 { + if x != nil && x.QMax != nil { + return *x.QMax + } + return 0 +} -// Wrapper messages for repeated primitive types in attributes. -type Floats struct { +// Strings is a wrapper for repeated string values in attributes. +type Strings struct { state protoimpl.MessageState `protogen:"open.v1"` - Val []float32 `protobuf:"fixed32,1,rep,packed,name=val,proto3" json:"val,omitempty"` + Val []string `protobuf:"bytes,1,rep,name=val,proto3" json:"val,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *Floats) Reset() { - *x = Floats{} - mi := &file_zerfoo_proto_msgTypes[7] +func (x *Strings) Reset() { + *x = Strings{} + mi := &file_zerfoo_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *Floats) String() string { +func (x *Strings) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Floats) ProtoMessage() {} +func (*Strings) ProtoMessage() {} -func (x *Floats) ProtoReflect() protoreflect.Message { - mi := &file_zerfoo_proto_msgTypes[7] +func (x *Strings) ProtoReflect() protoreflect.Message { + mi := &file_zerfoo_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -669,40 +776,47 @@ func (x *Floats) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Floats.ProtoReflect.Descriptor instead. -func (*Floats) Descriptor() ([]byte, []int) { - return file_zerfoo_proto_rawDescGZIP(), []int{7} +// Deprecated: Use Strings.ProtoReflect.Descriptor instead. +func (*Strings) Descriptor() ([]byte, []int) { + return file_zerfoo_proto_rawDescGZIP(), []int{8} } -func (x *Floats) GetVal() []float32 { +func (x *Strings) GetVal() []string { if x != nil { return x.Val } return nil } -type Ints struct { - state protoimpl.MessageState `protogen:"open.v1"` - Val []int64 `protobuf:"varint,1,rep,packed,name=val,proto3" json:"val,omitempty"` +// Tensor represents a multi-dimensional array of data (e.g., a weight matrix). +type Tensor struct { + state protoimpl.MessageState `protogen:"open.v1"` + Dtype Tensor_DataType `protobuf:"varint,1,opt,name=dtype,proto3,enum=zmf.Tensor_DataType" json:"dtype,omitempty"` + // The shape (dimensions) of the tensor. + Shape []int64 `protobuf:"varint,2,rep,packed,name=shape,proto3" json:"shape,omitempty"` + // The raw tensor data, stored as bytes. + Data []byte `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"` + // Optional quantization parameters for the tensor. + Quant *Quantization `protobuf:"bytes,4,opt,name=quant,proto3,oneof" json:"quant,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *Ints) Reset() { - *x = Ints{} - mi := &file_zerfoo_proto_msgTypes[8] +func (x *Tensor) Reset() { + *x = Tensor{} + mi := &file_zerfoo_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *Ints) String() string { +func (x *Tensor) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Ints) ProtoMessage() {} +func (*Tensor) ProtoMessage() {} -func (x *Ints) ProtoReflect() protoreflect.Message { - mi := &file_zerfoo_proto_msgTypes[8] +func (x *Tensor) ProtoReflect() protoreflect.Message { + mi := &file_zerfoo_proto_msgTypes[9] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -713,40 +827,67 @@ func (x *Ints) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Ints.ProtoReflect.Descriptor instead. -func (*Ints) Descriptor() ([]byte, []int) { - return file_zerfoo_proto_rawDescGZIP(), []int{8} +// Deprecated: Use Tensor.ProtoReflect.Descriptor instead. +func (*Tensor) Descriptor() ([]byte, []int) { + return file_zerfoo_proto_rawDescGZIP(), []int{9} } -func (x *Ints) GetVal() []int64 { +func (x *Tensor) GetDtype() Tensor_DataType { if x != nil { - return x.Val + return x.Dtype + } + return Tensor_BFLOAT16 +} + +func (x *Tensor) GetShape() []int64 { + if x != nil { + return x.Shape } return nil } -type Strings struct { - state protoimpl.MessageState `protogen:"open.v1"` - Val []string `protobuf:"bytes,1,rep,name=val,proto3" json:"val,omitempty"` +func (x *Tensor) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +func (x *Tensor) GetQuant() *Quantization { + if x != nil { + return x.Quant + } + return nil +} + +// ValueInfo describes a tensor, including its name, type, and shape. +type ValueInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The name of the tensor. + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + // The data type of the tensor. + Dtype Tensor_DataType `protobuf:"varint,2,opt,name=dtype,proto3,enum=zmf.Tensor_DataType" json:"dtype,omitempty"` + // The shape of the tensor. A value of -1 can be used for dynamic dimensions. + Shape []int64 `protobuf:"varint,3,rep,packed,name=shape,proto3" json:"shape,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *Strings) Reset() { - *x = Strings{} - mi := &file_zerfoo_proto_msgTypes[9] +func (x *ValueInfo) Reset() { + *x = ValueInfo{} + mi := &file_zerfoo_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *Strings) String() string { +func (x *ValueInfo) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Strings) ProtoMessage() {} +func (*ValueInfo) ProtoMessage() {} -func (x *Strings) ProtoReflect() protoreflect.Message { - mi := &file_zerfoo_proto_msgTypes[9] +func (x *ValueInfo) ProtoReflect() protoreflect.Message { + mi := &file_zerfoo_proto_msgTypes[10] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -757,14 +898,28 @@ func (x *Strings) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Strings.ProtoReflect.Descriptor instead. -func (*Strings) Descriptor() ([]byte, []int) { - return file_zerfoo_proto_rawDescGZIP(), []int{9} +// Deprecated: Use ValueInfo.ProtoReflect.Descriptor instead. +func (*ValueInfo) Descriptor() ([]byte, []int) { + return file_zerfoo_proto_rawDescGZIP(), []int{10} } -func (x *Strings) GetVal() []string { +func (x *ValueInfo) GetName() string { if x != nil { - return x.Val + return x.Name + } + return "" +} + +func (x *ValueInfo) GetDtype() Tensor_DataType { + if x != nil { + return x.Dtype + } + return Tensor_BFLOAT16 +} + +func (x *ValueInfo) GetShape() []int64 { + if x != nil { + return x.Shape } return nil } @@ -773,15 +928,18 @@ var File_zerfoo_proto protoreflect.FileDescriptor const file_zerfoo_proto_rawDesc = "" + "\n" + - "\fzerfoo.proto\x12\x03zmf\"T\n" + - "\x05Model\x12 \n" + - "\x05graph\x18\x01 \x01(\v2\n" + - ".zmf.GraphR\x05graph\x12)\n" + - "\bmetadata\x18\x02 \x01(\v2\r.zmf.MetadataR\bmetadata\"\x7f\n" + - "\bMetadata\x12#\n" + - "\rproducer_name\x18\x01 \x01(\tR\fproducerName\x12)\n" + - "\x10producer_version\x18\x02 \x01(\tR\x0fproducerVersion\x12#\n" + - "\ropset_version\x18\x03 \x01(\x03R\fopsetVersion\"\x82\x02\n" + + "\fzerfoo.proto\x12\x03zmf\"\xc6\x01\n" + + "\tAttribute\x12\x0e\n" + + "\x01f\x18\x01 \x01(\x02H\x00R\x01f\x12\x0e\n" + + "\x01i\x18\x02 \x01(\x03H\x00R\x01i\x12\x0e\n" + + "\x01s\x18\x03 \x01(\tH\x00R\x01s\x12%\n" + + "\x06floats\x18\x04 \x01(\v2\v.zmf.FloatsH\x00R\x06floats\x12\x1f\n" + + "\x04ints\x18\x05 \x01(\v2\t.zmf.IntsH\x00R\x04ints\x12(\n" + + "\astrings\x18\x06 \x01(\v2\f.zmf.StringsH\x00R\astrings\x12\x0e\n" + + "\x01b\x18\a \x01(\bH\x00R\x01bB\a\n" + + "\x05value\"\x1a\n" + + "\x06Floats\x12\x10\n" + + "\x03val\x18\x01 \x03(\x02R\x03val\"\x82\x02\n" + "\x05Graph\x12:\n" + "\n" + "parameters\x18\x01 \x03(\v2\x1a.zmf.Graph.ParametersEntryR\n" + @@ -791,11 +949,19 @@ const file_zerfoo_proto_rawDesc = "" + "\aoutputs\x18\x04 \x03(\v2\x0e.zmf.ValueInfoR\aoutputs\x1aJ\n" + "\x0fParametersEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12!\n" + - "\x05value\x18\x02 \x01(\v2\v.zmf.TensorR\x05value:\x028\x01\"a\n" + - "\tValueInfo\x12\x12\n" + - "\x04name\x18\x01 \x01(\tR\x04name\x12*\n" + - "\x05dtype\x18\x02 \x01(\x0e2\x14.zmf.Tensor.DataTypeR\x05dtype\x12\x14\n" + - "\x05shape\x18\x03 \x03(\x03R\x05shape\"\xef\x01\n" + + "\x05value\x18\x02 \x01(\v2\v.zmf.TensorR\x05value:\x028\x01\"\x18\n" + + "\x04Ints\x12\x10\n" + + "\x03val\x18\x01 \x03(\x03R\x03val\"\x7f\n" + + "\bMetadata\x12#\n" + + "\rproducer_name\x18\x01 \x01(\tR\fproducerName\x12)\n" + + "\x10producer_version\x18\x02 \x01(\tR\x0fproducerVersion\x12#\n" + + "\ropset_version\x18\x03 \x01(\x03R\fopsetVersion\"u\n" + + "\x05Model\x12 \n" + + "\x05graph\x18\x01 \x01(\v2\n" + + ".zmf.GraphR\x05graph\x12)\n" + + "\bmetadata\x18\x02 \x01(\v2\r.zmf.MetadataR\bmetadata\x12\x1f\n" + + "\vzmf_version\x18\x03 \x01(\tR\n" + + "zmfVersion\"\xd0\x02\n" + "\x04Node\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n" + "\aop_type\x18\x02 \x01(\tR\x06opType\x12\x16\n" + @@ -803,37 +969,56 @@ const file_zerfoo_proto_rawDesc = "" + "\aoutputs\x18\x04 \x03(\tR\aoutputs\x129\n" + "\n" + "attributes\x18\x05 \x03(\v2\x19.zmf.Node.AttributesEntryR\n" + - "attributes\x1aM\n" + + "attributes\x12\x1d\n" + + "\aepsilon\x18\x06 \x01(\x02H\x00R\aepsilon\x88\x01\x01\x12\x12\n" + + "\x04perm\x18\a \x03(\x03R\x04perm\x12\x17\n" + + "\x04axis\x18\b \x01(\x03H\x01R\x04axis\x88\x01\x01\x1aM\n" + "\x0fAttributesEntry\x12\x10\n" + "\x03key\x18\x01 \x01(\tR\x03key\x12$\n" + - "\x05value\x18\x02 \x01(\v2\x0e.zmf.AttributeR\x05value:\x028\x01\"\xc1\x01\n" + + "\x05value\x18\x02 \x01(\v2\x0e.zmf.AttributeR\x05value:\x028\x01B\n" + + "\n" + + "\b_epsilonB\a\n" + + "\x05_axis\"\x8b\x01\n" + + "\fQuantization\x12\x14\n" + + "\x05scale\x18\x01 \x01(\x02R\x05scale\x12\x1d\n" + + "\n" + + "zero_point\x18\x02 \x01(\x03R\tzeroPoint\x12\x18\n" + + "\x05q_min\x18\x03 \x01(\x03H\x00R\x04qMin\x88\x01\x01\x12\x18\n" + + "\x05q_max\x18\x04 \x01(\x03H\x01R\x04qMax\x88\x01\x01B\b\n" + + "\x06_q_minB\b\n" + + "\x06_q_max\"\x1b\n" + + "\aStrings\x12\x10\n" + + "\x03val\x18\x01 \x03(\tR\x03val\"\xc8\x02\n" + "\x06Tensor\x12*\n" + "\x05dtype\x18\x01 \x01(\x0e2\x14.zmf.Tensor.DataTypeR\x05dtype\x12\x14\n" + "\x05shape\x18\x02 \x03(\x03R\x05shape\x12\x12\n" + - "\x04data\x18\x03 \x01(\fR\x04data\"a\n" + - "\bDataType\x12\v\n" + - "\aFLOAT32\x10\x00\x12\v\n" + - "\aFLOAT16\x10\x01\x12\f\n" + - "\bBFLOAT16\x10\x02\x12\n" + + "\x04data\x18\x03 \x01(\fR\x04data\x12,\n" + + "\x05quant\x18\x04 \x01(\v2\x11.zmf.QuantizationH\x00R\x05quant\x88\x01\x01\"\xaf\x01\n" + + "\bDataType\x12\f\n" + + "\bBFLOAT16\x10\x00\x12\b\n" + + "\x04BOOL\x10\x01\x12\v\n" + + "\aFLOAT16\x10\x02\x12\v\n" + + "\aFLOAT32\x10\x03\x12\v\n" + + "\aFLOAT64\x10\x04\x12\n" + "\n" + - "\x06FLOAT8\x10\x03\x12\t\n" + - "\x05INT32\x10\x04\x12\t\n" + - "\x05INT64\x10\x05\x12\v\n" + - "\aFLOAT64\x10\x06\"\xb6\x01\n" + - "\tAttribute\x12\x0e\n" + - "\x01f\x18\x01 \x01(\x02H\x00R\x01f\x12\x0e\n" + - "\x01i\x18\x02 \x01(\x03H\x00R\x01i\x12\x0e\n" + - "\x01s\x18\x03 \x01(\tH\x00R\x01s\x12%\n" + - "\x06floats\x18\x04 \x01(\v2\v.zmf.FloatsH\x00R\x06floats\x12\x1f\n" + - "\x04ints\x18\x05 \x01(\v2\t.zmf.IntsH\x00R\x04ints\x12(\n" + - "\astrings\x18\x06 \x01(\v2\f.zmf.StringsH\x00R\astringsB\a\n" + - "\x05value\"\x1a\n" + - "\x06Floats\x12\x10\n" + - "\x03val\x18\x01 \x03(\x02R\x03val\"\x18\n" + - "\x04Ints\x12\x10\n" + - "\x03val\x18\x01 \x03(\x03R\x03val\"\x1b\n" + - "\aStrings\x12\x10\n" + - "\x03val\x18\x01 \x03(\tR\x03valB\x17Z\x15github.com/zerfoo/zmfb\x06proto3" + "\x06FLOAT8\x10\x05\x12\t\n" + + "\x05INT16\x10\x06\x12\t\n" + + "\x05INT32\x10\a\x12\t\n" + + "\x05INT64\x10\b\x12\b\n" + + "\x04INT8\x10\t\x12\n" + + "\n" + + "\x06STRING\x10\n" + + "\x12\n" + + "\n" + + "\x06UINT32\x10\v\x12\n" + + "\n" + + "\x06UINT64\x10\f\x12\t\n" + + "\x05UINT8\x10\rB\b\n" + + "\x06_quant\"a\n" + + "\tValueInfo\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12*\n" + + "\x05dtype\x18\x02 \x01(\x0e2\x14.zmf.Tensor.DataTypeR\x05dtype\x12\x14\n" + + "\x05shape\x18\x03 \x03(\x03R\x05shapeB\x17Z\x15github.com/zerfoo/zmfb\x06proto3" var ( file_zerfoo_proto_rawDescOnce sync.Once @@ -848,42 +1033,44 @@ func file_zerfoo_proto_rawDescGZIP() []byte { } var file_zerfoo_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_zerfoo_proto_msgTypes = make([]protoimpl.MessageInfo, 12) +var file_zerfoo_proto_msgTypes = make([]protoimpl.MessageInfo, 13) var file_zerfoo_proto_goTypes = []any{ (Tensor_DataType)(0), // 0: zmf.Tensor.DataType - (*Model)(nil), // 1: zmf.Model - (*Metadata)(nil), // 2: zmf.Metadata + (*Attribute)(nil), // 1: zmf.Attribute + (*Floats)(nil), // 2: zmf.Floats (*Graph)(nil), // 3: zmf.Graph - (*ValueInfo)(nil), // 4: zmf.ValueInfo - (*Node)(nil), // 5: zmf.Node - (*Tensor)(nil), // 6: zmf.Tensor - (*Attribute)(nil), // 7: zmf.Attribute - (*Floats)(nil), // 8: zmf.Floats - (*Ints)(nil), // 9: zmf.Ints - (*Strings)(nil), // 10: zmf.Strings - nil, // 11: zmf.Graph.ParametersEntry - nil, // 12: zmf.Node.AttributesEntry + (*Ints)(nil), // 4: zmf.Ints + (*Metadata)(nil), // 5: zmf.Metadata + (*Model)(nil), // 6: zmf.Model + (*Node)(nil), // 7: zmf.Node + (*Quantization)(nil), // 8: zmf.Quantization + (*Strings)(nil), // 9: zmf.Strings + (*Tensor)(nil), // 10: zmf.Tensor + (*ValueInfo)(nil), // 11: zmf.ValueInfo + nil, // 12: zmf.Graph.ParametersEntry + nil, // 13: zmf.Node.AttributesEntry } var file_zerfoo_proto_depIdxs = []int32{ - 3, // 0: zmf.Model.graph:type_name -> zmf.Graph - 2, // 1: zmf.Model.metadata:type_name -> zmf.Metadata - 11, // 2: zmf.Graph.parameters:type_name -> zmf.Graph.ParametersEntry - 5, // 3: zmf.Graph.nodes:type_name -> zmf.Node - 4, // 4: zmf.Graph.inputs:type_name -> zmf.ValueInfo - 4, // 5: zmf.Graph.outputs:type_name -> zmf.ValueInfo - 0, // 6: zmf.ValueInfo.dtype:type_name -> zmf.Tensor.DataType - 12, // 7: zmf.Node.attributes:type_name -> zmf.Node.AttributesEntry - 0, // 8: zmf.Tensor.dtype:type_name -> zmf.Tensor.DataType - 8, // 9: zmf.Attribute.floats:type_name -> zmf.Floats - 9, // 10: zmf.Attribute.ints:type_name -> zmf.Ints - 10, // 11: zmf.Attribute.strings:type_name -> zmf.Strings - 6, // 12: zmf.Graph.ParametersEntry.value:type_name -> zmf.Tensor - 7, // 13: zmf.Node.AttributesEntry.value:type_name -> zmf.Attribute - 14, // [14:14] is the sub-list for method output_type - 14, // [14:14] is the sub-list for method input_type - 14, // [14:14] is the sub-list for extension type_name - 14, // [14:14] is the sub-list for extension extendee - 0, // [0:14] is the sub-list for field type_name + 2, // 0: zmf.Attribute.floats:type_name -> zmf.Floats + 4, // 1: zmf.Attribute.ints:type_name -> zmf.Ints + 9, // 2: zmf.Attribute.strings:type_name -> zmf.Strings + 12, // 3: zmf.Graph.parameters:type_name -> zmf.Graph.ParametersEntry + 7, // 4: zmf.Graph.nodes:type_name -> zmf.Node + 11, // 5: zmf.Graph.inputs:type_name -> zmf.ValueInfo + 11, // 6: zmf.Graph.outputs:type_name -> zmf.ValueInfo + 3, // 7: zmf.Model.graph:type_name -> zmf.Graph + 5, // 8: zmf.Model.metadata:type_name -> zmf.Metadata + 13, // 9: zmf.Node.attributes:type_name -> zmf.Node.AttributesEntry + 0, // 10: zmf.Tensor.dtype:type_name -> zmf.Tensor.DataType + 8, // 11: zmf.Tensor.quant:type_name -> zmf.Quantization + 0, // 12: zmf.ValueInfo.dtype:type_name -> zmf.Tensor.DataType + 10, // 13: zmf.Graph.ParametersEntry.value:type_name -> zmf.Tensor + 1, // 14: zmf.Node.AttributesEntry.value:type_name -> zmf.Attribute + 15, // [15:15] is the sub-list for method output_type + 15, // [15:15] is the sub-list for method input_type + 15, // [15:15] is the sub-list for extension type_name + 15, // [15:15] is the sub-list for extension extendee + 0, // [0:15] is the sub-list for field type_name } func init() { file_zerfoo_proto_init() } @@ -891,21 +1078,25 @@ func file_zerfoo_proto_init() { if File_zerfoo_proto != nil { return } - file_zerfoo_proto_msgTypes[6].OneofWrappers = []any{ + file_zerfoo_proto_msgTypes[0].OneofWrappers = []any{ (*Attribute_F)(nil), (*Attribute_I)(nil), (*Attribute_S)(nil), (*Attribute_Floats)(nil), (*Attribute_Ints)(nil), (*Attribute_Strings)(nil), + (*Attribute_B)(nil), } + file_zerfoo_proto_msgTypes[6].OneofWrappers = []any{} + file_zerfoo_proto_msgTypes[7].OneofWrappers = []any{} + file_zerfoo_proto_msgTypes[9].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_zerfoo_proto_rawDesc), len(file_zerfoo_proto_rawDesc)), NumEnums: 1, - NumMessages: 12, + NumMessages: 13, NumExtensions: 0, NumServices: 0, }, diff --git a/zerfoo.proto b/zerfoo.proto index 2a26461..468ef9b 100644 --- a/zerfoo.proto +++ b/zerfoo.proto @@ -4,117 +4,156 @@ package zmf; option go_package = "github.com/zerfoo/zmf"; -// Model is the top-level container for a serialized Zerfoo model. -message Model { - // The computation graph defining the model's architecture. - Graph graph = 1; +// Attribute represents a named, non-tensor parameter for a node. +message Attribute { + oneof value { + float f = 1; + int64 i = 2; + string s = 3; + Floats floats = 4; + Ints ints = 5; + Strings strings = 6; + bool b = 7; // Added boolean support + } +} - // Metadata about the model, such as its producer. - Metadata metadata = 2; +// Floats is a wrapper for repeated float values in attributes. +message Floats { + repeated float val = 1; } -// Metadata stores information about the model's origin and versioning. -message Metadata { - // The name of the tool or framework that produced this model. - string producer_name = 1; +// Graph represents the computation graph of the model. +message Graph { + // A map of all trainable parameters (weights, biases) in the model. + // The key is a unique name for the parameter, e.g., "layers.0.attention.wq.weight". + map parameters = 1; + + // A list of all nodes (layers or operations) in the graph. + // For sequential models, this list should be in execution order. + repeated Node nodes = 2; - // The version of the producer. - string producer_version = 2; + // A list describing the input tensors to the graph. + repeated ValueInfo inputs = 3; - // The version of the ZMF operator set this model conforms to. - int64 opset_version = 3; + // A list describing the output tensors of the graph. + repeated ValueInfo outputs = 4; } -// Graph represents the computation graph of the model. -message Graph { - // A map of all trainable parameters (weights, biases) in the model. - // The key is a unique name for the parameter, e.g., "layers.0.attention.wq.weight". - map parameters = 1; +// Ints is a wrapper for repeated int64 values in attributes. +message Ints { + repeated int64 val = 1; +} - // A list of all nodes (layers or operations) in the graph. - // For sequential models, this list should be in execution order. - repeated Node nodes = 2; +// Metadata stores information about the model's origin and versioning. +message Metadata { + // The name of the tool or framework that produced this model. + string producer_name = 1; - // A list describing the input tensors to the graph. - repeated ValueInfo inputs = 3; + // The version of the producer. + string producer_version = 2; - // A list describing the output tensors of the graph. - repeated ValueInfo outputs = 4; + // The version of the ZMF operator set this model conforms to. + int64 opset_version = 3; } -// ValueInfo describes a tensor, including its name, type, and shape. -message ValueInfo { - // The name of the tensor. - string name = 1; +// Model is the top-level container for a serialized Zerfoo model. +message Model { + // The computation graph defining the model's architecture. + Graph graph = 1; - // The data type of the tensor. - Tensor.DataType dtype = 2; + // Metadata about the model, such as its producer. + Metadata metadata = 2; - // The shape of the tensor. A value of -1 can be used for dynamic dimensions. - repeated int64 shape = 3; + // Version of the ZMF format itself. + string zmf_version = 3; } // Node represents a single layer or operation in the computation graph. message Node { - // A unique name for this node, e.g., "transformer_block_0". - string name = 1; + // A unique name for this node, e.g., "transformer_block_0". + string name = 1; - // The type of operation this node performs, e.g., "RMSNorm", "GlobalAttention". - // This must map to a registered constructor in the importer. - string op_type = 2; + // The type of operation this node performs, e.g., "RMSNorm", "GlobalAttention". + // This must map to a registered constructor in the importer. + string op_type = 2; - // A list of names of the tensors that are inputs to this node. - // These can be names of parameters or outputs of other nodes. - repeated string inputs = 3; + // A list of names of the tensors that are inputs to this node. + // These can be names of parameters or outputs of other nodes. + repeated string inputs = 3; - // A list of names for the output tensors of this node. - repeated string outputs = 4; + // A list of names for the output tensors of this node. + repeated string outputs = 4; - // A map of attributes for this node, for values that are not tensors. - // e.g., "epsilon" for RMSNorm, "hidden_dim" for FFN. - map attributes = 5; -} + // A map of attributes for this node, for values that are not tensors. + // e.g., "epsilon" for RMSNorm, "hidden_dim" for FFN. + map attributes = 5; -// Tensor represents a multi-dimensional array of data (e.g., a weight matrix). -message Tensor { - // The data type of the tensor elements. - enum DataType { - FLOAT32 = 0; - FLOAT16 = 1; - BFLOAT16 = 2; - FLOAT8 = 3; - INT32 = 4; - INT64 = 5; - FLOAT64 = 6; - } - DataType dtype = 1; - - // The shape (dimensions) of the tensor. - repeated int64 shape = 2; - - // The raw tensor data, stored as bytes. - bytes data = 3; + // Common attributes for explicitness + optional float epsilon = 6; // Used in normalization layers + repeated int64 perm = 7; // Used in transpose operations + optional int64 axis = 8; // Used in reduction operations } -// Attribute represents a named, non-tensor parameter for a node. -message Attribute { - oneof value { - float f = 1; - int64 i = 2; - string s = 3; - Floats floats = 4; - Ints ints = 5; - Strings strings = 6; - } +// Quantization describes the parameters for affine quantization. +message Quantization { + // The scale factor for quantization. + float scale = 1; + + // The zero point for quantization. + int64 zero_point = 2; + + // The minimum value of the quantized range. + optional int64 q_min = 3; + + // The maximum value of the quantized range. + optional int64 q_max = 4; } -// Wrapper messages for repeated primitive types in attributes. -message Floats { - repeated float val = 1; +// Strings is a wrapper for repeated string values in attributes. +message Strings { + repeated string val = 1; } -message Ints { - repeated int64 val = 1; + +// Tensor represents a multi-dimensional array of data (e.g., a weight matrix). +message Tensor { + // The data type of the tensor elements. + enum DataType { + BFLOAT16 = 0; + BOOL = 1; + FLOAT16 = 2; + FLOAT32 = 3; + FLOAT64 = 4; + FLOAT8 = 5; + INT16 = 6; + INT32 = 7; + INT64 = 8; + INT8 = 9; + STRING = 10; + UINT32 = 11; + UINT64 = 12; + UINT8 = 13; + } + + DataType dtype = 1; + + // The shape (dimensions) of the tensor. + repeated int64 shape = 2; + + // The raw tensor data, stored as bytes. + bytes data = 3; + + // Optional quantization parameters for the tensor. + optional Quantization quant = 4; } -message Strings { - repeated string val = 1; + +// ValueInfo describes a tensor, including its name, type, and shape. +message ValueInfo { + // The name of the tensor. + string name = 1; + + // The data type of the tensor. + Tensor.DataType dtype = 2; + + // The shape of the tensor. A value of -1 can be used for dynamic dimensions. + repeated int64 shape = 3; }