12 #include "llvm/ADT/TypeSwitch.h"
17 void XeGPUDialect::initialize() {
19 #define GET_TYPEDEF_LIST
20 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
24 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
27 #define GET_ATTRDEF_LIST
28 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
36 xegpu::MemorySpace memory_space,
38 bool boundary_check) {
43 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
51 xegpu::MemorySpace memory_space,
int chunk_size) {
55 return Base::get(context, scopeAttr, chunkSizeAttr);
62 template <
typename T,
unsigned N>
65 llvm::StringRef fieldName) {
68 "unexpected field name. Expected " + fieldName +
".");
77 auto elemParser = [&]() -> llvm::ParseResult {
80 result.push_back(elem);
85 elemParser, fieldName);
95 if (failed(parseIntArrayField(parser, wi_layout,
"wi_layout")))
101 if (failed(parseIntArrayField(parser, wi_data,
"wi_data")))
104 return SGMapAttr::getChecked(
112 printer <<
" = [" << getWiLayout() <<
"], ";
114 printer <<
" = [" << getWiData() <<
"]";
122 if (wi_layout.size() != 2)
123 return emitError() <<
"expected wi_layout of size 2";
124 if (wi_data.size() != 2)
125 return emitError() <<
"expected wi_data of size 2";
136 mlir::FailureOr<mlir::Attribute> encoding;
137 mlir::FailureOr<mlir::Attribute> sg_map;
145 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
150 if (mlir::failed(parser.
parseType(elementType))) {
151 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
159 if (mlir::succeeded(res)) {
160 if (mlir::isa<SGMapAttr>(attr)) {
164 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
170 "Failed to parse the attribute.\n");
187 for (int64_t dim : shape) {
188 if (mlir::ShapedType::isDynamic(dim))
197 if (
auto encoding = getEncoding())
198 printer <<
", " << encoding;
200 if (
auto sg_map = getSgMap())
201 printer <<
", " << sg_map;
209 MemorySpace memory_space,
214 return Base::get(context, shape, elementType, attr, sg_map);
219 MemorySpace memory_space,
223 return Base::get(context, shape, elementType, attr, sg_map);
229 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
230 #define GET_ATTRDEF_CLASSES
231 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
232 #define GET_TYPEDEF_CLASSES
233 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printKeywordOrString(StringRef keyword)
Print the given string as a keyword, or a quoted and escaped string if it has any special or non-prin...
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
MLIRContext is the top-level object for a collection of MLIR operations.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...