MLIR  21.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1 //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/IR/Builders.h"
12 #include "llvm/ADT/TypeSwitch.h"
13 
14 namespace mlir {
15 namespace xegpu {
16 
17 void XeGPUDialect::initialize() {
18  addTypes<
19 #define GET_TYPEDEF_LIST
20 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
21  >();
22  addOperations<
23 #define GET_OP_LIST
24 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
25  >();
26  addAttributes<
27 #define GET_ATTRDEF_LIST
28 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
29  >();
30 }
31 
32 //===----------------------------------------------------------------------===//
33 // XeGPU_BlockTensorDescAttr
34 //===----------------------------------------------------------------------===//
35 BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
36  xegpu::MemorySpace memory_space,
37  int array_length,
38  bool boundary_check) {
39  auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
40  auto lengthAttr =
41  IntegerAttr::get(IntegerType::get(context, 64), array_length);
42  auto boundaryAttr = BoolAttr::get(context, boundary_check);
43  return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
44 }
45 
46 //===----------------------------------------------------------------------===//
47 // XeGPU_ScatterTensorDescAttr
48 //===----------------------------------------------------------------------===//
49 ScatterTensorDescAttr
51  xegpu::MemorySpace memory_space, int chunk_size) {
52  auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
53  auto chunkSizeAttr =
54  IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
55  return Base::get(context, scopeAttr, chunkSizeAttr);
56 }
57 
58 LogicalResult ScatterTensorDescAttr::verify(
60  MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
61  int64_t chunkSize = chunk_size.getInt();
62  SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8,
63  16, 32, 64, 128, 256};
64  if (!llvm::is_contained(supportedChunkSizes, chunkSize))
65  return emitError() << "invalid chunk size";
66 
67  return success();
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // XeGPU_SGMapAttr
72 //===----------------------------------------------------------------------===//
73 namespace {
74 template <typename T, unsigned N>
75 LogicalResult parseIntArrayField(::mlir::AsmParser &parser,
77  llvm::StringRef fieldName) {
78  if (failed(parser.parseKeyword(fieldName))) {
79  parser.emitError(parser.getCurrentLocation(),
80  "unexpected field name. Expected " + fieldName + ".");
81  return failure();
82  }
83 
84  if (failed(parser.parseEqual())) {
85  parser.emitError(parser.getCurrentLocation(), "expected '=' sign.");
86  return failure();
87  }
88 
89  auto elemParser = [&]() -> llvm::ParseResult {
90  uint32_t elem = 0;
91  auto res = parser.parseInteger(elem);
92  result.push_back(elem);
93  return res;
94  };
95 
97  elemParser, fieldName);
98 }
99 } // namespace
100 
102  ::mlir::Type attrType) {
103  if (failed(parser.parseLess()))
104  return {};
105 
106  llvm::SmallVector<uint32_t, 2> wi_layout, wi_data;
107  if (failed(parseIntArrayField(parser, wi_layout, "wi_layout")))
108  return {};
109 
110  if (failed(parser.parseComma()))
111  return {};
112 
113  if (failed(parseIntArrayField(parser, wi_data, "wi_data")))
114  return {};
115 
116  return SGMapAttr::getChecked(
117  [&]() { return parser.emitError(parser.getNameLoc()); },
118  parser.getContext(), wi_layout, wi_data);
119 }
120 
121 void SGMapAttr::print(::mlir::AsmPrinter &printer) const {
122  printer << "<";
123  printer.printKeywordOrString("wi_layout");
124  printer << " = [" << getWiLayout() << "], ";
125  printer.printKeywordOrString("wi_data");
126  printer << " = [" << getWiData() << "]";
127  printer << ">";
128 }
129 
130 LogicalResult
132  llvm::ArrayRef<uint32_t> wi_layout,
133  llvm::ArrayRef<uint32_t> wi_data) {
134  if (wi_layout.size() != 2)
135  return emitError() << "expected wi_layout of size 2";
136  if (wi_data.size() != 2)
137  return emitError() << "expected wi_data of size 2";
138  return success();
139 }
140 
141 //===----------------------------------------------------------------------===//
142 // XeGPU_TensorDescType
143 //===----------------------------------------------------------------------===//
144 
147  mlir::Type elementType;
148  mlir::FailureOr<mlir::Attribute> encoding;
149  mlir::FailureOr<mlir::Attribute> sg_map;
150 
151  // Parse literal '<'
152  if (parser.parseLess())
153  return {};
154 
155  auto shapeLoc = parser.getCurrentLocation();
156  if (mlir::failed(parser.parseDimensionList(shape))) {
157  parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
158  return {};
159  }
160 
161  auto elemTypeLoc = parser.getCurrentLocation();
162  if (mlir::failed(parser.parseType(elementType))) {
163  parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
164  return {};
165  }
166 
167  // parse optional attributes
168  while (mlir::succeeded(parser.parseOptionalComma())) {
169  mlir::Attribute attr;
170  ParseResult res = parser.parseAttribute(attr);
171  if (mlir::succeeded(res)) {
172  if (mlir::isa<SGMapAttr>(attr)) {
173  sg_map = attr;
174  continue;
175  }
176  if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
177  encoding = attr;
178  continue;
179  }
180  }
181  return {};
182  }
183 
184  // Parse literal '>'
185  if (parser.parseGreater())
186  return {};
187 
188  return TensorDescType::getChecked(
189  [&]() { return parser.emitError(parser.getNameLoc()); },
190  parser.getContext(), shape, elementType,
191  encoding.value_or(mlir::Attribute()), sg_map.value_or(mlir::Attribute()));
192 }
193 
194 void TensorDescType::print(::mlir::AsmPrinter &printer) const {
195  printer << "<";
196 
197  auto shape = getShape();
198  for (int64_t dim : shape) {
199  if (mlir::ShapedType::isDynamic(dim))
200  printer << '?';
201  else
202  printer << dim;
203  printer << 'x';
204  }
205 
206  printer << getElementType();
207 
208  if (auto encoding = getEncoding())
209  printer << ", " << encoding;
210 
211  if (auto sg_map = getSgMap())
212  printer << ", " << sg_map;
213 
214  printer << ">";
215 }
216 
217 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
218  mlir::Type elementType, int array_length,
219  bool boundary_check,
220  MemorySpace memory_space,
221  mlir::Attribute sg_map) {
222  auto context = elementType.getContext();
223  auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
224  boundary_check);
225  return Base::get(context, shape, elementType, attr, sg_map);
226 }
227 
228 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
229  mlir::Type elementType, int chunk_size,
230  MemorySpace memory_space,
231  mlir::Attribute sg_map) {
232  auto context = elementType.getContext();
233  auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
234  return Base::get(context, shape, elementType, attr, sg_map);
235 }
236 
237 LogicalResult TensorDescType::verify(
239  llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
240  mlir::Attribute encoding, mlir::Attribute sg_map) {
241  size_t rank = shape.size();
242  // Low-pressure types are packed in 32-bit units.
243  unsigned packingFactor = 32 / elementType.getIntOrFloatBitWidth();
244  if (rank != 1 && rank != 2)
245  return emitError() << "expected 1D or 2D tensor";
246 
247  auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
248  if (scatterAttr) {
249  // Expected tensor ranks for scattered data:
250  // - 1D tensor for fully non-contiguous elements (chunk size == 1)
251  // - 2D tensor for scattered blocks (chunk size > 1)
252  unsigned chunkSize = scatterAttr.getChunkSize().getInt();
253  if (rank == 1 && chunkSize != 1)
254  return emitError() << "expected non-contiguous elements for 1D tensor";
255  if (rank == 2 && chunkSize < 2)
256  return emitError() << "expected chunk blocks for 2D tensor";
257  // If chunk size > 1, the second dimension of the tensor shape must be
258  // equal to chunk size and it must be a multiple of the packing factor.
259  if (chunkSize > 1) {
260  if (shape.back() != chunkSize)
261  return emitError() << "expected tensor shape[1] to match chunk size";
262  if (shape.back() % packingFactor != 0)
263  return emitError()
264  << "expected tensor shape[1] to be a multiple of packing factor "
265  << packingFactor;
266  }
267  }
268 
269  if (auto blockAttr =
270  mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding)) {
271  MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
272  if (rank == 2 && memorySpaceAttr &&
273  memorySpaceAttr.getValue() == MemorySpace::SLM)
274  return emitError() << "SLM is not supported for 2D block tensor";
275  }
276 
277  if (auto sgMapAttr = llvm::dyn_cast_if_present<SGMapAttr>(sg_map)) {
278  ArrayRef<uint32_t> wiLayout = sgMapAttr.getWiLayout();
279  ArrayRef<uint32_t> wiData = sgMapAttr.getWiData();
280 
281  if (rank == 1) {
282  if (wiLayout[0] != 1 || wiData[0] != 1)
283  return emitError()
284  << "outer layout distribution and data mapping must be 1 "
285  "for 1D tensor";
286  }
287 
288  if (scatterAttr) {
289  // Validate subgroup mapping rules for scattered tensors.
290  // A work-item's slice of the tensor with shape [sg_size] or
291  // [sg_size, chunk_size] will be [1] or [1, 32/element_ty_bit_width]
292  // respectively, the mapping should reflect that. This is because each
293  // work item access data in 32 bit granularity.
294  if (wiData[0] != 1)
295  return emitError()
296  << "cannot map over non-contiguous scattered row elements";
297  if (wiData[1] != packingFactor)
298  return emitError() << "work item data mapping must match the number of "
299  "contiguous elements";
300  }
301 
302  // For 1D tensor, pad the shape with an outer unit dimension to allow common
303  // validation logic.
304  SmallVector<int64_t> tensorShape(shape.begin(), shape.end());
305  if (rank == 1)
306  tensorShape = {1, tensorShape.back()};
307 
308  size_t dims = tensorShape.size();
309  for (size_t i = 0; i < dims; ++i) {
310  uint32_t numElemPerWi = wiLayout[i] * wiData[i];
311  if (tensorShape[i] < numElemPerWi || tensorShape[i] % numElemPerWi != 0)
312  return emitError() << "cannot distribute " << tensorShape[i] << " over "
313  << wiLayout[i] << " work items with " << wiData[i]
314  << " elements each";
315  }
316  }
317 
318  return success();
319 }
320 
321 // If tensor descriptor has a sg_map attribute it is used in SIMT mode.
322 // In this mode, the distributed vector shape is determined as follows:
323 // Definitions:
324 // wi_data_size = wi_data[0] × wi_data[1]
325 // subgroup_size = wi_layout[0] × wi_layout[1]
326 // distribution_unit_size = subgroup_size × wi_data_size
327 // ---------------------------------------------------------------------
328 // Case 1: Regular loads/stores.
329 // ---------------------------------------------------------------------
330 // Distributed vector shape must be:
331 // [chunk_size / wi_data_size, wi_data_size]
332 // If the tensor descriptor shape is 1D, first dimension is ignored (set to 1).
333 // [wi_data_size]
334 // ---------------------------------------------------------------------
335 // Case 2: Block loads/stores
336 // ---------------------------------------------------------------------
337 // Additional definitions:
338 // tensor_size = tensor_desc[0] * .. * tensor_desc[r-1] * array_length
339 // n_distribution_units = tensor_size / distribution_unit_size
340 // Given above definitions, the following conditions must be met:
341 // * tensor_desc[0] % (wi_layout[0] × wi_data[0]) == 0
342 // * tensor_desc[1] % (wi_layout[1] × wi_data[1]) == 0
343 // Distributed vector shape must be:
344 // [n_distribution_units, wi_data_size]
345 FailureOr<VectorType> TensorDescType::getDistributedVectorType() {
346  auto sgMap = llvm::dyn_cast_if_present<SGMapAttr>(getSgMap());
347  // If no sg_map is provided, tensor desc is not used in SIMT mode.
348  if (!sgMap)
349  return failure();
350 
351  SmallVector<int64_t> wiData(sgMap.getWiData());
352  SmallVector<int64_t> wiLayout(sgMap.getWiLayout());
353  auto tdescShape = getShape();
354 
355  auto wiDataSize = 1, sgSize = 1;
356  for (auto [wiDim, wiDataDim] : llvm::zip_equal(wiLayout, wiData)) {
357  wiDataSize *= wiDataDim;
358  sgSize *= wiDim;
359  }
360 
361  // Case 1: regular loads/stores
362  auto scatterAttr = getEncodingAsScatterTensorDescAttr();
363  if (scatterAttr) {
364  auto chunkSize = scatterAttr.getChunkSize().getInt();
365  // Verify if the first dimension of the tensor descriptor shape is
366  // distributable.
367  assert(tdescShape[0] % (wiLayout[0]) == 0 &&
368  "tensor descriptor shape is not distributable");
369  if (chunkSize > 1)
370  return VectorType::get({chunkSize / wiDataSize, wiDataSize},
371  getElementType());
372  return VectorType::get({wiDataSize}, getElementType());
373  }
374 
375  // Case 2: block loads/stores
376  // Tensor descriptor shape can be 1D. For the 1D case, outer dims of wiData
377  // and wiLayout must be 1.
378  if (tdescShape.size() == 1) {
379  assert((wiData[0] == 1 && wiLayout[0] == 1) &&
380  "wi_data[0] and wi_layout[0] must be 1 for 1D tensor descriptor");
381  wiData = {wiData[1]};
382  wiLayout = {wiLayout[1]};
383  }
384  // Check if the tensor descriptor shape is distributable.
385  int64_t tensorSize = 1;
386  for (auto [tdescDim, wiDim, wiDataDim] :
387  llvm::zip_equal(tdescShape, wiLayout, wiData)) {
388  assert((tdescDim % (wiDim * wiDataDim) == 0) &&
389  "tensor descriptor shape is not distributable");
390  tensorSize *= tdescDim;
391  }
392  // tensorSize must be adjusted for array_length.
393  tensorSize *= getArrayLength();
394 
395  return VectorType::get({tensorSize / (sgSize * wiDataSize), wiDataSize},
396  getElementType());
397 }
398 
399 } // namespace xegpu
400 } // namespace mlir
401 
402 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
403 #define GET_ATTRDEF_CLASSES
404 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
405 #define GET_TYPEDEF_CLASSES
406 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:187
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
This base class exposes generic asm parser hooks, usable across the various derived parsers.
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printKeywordOrString(StringRef keyword)
Print the given string as a keyword, or a quoted and escaped string if it has any special or non-prin...
Attributes are known-constant values of operations.
Definition: Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:425