MLIR  20.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1 //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 namespace mlir {
15 namespace xegpu {
16 
17 void XeGPUDialect::initialize() {
18  addTypes<
19 #define GET_TYPEDEF_LIST
20 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
21  >();
22  addOperations<
23 #define GET_OP_LIST
24 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
25  >();
26  addAttributes<
27 #define GET_ATTRDEF_LIST
28 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
29  >();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // XeGPU_BlockTensorDescAttr
34 //===----------------------------------------------------------------------===//
35 BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
36  xegpu::MemorySpace memory_space,
37  int array_length,
38  bool boundary_check) {
39  auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
40  auto lengthAttr =
41  IntegerAttr::get(IntegerType::get(context, 64), array_length);
42  auto boundaryAttr = BoolAttr::get(context, boundary_check);
43  return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // XeGPU_ScatterTensorDescAttr
48 //===----------------------------------------------------------------------===//
49 ScatterTensorDescAttr
51  xegpu::MemorySpace memory_space, int chunk_size) {
52  auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
53  auto chunkSizeAttr =
54  IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
55  return Base::get(context, scopeAttr, chunkSizeAttr);
56 }
57 
58 //===----------------------------------------------------------------------===//
59 // XeGPU_SGMapAttr
60 //===----------------------------------------------------------------------===//
61 namespace {
62 template <typename T, unsigned N>
63 LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
65  llvm::StringRef fieldName) {
66  if (failed(parser.parseKeyword(fieldName))) {
67  parser.emitError(parser.getCurrentLocation(),
68  "unexpected field name. Expected " + fieldName + ".");
69  return failure();
70  }
71 
72  if (failed(parser.parseEqual())) {
73  parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
74  return failure();
75  }
76 
77  auto elemParser = [&]() -> llvm::ParseResult {
78  uint32_t elem = 0;
79  auto res = parser.parseInteger(elem);
80  result.push_back(elem);
81  return res;
82  };
83 
85  elemParser, fieldName);
86 }
87 } // namespace
88 
90  ::mlir::Type attrType) {
91  if (failed(parser.parseLess()))
92  return {};
93 
94  llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
95  if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
96  return {};
97 
98  if (failed(parser.parseComma()))
99  return {};
100 
101  if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
102  return {};
103 
104  return SGMapAttr::getChecked(
105  [&]() { return parser.emitError(parser.getNameLoc()); },
106  parser.getContext(), wi_layout, wi_data);
107 }
108 
109 void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
110  printer << "<";
111  printer.printKeywordOrString("wi_layout");
112  printer << " = [" << getWiLayout() << "], ";
113  printer.printKeywordOrString("wi_data");
114  printer << " = [" << getWiData() << "]";
115  printer << ">";
116 }
117 
118 LogicalResult
120  llvm::ArrayRef<uint32_t> wi_layout,
121  llvm::ArrayRef<uint32_t> wi_data) {
122  if (wi_layout.size() != 2)
123  return emitError() << "expected wi_layout of size 2";
124  if (wi_data.size() != 2)
125  return emitError() << "expected wi_data of size 2";
126  return success();
127 }
128 
129 //===----------------------------------------------------------------------===//
130 // XeGPU_TensorDescType
131 //===----------------------------------------------------------------------===//
132 
135  mlir::Type elementType;
136  mlir::FailureOr<mlir::Attribute> encoding;
137  mlir::FailureOr<mlir::Attribute> sg_map;
138 
139  // Parse literal '<'
140  if (parser.parseLess())
141  return {};
142 
143  auto shapeLoc = parser.getCurrentLocation();
144  if (mlir::failed(parser.parseDimensionList(shape))) {
145  parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
146  return {};
147  }
148 
149  auto elemTypeLoc = parser.getCurrentLocation();
150  if (mlir::failed(parser.parseType(elementType))) {
151  parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
152  return {};
153  }
154 
155  // parse optional attributes
156  while (mlir::succeeded(parser.parseOptionalComma())) {
157  mlir::Attribute attr;
158  ParseResult res = parser.parseAttribute(attr);
159  if (mlir::succeeded(res)) {
160  if (mlir::isa<SGMapAttr>(attr)) {
161  sg_map = attr;
162  continue;
163  }
164  if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
165  encoding = attr;
166  continue;
167  }
168  }
169  parser.emitError(parser.getCurrentLocation(),
170  "Failed to parse the attribute.\n");
171  return {};
172  }
173 
174  // Parse literal '>'
175  if (parser.parseGreater())
176  return {};
177 
178  return TensorDescType::get(parser.getContext(), shape, elementType,
179  encoding.value_or(mlir::Attribute()),
180  sg_map.value_or(mlir::Attribute()));
181 }
182 
183 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
184  printer << "<";
185 
186  auto shape = getShape();
187  for (int64_t dim : shape) {
188  if (mlir::ShapedType::isDynamic(dim))
189  printer << '?';
190  else
191  printer << dim;
192  printer << 'x';
193  }
194 
195  printer << getElementType();
196 
197  if (auto encoding = getEncoding())
198  printer << ", " << encoding;
199 
200  if (auto sg_map = getSgMap())
201  printer << ", " << sg_map;
202 
203  printer << ">";
204 }
205 
206 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
207  mlir::Type elementType, int array_length,
208  bool boundary_check,
209  MemorySpace memory_space,
210  mlir::Attribute sg_map) {
211  auto context = elementType.getContext();
212  auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
213  boundary_check);
214  return Base::get(context, shape, elementType, attr, sg_map);
215 }
216 
217 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
218  mlir::Type elementType, int chunk_size,
219  MemorySpace memory_space,
220  mlir::Attribute sg_map) {
221  auto context = elementType.getContext();
222  auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
223  return Base::get(context, shape, elementType, attr, sg_map);
224 }
225 
226 } // namespace xegpu
227 } // namespace mlir
228 
229 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
230 #define GET_ATTRDEF_CLASSES
231 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
232 #define GET_TYPEDEF_CLASSES
233 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printKeywordOrString(StringRef keyword)
Print the given string as a keyword, or a quoted and escaped string if it has any special or non-prin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425