MLIR  22.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1 //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
16 #include "mlir/IR/Builders.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
20 
21 using std::optional;
22 
23 namespace mlir {
24 namespace xegpu {
25 
26 void XeGPUDialect::initialize() {
27  addTypes<
28 #define GET_TYPEDEF_LIST
29 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
30  >();
31  addOperations<
32 #define GET_OP_LIST
33 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
34  >();
35  addAttributes<
36 #define GET_ATTRDEF_LIST
37 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
38  >();
39 }
40 
41 /// Generates instructions to compute offsets for a subgroup identified by
42 /// its multidimensional indices (sgId), using the specified subgroup layout
43 /// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
44 /// dimensions (sizePerWg).
48  ArrayRef<int64_t> sizePerSg,
49  ArrayRef<int64_t> sizePerWg) {
50 
52 
53  // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
54  SmallVector<Value> localOffsets = llvm::map_to_vector(
55  llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
56  return builder.createOrFold<index::MulOp>(
57  loc, std::get<0>(t),
58  builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
59  });
60 
61  // distUnit[i] is the minimum value between sizePerWg[i] and
62  // sgLayout[i] * sizePerSg[i]
63  SmallVector<int64_t> distUnit = llvm::map_to_vector(
64  llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
65  [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
66 
67  for (SmallVector<int64_t> unitOffs :
68  StaticTileOffsetRange(sizePerWg, distUnit)) {
69  SmallVector<Value> base =
70  llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
71  return arith::ConstantIndexOp::create(builder, loc, d);
72  });
73 
74  SmallVector<Value> adds = llvm::map_to_vector(
75  llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
76  return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
77  std::get<1>(t));
78  });
79 
80  SmallVector<Value> mods = llvm::map_to_vector(
81  llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
82  return builder.createOrFold<index::RemUOp>(
83  loc, std::get<0>(t),
84  arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
85  });
86 
87  offsets.push_back(mods);
88  }
89  return offsets;
90 }
91 
92 // Checks if the given shape can be evenly distributed based on the layout
93 // and data factors provided by the LayoutAttr.
94 bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
95  xegpu::DistributeLayoutAttr attr) {
96  assert(attr && "Layout attribute is missing.");
97 
98  // Checks whether the given shape can be evenly distributed using the
99  // specified layout and data attributes. If successful, it returns the work
100  // size for each compute unit; otherwise, it returns `std::nullopt`. The work
101  // size per compute unit is calculated as follows:
102  // - If `data` is null: newShape[i] = shape[i] / layout[i]
103  // - If `data` is not null: newShape[i] = data[i]
104  // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
105  // smaller than `layout[i] * data[i]`, allowing multiple compute units to
106  // share the data.
107  auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
108  SmallVector<int64_t> layout,
110  bool rr = true) -> optional<SmallVector<int64_t>> {
111  llvm::SmallVector<int64_t> newShape(shape);
112  if (layout.size()) {
113  if (layout.size() != shape.size())
114  return std::nullopt;
115  auto ratio = computeShapeRatio(shape, layout);
116  if (!ratio.has_value())
117  return std::nullopt;
118  newShape = ratio.value();
119  }
120 
121  if (data.size()) {
122  if (data.size() != shape.size())
123  return std::nullopt;
124  auto ratio = computeShapeRatio(newShape, data);
125  if (!ratio.has_value() && rr)
126  ratio = computeShapeRatio(data, newShape);
127  if (!ratio.has_value())
128  return std::nullopt;
129 
130  // if data is not null, we always return it for next phase.
131  newShape = data;
132  }
133  return newShape;
134  };
135 
136  // check the sgLayout and sgData
137  auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
138  attr.getEffectiveSgDataAsInt());
139  if (!maybeSgShape)
140  return false;
141  auto sgShape = maybeSgShape.value();
142 
143  // check InstData, it neither have layout nor need round-robin
144  auto maybeInstShape =
145  tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
146  if (!maybeInstShape)
147  return false;
148  auto instShape = maybeInstShape.value();
149 
150  // check LaneLayout and LaneData
151  auto maybeLaneShape =
152  tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
153  attr.getEffectiveLaneDataAsInt(), false);
154  return maybeLaneShape.has_value();
155 }
156 
157 //===----------------------------------------------------------------------===//
158 // XeGPU_BlockTensorDescAttr
159 //===----------------------------------------------------------------------===//
160 BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
161  xegpu::MemorySpace memory_space,
162  int array_length,
163  bool boundary_check) {
164  auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
165  auto lengthAttr =
166  IntegerAttr::get(IntegerType::get(context, 64), array_length);
167  auto boundaryAttr = BoolAttr::get(context, boundary_check);
168  return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
169 }
170 
171 bool BlockTensorDescAttr::hasDefaultsOnly() {
172  return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
173  getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
174 }
175 
176 //===----------------------------------------------------------------------===//
177 // XeGPU_ScatterTensorDescAttr
178 //===----------------------------------------------------------------------===//
179 ScatterTensorDescAttr
181  xegpu::MemorySpace memory_space, int chunk_size) {
182  auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
183  auto chunkSizeAttr =
184  IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
185  return Base::get(context, scopeAttr, chunkSizeAttr);
186 }
187 
188 LogicalResult ScatterTensorDescAttr::verify(
190  MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
191  int64_t chunkSize = chunk_size.getInt();
192  if (chunkSize <= 0)
193  return emitError() << "invalid chunk size";
194 
195  return success();
196 }
197 
198 //===----------------------------------------------------------------------===//
199 // XeGPU_LayoutAttr
200 //===----------------------------------------------------------------------===//
201 LogicalResult
203  DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
204  DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
205  DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
206 
207  // A valid layout must include at least one of sg_layout and lane_layout.
208  // sg_layout is essential for Workgroup layout, while lane_layout is
209  // required for Subgroup layout.
210  if (!sg_layout && !inst_data && !lane_layout) {
211  return emitError()
212  << "expected at least one of sg_layout, inst_data or lane_layout";
213  }
214 
215  // generate code to check sg_laout, inst_data and lane_layout having the same
216  // rank if they are not null.
217 
218  if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
219  return emitError()
220  << "expected sg_layout and inst_data to have the same rank";
221  }
222 
223  if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
224  return emitError()
225  << "expected sg_layout and lane_layout to have the same rank";
226  }
227 
228  if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
229  return emitError()
230  << "expected inst_data and lane_layout to have the same rank";
231  }
232 
233  // sg_data is optional for Workgroup layout, but its presence requires
234  // sg_layout.
235  if (sg_data) {
236  if (!sg_layout)
237  return emitError() << "expected sg_layout being used with sg_data";
238  if (sg_data.size() != sg_layout.size())
239  return emitError()
240  << "expected sg_data and sg_layout to have the same rank";
241  }
242 
243  // lane_data is optional for Subgroup layout, but its presence requires
244  // lane_layout.
245  if (lane_data) {
246  if (!lane_layout)
247  return emitError() << "expected lane_layout being used with lane_data";
248  if (lane_data.size() != lane_layout.size())
249  return emitError()
250  << "expected lane_data and lane_layout to have the same rank";
251  }
252 
253  if (order) {
254  if (!sg_layout && !lane_layout)
255  return emitError()
256  << "expected sg_layout/lane_layout being used with order";
257 
258  if (sg_layout && order.size() != sg_layout.size())
259  return emitError()
260  << "expected order and sg_layout to have the same rank";
261 
262  if (lane_layout && order.size() != lane_layout.size())
263  return emitError()
264  << "expected order and lane_layout to have the same rank";
265  }
266 
267  return success();
268 }
269 
270 FailureOr<SmallVector<Value>>
271 LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
272  Value linearId) {
273  // delinearizeSubgroupId is only available for
274  // workgroup-level layout attribute
275  if (!isForWorkgroup())
276  return failure();
277 
278  // TODO: handle order attribute
279  auto hasDefaultOrder = [&]() {
280  DenseI32ArrayAttr order = getOrder();
281  return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
282  llvm::reverse(order.asArrayRef())));
283  };
284  if (!hasDefaultOrder())
285  return mlir::emitError(loc, "order attribute is currently not supported.");
286 
287  auto dims =
288  llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
289  return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
290  });
291 
292  return affine::delinearizeIndex(builder, loc, linearId, dims);
293 }
294 
295 /// Implements DistributeLayoutAttr::getOffsets to generate
296 /// instructions for computing multi-dimensional offsets when distributed by
297 /// LayoutAttr.
298 FailureOr<SmallVector<SmallVector<Value>>>
299 LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
300  ArrayRef<int64_t> shape) {
301  if (!isForWorkgroup())
302  return failure();
303 
304  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
305  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
306  if (sgShape.empty()) {
307  if (auto derivedShape = computeShapeRatio(shape, sgLayout))
308  sgShape = derivedShape.value();
309  else
310  return failure();
311  }
312 
313  // delinearize Ids
314  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
315  if (failed(maybeIds))
316  return failure();
317  SmallVector<Value> sgIds = *maybeIds;
318 
319  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
320  shape);
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // XeGPU_SliceAttr
325 //===----------------------------------------------------------------------===//
326 LogicalResult
328  xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
329  if (!parent || !dims)
330  return emitError() << "expected parent layout and dims attribute";
331 
332  int64_t rank = parent.getRank();
333 
334  // check every element in dims is unique and smaller than rank
335  llvm::SmallDenseSet<int64_t> seen;
336  for (int64_t dim : dims.asArrayRef()) {
337  if (dim < 0 || dim >= rank)
338  return emitError() << "invalid dim (" << dim << ") in slice attribute.";
339  if (!seen.insert(dim).second)
340  return emitError() << "repeated dim (" << dim << ") in slice attribute.";
341  }
342  return success();
343 }
344 
345 SliceAttr SliceAttr::flatten() const {
346  xegpu::DistributeLayoutAttr parent = getParent();
347  SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
348 
349  while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
350  parent = sliceAttr.getParent();
351  slicedDims.push_back(sliceAttr.getDims());
352  }
353 
354  auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
355  SmallVector<int64_t> indices =
356  llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
357 
358  // get remaining dims (flattend) by applying slice ops with all slicedDims
359  SmallVector<int64_t> remainingDims(indices);
360  for (auto dim : llvm::reverse(slicedDims))
361  remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
362  dim.asArrayRef());
363 
364  // get flattend sliced dims by applying slice ops with the remaining dims
365  SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
366  llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
367 
368  return xegpu::SliceAttr::get(
369  getContext(), layoutAttr,
370  DenseI64ArrayAttr::get(getContext(), flattendDims));
371 }
372 
373 FailureOr<SmallVector<Value>>
374 SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
375  Value linearId) {
376  SliceAttr attr = flatten();
377  auto parent = dyn_cast<LayoutAttr>(attr.getParent());
378  return parent.delinearizeSubgroupId(builder, loc, linearId);
379 }
380 
381 /// Implements DistributeLayoutAttr::getOffsets to generate
382 /// instructions for computing multi-dimensional offsets when distributed by
383 /// SliceAttr.
384 FailureOr<SmallVector<SmallVector<Value>>>
385 SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
386  ArrayRef<int64_t> shape) {
387  assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
388  if (!isForWorkgroup())
389  return failure();
390 
391  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
392  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
393  if (sgShape.empty()) {
394  if (auto derivedShape = computeShapeRatio(shape, sgLayout))
395  sgShape = derivedShape.value();
396  else
397  return failure();
398  }
399 
400  // delinearize Ids
401  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
402  if (failed(maybeIds))
403  return failure();
404 
405  // The effective sgIds for offsets computing correspond
406  // to the dims that are not sliced.
407  ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
408  SmallVector<Value> sgIds =
409  XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
410 
411  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
412  shape);
413 }
414 
415 bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
416  auto flattenedThis = flatten();
417  // If other is a LayoutAttr, just compare directly with parent of
418  // flattenedThis.
419  if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
420  return flattenedThis.getParent() == otherLayout;
421  // If other is a SliceAttr, flatten it first before comparing.
422  auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
423  // Both must have common parent LayoutAttr.
424  if (flattenedThis.getParent() != flattenedOther.getParent())
425  return false;
426  // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
427  // dims.
428  llvm::SmallDenseSet<int64_t> thisDims(
429  flattenedThis.getDims().asArrayRef().begin(),
430  flattenedThis.getDims().asArrayRef().end());
431  return llvm::all_of(flattenedOther.getDims().asArrayRef(),
432  [&](int64_t dim) { return thisDims.contains(dim); });
433 }
434 
435 //===----------------------------------------------------------------------===//
436 // XeGPU_RangeAttr
437 //===----------------------------------------------------------------------===//
438 
439 LogicalResult
441  IntegerAttr startOfRange, IntegerAttr endOfRange) {
442  if (startOfRange.getInt() >= endOfRange.getInt())
443  return emitError() << "'end' : " << endOfRange.getInt()
444  << " must be greater than 'start' : "
445  << startOfRange.getInt();
446 
447  return success();
448 }
449 
450 //===----------------------------------------------------------------------===//
451 // XeGPU_TensorDescType
452 //===----------------------------------------------------------------------===//
453 
456  mlir::Type elementType;
457  mlir::FailureOr<mlir::Attribute> encoding;
458  mlir::FailureOr<mlir::Attribute> layout;
459 
460  // Parse literal '<'
461  if (parser.parseLess())
462  return {};
463 
464  auto shapeLoc = parser.getCurrentLocation();
465  if (mlir::failed(parser.parseDimensionList(shape))) {
466  parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
467  return {};
468  }
469 
470  auto elemTypeLoc = parser.getCurrentLocation();
471  if (mlir::failed(parser.parseType(elementType))) {
472  parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
473  return {};
474  }
475 
476  // parse optional attributes
477  while (mlir::succeeded(parser.parseOptionalComma())) {
478  mlir::Attribute attr;
479  ParseResult res = parser.parseAttribute(attr);
480  if (mlir::succeeded(res)) {
481  if (mlir::isa<LayoutAttr>(attr)) {
482  layout = attr;
483  continue;
484  }
485  if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
486  encoding = attr;
487  continue;
488  }
489  }
490  return {};
491  }
492 
493  // Parse literal '>'
494  if (parser.parseGreater())
495  return {};
496 
497  MLIRContext *ctxt = parser.getContext();
498  return TensorDescType::getChecked(
499  [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
500  elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
501  layout.value_or(mlir::Attribute()));
502 }
503 
504 void TensorDescType::print(AsmPrinter &printer) const {
505  printer << "<";
506 
507  auto shape = getShape();
508  for (int64_t dim : shape) {
509  if (mlir::ShapedType::isDynamic(dim))
510  printer << '?';
511  else
512  printer << dim;
513  printer << 'x';
514  }
515 
516  printer << getElementType();
517 
518  auto encoding = getEncoding();
519  auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
520  if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
521  printer << ", " << encoding;
522 
523  if (auto layout = getLayout())
524  printer << ", " << layout;
525 
526  printer << ">";
527 }
528 
529 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
530  mlir::Type elementType, int array_length,
531  bool boundary_check,
532  MemorySpace memory_space,
533  mlir::Attribute layout) {
534  auto context = elementType.getContext();
535  auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
536  boundary_check);
537  return Base::get(context, shape, elementType, attr, layout);
538 }
539 
540 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
541  mlir::Type elementType, int chunk_size,
542  MemorySpace memory_space,
543  mlir::Attribute layout) {
544  auto context = elementType.getContext();
545  auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
546  return Base::get(context, shape, elementType, attr, layout);
547 }
548 
549 LogicalResult
551  llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
552  mlir::Attribute encoding, mlir::Attribute layout) {
553  size_t rank = shape.size();
554 
555  if (rank == 0)
556  return emitError() << "expected non-zero rank tensor";
557 
558  auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
559  if (blockAttr) {
560  MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
561  if (rank > 1 && memorySpaceAttr &&
562  memorySpaceAttr.getValue() == MemorySpace::SLM)
563  return emitError() << "SLM is only supported for 1D block tensor";
564  }
565 
566  // for gather and scatter ops, Low-precision types are packed in 32-bit units.
567  unsigned bitWidth = elementType.getIntOrFloatBitWidth();
568  int chunkAlignmentFactor =
571  : 1;
572  auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
573  if (scatterAttr) {
574  int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
575  if (rank == 1 && chunkSize != 1)
576  return emitError() << "expected non-contiguous elements for 1D tensor";
577 
578  // If chunk size > 1, the second dimension of the tensor shape must be
579  // equal to chunk size and it must be a multiple of the
580  // chunkAlignmentFactor.
581  if (chunkSize > 1) {
582  if (shape.back() != chunkSize)
583  return emitError() << "expected last dim of tensor to match chunk size";
584  if (shape.back() % chunkAlignmentFactor != 0)
585  return emitError() << "expected last dim of tensor to be a multiple of "
586  << chunkAlignmentFactor;
587  }
588  }
589 
590  auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
591  if (layoutAttr) {
592  if (rank != (size_t)layoutAttr.getRank())
593  return emitError() << "expected layout rank to match tensor rank";
594 
595  auto laneData = layoutAttr.getLaneData();
596  if (scatterAttr && laneData) {
597  // Validate subgroup mapping rules for scattered tensors.
598  // if chunkSize > 1, the last dimension of the tensor should
599  // be distributed in the units divisible by chunkAlignmentFactor.
600  int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
601  if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
602  return emitError()
603  << "expected last dim of lane_data to be a multiple of: "
604  << chunkAlignmentFactor;
605  }
606 
607  if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
608  std::string shapeStr;
609  llvm::raw_string_ostream stream(shapeStr);
610  llvm::interleaveComma(shape, stream);
611  return emitError() << "cannot distribute [" << shapeStr << "] using "
612  << layoutAttr;
613  }
614  }
615  return success();
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // XeGPU_MemDescType
620 //===----------------------------------------------------------------------===//
623  mlir::Type elementType;
624  mlir::FailureOr<MemLayoutAttr> layout;
625 
626  // Parse literal '<'
627  if (parser.parseLess())
628  return {};
629 
630  auto shapeLoc = parser.getCurrentLocation();
631  if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
632  parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
633  return {};
634  }
635 
636  auto elemTypeLoc = parser.getCurrentLocation();
637  if (mlir::failed(parser.parseType(elementType))) {
638  parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
639  return {};
640  }
641 
642  // parse optional attributes
643  if (mlir::succeeded(parser.parseOptionalComma())) {
644  MemLayoutAttr attr;
645  ParseResult res = parser.parseAttribute(attr);
646  if (mlir::failed(res))
647  return {};
648  layout = attr;
649  }
650 
651  // Parse literal '>'
652  if (parser.parseGreater())
653  return {};
654 
655  MLIRContext *ctxt = parser.getContext();
656  return MemDescType::getChecked(
657  [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
658  elementType, layout.value_or(MemLayoutAttr()));
659 }
660 
661 void MemDescType::print(AsmPrinter &printer) const {
662  printer << "<";
663 
664  printer.printDimensionList(getShape());
665  printer << 'x';
666  printer << getElementType();
667 
668  if (auto layout = getMemLayout())
669  printer << ", " << layout;
670 
671  printer << ">";
672 }
673 
674 //===----------------------------------------------------------------------===//
675 // XeGPU_MemDescType
676 //===----------------------------------------------------------------------===//
677 
679 
680  auto context = parser.getContext();
681  llvm::SMLoc loc = parser.getCurrentLocation();
682 
683  llvm::SmallDenseSet<StringRef> seenKeys;
684  SmallVector<NamedAttribute> attributes;
685 
686  auto parseElt = [&]() -> ParseResult {
687  StringRef nameId;
688  if (failed(parser.parseKeyword(&nameId)))
689  return parser.emitError(loc, "expected valid attribute name");
690 
691  if (!seenKeys.insert(nameId).second)
692  return parser.emitError(loc, "duplicate key '")
693  << nameId << " in mem layout attribute";
694 
695  if (failed(parser.parseEqual()))
696  return failure();
697 
698  Attribute attr;
699  if (failed(parser.parseAttribute(attr)))
700  return failure();
701  attributes.emplace_back(nameId, attr);
702  return success();
703  };
704 
705  // Parse literal '<'
706  if (parser.parseLess())
707  return {};
708 
709  if (failed(parser.parseCommaSeparatedList(parseElt)))
710  return {};
711 
712  // Parse literal '>'
713  if (parser.parseGreater())
714  return {};
715 
716  return parser.getChecked<MemLayoutAttr>(
717  loc, context, DictionaryAttr::get(context, attributes));
718 }
719 
720 void MemLayoutAttr::print(AsmPrinter &printer) const {
721  printer << "<";
722  ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
723  for (size_t i = 0; i < attrs.size(); i++) {
724  printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
725  if (i < attrs.size() - 1)
726  printer << ", ";
727  }
728  printer << ">";
729 }
730 
731 } // namespace xegpu
732 } // namespace mlir
733 
734 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
735 #define GET_ATTRDEF_CLASSES
736 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
737 #define GET_TYPEDEF_CLASSES
738 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
Definition: Utils.cpp:1967
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Definition: VectorUtils.h:130
constexpr unsigned packedSizeInBitsForGatherScatter
static SmallVector< SmallVector< Value > > genOffsetsComputingInsts(OpBuilder &builder, Location loc, SmallVector< Value > sgId, ArrayRef< int64_t > sgLayout, ArrayRef< int64_t > sizePerSg, ArrayRef< int64_t > sizePerWg)
Generates instructions to compute offsets for a subgroup identified by its multidimensional indices (...
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423