MLIR  22.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1 //===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
15 #include "mlir/IR/Builders.h"
17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/Debug.h"
19 
20 using std::optional;
21 
22 namespace mlir {
23 namespace xegpu {
24 
25 void XeGPUDialect::initialize() {
26  addTypes<
27 #define GET_TYPEDEF_LIST
28 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
29  >();
30  addOperations<
31 #define GET_OP_LIST
32 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
33  >();
34  addAttributes<
35 #define GET_ATTRDEF_LIST
36 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
37  >();
38 }
39 
40 /// Generates instructions to compute offsets for a subgroup identified by
41 /// its multidimensional indices (sgId), using the specified subgroup layout
42 /// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
43 /// dimensions (sizePerWg).
47  ArrayRef<int64_t> sizePerSg,
48  ArrayRef<int64_t> sizePerWg) {
49 
51 
52  // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
53  SmallVector<Value> localOffsets = llvm::map_to_vector(
54  llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
55  return builder.createOrFold<index::MulOp>(
56  loc, std::get<0>(t),
57  builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
58  });
59 
60  // distUnit[i] is the minimum value between sizePerWg[i] and
61  // sgLayout[i] * sizePerSg[i]
62  SmallVector<int64_t> distUnit = llvm::map_to_vector(
63  llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
64  [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
65 
66  for (SmallVector<int64_t> unitOffs :
67  StaticTileOffsetRange(sizePerWg, distUnit)) {
68  SmallVector<Value> base =
69  llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
70  return arith::ConstantIndexOp::create(builder, loc, d);
71  });
72 
73  SmallVector<Value> adds = llvm::map_to_vector(
74  llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
75  return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
76  std::get<1>(t));
77  });
78 
79  SmallVector<Value> mods = llvm::map_to_vector(
80  llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
81  return builder.createOrFold<index::RemUOp>(
82  loc, std::get<0>(t),
83  arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
84  });
85 
86  offsets.push_back(mods);
87  }
88  return offsets;
89 }
90 
91 // Checks if the given shape can be evenly distributed based on the layout
92 // and data factors provided by the LayoutAttr.
93 bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
94  xegpu::DistributeLayoutAttr attr) {
95  assert(attr && "Layout attribute is missing.");
96 
97  // Checks whether the given shape can be evenly distributed using the
98  // specified layout and data attributes. If successful, it returns the work
99  // size for each compute unit; otherwise, it returns `std::nullopt`. The work
100  // size per compute unit is calculated as follows:
101  // - If `data` is null: newShape[i] = shape[i] / layout[i]
102  // - If `data` is not null: newShape[i] = data[i]
103  // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
104  // smaller than `layout[i] * data[i]`, allowing multiple compute units to
105  // share the data.
106  auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
107  SmallVector<int64_t> layout,
109  bool rr = true) -> optional<SmallVector<int64_t>> {
110  llvm::SmallVector<int64_t> newShape(shape);
111  if (layout.size()) {
112  if (layout.size() != shape.size())
113  return std::nullopt;
114  auto ratio = computeShapeRatio(shape, layout);
115  if (ratio.has_value()) {
116  newShape = ratio.value();
117  } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
118  return std::nullopt;
119  }
120  // Round-robin case: continue with original newShape
121  }
122 
123  if (data.size()) {
124  if (data.size() != shape.size())
125  return std::nullopt;
126  auto ratio = computeShapeRatio(newShape, data);
127  if (!ratio.has_value() && rr)
128  ratio = computeShapeRatio(data, newShape);
129  if (!ratio.has_value())
130  return std::nullopt;
131 
132  // if data is not null, we always return it for next phase.
133  newShape = data;
134  }
135  return newShape;
136  };
137 
138  // check the sgLayout and sgData
139  auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
140  attr.getEffectiveSgDataAsInt());
141  if (!maybeSgShape)
142  return false;
143  auto sgShape = maybeSgShape.value();
144 
145  // check InstData, it neither have layout nor need round-robin
146  auto maybeInstShape =
147  tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
148  if (!maybeInstShape)
149  return false;
150  auto instShape = maybeInstShape.value();
151 
152  // check LaneLayout and LaneData
153  auto maybeLaneShape =
154  tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
155  attr.getEffectiveLaneDataAsInt(), false);
156  return maybeLaneShape.has_value();
157 }
158 
159 //===----------------------------------------------------------------------===//
160 // XeGPU_BlockTensorDescAttr
161 //===----------------------------------------------------------------------===//
162 BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
163  xegpu::MemorySpace memory_space,
164  int array_length,
165  bool boundary_check) {
166  auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
167  auto lengthAttr =
168  IntegerAttr::get(IntegerType::get(context, 64), array_length);
169  auto boundaryAttr = BoolAttr::get(context, boundary_check);
170  return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
171 }
172 
173 bool BlockTensorDescAttr::hasDefaultsOnly() {
174  return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
175  getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // XeGPU_ScatterTensorDescAttr
180 //===----------------------------------------------------------------------===//
181 ScatterTensorDescAttr
183  xegpu::MemorySpace memory_space, int chunk_size) {
184  auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
185  auto chunkSizeAttr =
186  IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
187  return Base::get(context, scopeAttr, chunkSizeAttr);
188 }
189 
190 LogicalResult ScatterTensorDescAttr::verify(
192  MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
193  int64_t chunkSize = chunk_size.getInt();
194  if (chunkSize <= 0)
195  return emitError() << "invalid chunk size";
196 
197  return success();
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // XeGPU_LayoutAttr
202 //===----------------------------------------------------------------------===//
203 LogicalResult
205  DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
206  DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
207  DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
208 
209  // A valid layout must include at least one of sg_layout and lane_layout.
210  // sg_layout is essential for Workgroup layout, while lane_layout is
211  // required for Subgroup layout.
212  if (!sg_layout && !inst_data && !lane_layout) {
213  return emitError()
214  << "expected at least one of sg_layout, inst_data or lane_layout";
215  }
216 
217  // generate code to check sg_laout, inst_data and lane_layout having the same
218  // rank if they are not null.
219 
220  if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
221  return emitError()
222  << "expected sg_layout and inst_data to have the same rank";
223  }
224 
225  if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
226  return emitError()
227  << "expected sg_layout and lane_layout to have the same rank";
228  }
229 
230  if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
231  return emitError() << "expected inst_data and lane_layout to have the same "
232  "rank, got inst_data "
233  << inst_data.size() << ", lane_layout "
234  << lane_layout.size();
235  }
236 
237  // sg_data is optional for Workgroup layout, but its presence requires
238  // sg_layout.
239  if (sg_data) {
240  if (!sg_layout)
241  return emitError() << "expected sg_layout being used with sg_data";
242  if (sg_data.size() != sg_layout.size())
243  return emitError()
244  << "expected sg_data and sg_layout to have the same rank";
245  }
246 
247  // lane_data is optional for Subgroup layout, but its presence requires
248  // lane_layout.
249  if (lane_data) {
250  if (!lane_layout)
251  return emitError() << "expected lane_layout being used with lane_data";
252  if (lane_data.size() != lane_layout.size())
253  return emitError()
254  << "expected lane_data and lane_layout to have the same rank";
255  }
256 
257  if (order) {
258  if (!sg_layout && !lane_layout)
259  return emitError()
260  << "expected sg_layout/lane_layout being used with order";
261 
262  if (sg_layout && order.size() != sg_layout.size())
263  return emitError()
264  << "expected order and sg_layout to have the same rank";
265 
266  if (lane_layout && order.size() != lane_layout.size())
267  return emitError()
268  << "expected order and lane_layout to have the same rank";
269  }
270 
271  return success();
272 }
273 
274 FailureOr<SmallVector<Value>>
275 LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
276  Value linearId) {
277  // delinearizeSubgroupId is only available for
278  // workgroup-level layout attribute
279  if (!isForWorkgroup())
280  return failure();
281 
282  // TODO: handle order attribute
283  auto hasDefaultOrder = [&]() {
284  DenseI32ArrayAttr order = getOrder();
285  return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
286  llvm::reverse(order.asArrayRef())));
287  };
288  if (!hasDefaultOrder())
289  return mlir::emitError(loc, "order attribute is currently not supported.");
290 
291  auto dims =
292  llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) -> Value {
293  return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
294  });
295 
296  return affine::delinearizeIndex(builder, loc, linearId, dims);
297 }
298 
299 /// Implements DistributeLayoutAttr::getOffsets to generate
300 /// instructions for computing multi-dimensional offsets when distributed by
301 /// LayoutAttr.
302 FailureOr<SmallVector<SmallVector<Value>>>
303 LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
304  ArrayRef<int64_t> shape) {
305  if (!isForWorkgroup())
306  return failure();
307 
308  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
309  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
310  if (sgShape.empty()) {
311  if (auto derivedShape = computeShapeRatio(shape, sgLayout))
312  sgShape = derivedShape.value();
313  else
314  return failure();
315  }
316 
317  // delinearize Ids
318  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
319  if (failed(maybeIds))
320  return failure();
321  SmallVector<Value> sgIds = *maybeIds;
322 
323  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
324  shape);
325 }
326 
327 //===----------------------------------------------------------------------===//
328 // XeGPU_SliceAttr
329 //===----------------------------------------------------------------------===//
330 LogicalResult
332  xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
333  if (!parent || !dims)
334  return emitError() << "expected parent layout and dims attribute";
335 
336  int64_t rank = parent.getRank();
337 
338  // check every element in dims is unique and smaller than rank
339  llvm::SmallDenseSet<int64_t> seen;
340  for (int64_t dim : dims.asArrayRef()) {
341  if (dim < 0 || dim >= rank)
342  return emitError() << "invalid dim (" << dim << ") in slice attribute.";
343  if (!seen.insert(dim).second)
344  return emitError() << "repeated dim (" << dim << ") in slice attribute.";
345  }
346  return success();
347 }
348 
349 SliceAttr SliceAttr::flatten() const {
350  xegpu::DistributeLayoutAttr parent = getParent();
351  SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
352 
353  while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
354  parent = sliceAttr.getParent();
355  slicedDims.push_back(sliceAttr.getDims());
356  }
357 
358  auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
359  SmallVector<int64_t> indices =
360  llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
361 
362  // get remaining dims (flattend) by applying slice ops with all slicedDims
363  SmallVector<int64_t> remainingDims(indices);
364  for (auto dim : llvm::reverse(slicedDims))
365  remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
366  dim.asArrayRef());
367 
368  // get flattend sliced dims by applying slice ops with the remaining dims
369  SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
370  llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
371 
372  return xegpu::SliceAttr::get(
373  getContext(), layoutAttr,
374  DenseI64ArrayAttr::get(getContext(), flattendDims));
375 }
376 
377 FailureOr<SmallVector<Value>>
378 SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
379  Value linearId) {
380  SliceAttr attr = flatten();
381  auto parent = dyn_cast<LayoutAttr>(attr.getParent());
382  return parent.delinearizeSubgroupId(builder, loc, linearId);
383 }
384 
385 /// Implements DistributeLayoutAttr::getOffsets to generate
386 /// instructions for computing multi-dimensional offsets when distributed by
387 /// SliceAttr.
388 FailureOr<SmallVector<SmallVector<Value>>>
389 SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
390  ArrayRef<int64_t> shape) {
391  assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
392  if (!isForWorkgroup())
393  return failure();
394 
395  SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
396  SmallVector<int64_t> sgShape = getEffectiveSgDataAsInt();
397  if (sgShape.empty()) {
398  if (auto derivedShape = computeShapeRatio(shape, sgLayout))
399  sgShape = derivedShape.value();
400  else
401  return failure();
402  }
403 
404  // delinearize Ids
405  auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
406  if (failed(maybeIds))
407  return failure();
408 
409  // The effective sgIds for offsets computing correspond
410  // to the dims that are not sliced.
411  ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
412  SmallVector<Value> sgIds =
413  XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
414 
415  return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
416  shape);
417 }
418 
419 bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
420  auto flattenedThis = flatten();
421  // If other is a LayoutAttr, just compare directly with parent of
422  // flattenedThis.
423  if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
424  return flattenedThis.getParent() == otherLayout;
425  // If other is a SliceAttr, flatten it first before comparing.
426  auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
427  // Both must have common parent LayoutAttr.
428  if (flattenedThis.getParent() != flattenedOther.getParent())
429  return false;
430  // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
431  // dims.
432  llvm::SmallDenseSet<int64_t> thisDims(
433  flattenedThis.getDims().asArrayRef().begin(),
434  flattenedThis.getDims().asArrayRef().end());
435  return llvm::all_of(flattenedOther.getDims().asArrayRef(),
436  [&](int64_t dim) { return thisDims.contains(dim); });
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // XeGPU_RangeAttr
441 //===----------------------------------------------------------------------===//
442 
443 LogicalResult
445  IntegerAttr startOfRange, IntegerAttr endOfRange) {
446  if (startOfRange.getInt() >= endOfRange.getInt())
447  return emitError() << "'end' : " << endOfRange.getInt()
448  << " must be greater than 'start' : "
449  << startOfRange.getInt();
450 
451  return success();
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // XeGPU_TensorDescType
456 //===----------------------------------------------------------------------===//
457 
460  mlir::Type elementType;
461  mlir::FailureOr<mlir::Attribute> encoding;
462  mlir::FailureOr<mlir::Attribute> layout;
463 
464  // Parse literal '<'
465  if (parser.parseLess())
466  return {};
467 
468  auto shapeLoc = parser.getCurrentLocation();
469  if (mlir::failed(parser.parseDimensionList(shape))) {
470  parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
471  return {};
472  }
473 
474  auto elemTypeLoc = parser.getCurrentLocation();
475  if (mlir::failed(parser.parseType(elementType))) {
476  parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
477  return {};
478  }
479 
480  // parse optional attributes
481  while (mlir::succeeded(parser.parseOptionalComma())) {
482  mlir::Attribute attr;
483  ParseResult res = parser.parseAttribute(attr);
484  if (mlir::succeeded(res)) {
485  if (mlir::isa<LayoutAttr>(attr)) {
486  layout = attr;
487  continue;
488  }
489  if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
490  encoding = attr;
491  continue;
492  }
493  }
494  return {};
495  }
496 
497  // Parse literal '>'
498  if (parser.parseGreater())
499  return {};
500 
501  MLIRContext *ctxt = parser.getContext();
502  return TensorDescType::getChecked(
503  [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
504  elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
505  layout.value_or(mlir::Attribute()));
506 }
507 
508 void TensorDescType::print(AsmPrinter &printer) const {
509  printer << "<";
510 
511  auto shape = getShape();
512  for (int64_t dim : shape) {
513  if (mlir::ShapedType::isDynamic(dim))
514  printer << '?';
515  else
516  printer << dim;
517  printer << 'x';
518  }
519 
520  printer << getElementType();
521 
522  auto encoding = getEncoding();
523  auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
524  if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
525  printer << ", " << encoding;
526 
527  if (auto layout = getLayout())
528  printer << ", " << layout;
529 
530  printer << ">";
531 }
532 
533 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
534  mlir::Type elementType, int array_length,
535  bool boundary_check,
536  MemorySpace memory_space,
537  mlir::Attribute layout) {
538  auto context = elementType.getContext();
539  auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
540  boundary_check);
541  return Base::get(context, shape, elementType, attr, layout);
542 }
543 
544 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
545  mlir::Type elementType, int chunk_size,
546  MemorySpace memory_space,
547  mlir::Attribute layout) {
548  auto context = elementType.getContext();
549  auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
550  return Base::get(context, shape, elementType, attr, layout);
551 }
552 
553 LogicalResult
555  llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
556  mlir::Attribute encoding, mlir::Attribute layout) {
557  size_t rank = shape.size();
558 
559  if (rank == 0)
560  return emitError() << "expected non-zero rank tensor";
561 
562  auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
563  if (blockAttr) {
564  MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
565  if (rank > 1 && memorySpaceAttr &&
566  memorySpaceAttr.getValue() == MemorySpace::SLM)
567  return emitError() << "SLM is only supported for 1D block tensor";
568  }
569 
570  // for gather and scatter ops, Low-precision types are packed in 32-bit units.
571  unsigned bitWidth = elementType.getIntOrFloatBitWidth();
572  int chunkAlignmentFactor =
575  : 1;
576  auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
577  if (scatterAttr) {
578  int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
579  if (rank == 1 && chunkSize != 1)
580  return emitError() << "expected non-contiguous elements for 1D tensor";
581 
582  // If chunk size > 1, the second dimension of the tensor shape must be
583  // equal to chunk size and it must be a multiple of the
584  // chunkAlignmentFactor.
585  if (chunkSize > 1) {
586  if (shape.back() != chunkSize)
587  return emitError() << "expected last dim of tensor to match chunk size";
588  if (shape.back() % chunkAlignmentFactor != 0)
589  return emitError() << "expected last dim of tensor to be a multiple of "
590  << chunkAlignmentFactor;
591  }
592  }
593 
594  auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
595  if (layoutAttr) {
596  if (rank != (size_t)layoutAttr.getRank())
597  return emitError() << "expected layout rank to match tensor rank";
598 
599  auto laneData = layoutAttr.getLaneData();
600  if (scatterAttr && laneData) {
601  // Validate subgroup mapping rules for scattered tensors.
602  // if chunkSize > 1, the last dimension of the tensor should
603  // be distributed in the units divisible by chunkAlignmentFactor.
604  int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
605  if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
606  return emitError()
607  << "expected last dim of lane_data to be a multiple of: "
608  << chunkAlignmentFactor;
609  }
610 
611  if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
612  std::string shapeStr;
613  llvm::raw_string_ostream stream(shapeStr);
614  llvm::interleaveComma(shape, stream);
615  return emitError() << "cannot distribute [" << shapeStr << "] using "
616  << layoutAttr;
617  }
618  }
619  return success();
620 }
621 
622 //===----------------------------------------------------------------------===//
623 // XeGPU_MemDescType
624 //===----------------------------------------------------------------------===//
627  mlir::Type elementType;
628  mlir::FailureOr<MemLayoutAttr> layout;
629 
630  // Parse literal '<'
631  if (parser.parseLess())
632  return {};
633 
634  auto shapeLoc = parser.getCurrentLocation();
635  if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
636  parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
637  return {};
638  }
639 
640  auto elemTypeLoc = parser.getCurrentLocation();
641  if (mlir::failed(parser.parseType(elementType))) {
642  parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
643  return {};
644  }
645 
646  // parse optional attributes
647  if (mlir::succeeded(parser.parseOptionalComma())) {
648  MemLayoutAttr attr;
649  ParseResult res = parser.parseAttribute(attr);
650  if (mlir::failed(res))
651  return {};
652  layout = attr;
653  }
654 
655  // Parse literal '>'
656  if (parser.parseGreater())
657  return {};
658 
659  MLIRContext *ctxt = parser.getContext();
660  return MemDescType::getChecked(
661  [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
662  elementType, layout.value_or(MemLayoutAttr()));
663 }
664 
665 void MemDescType::print(AsmPrinter &printer) const {
666  printer << "<";
667 
668  printer.printDimensionList(getShape());
669  printer << 'x';
670  printer << getElementType();
671 
672  if (auto layout = getMemLayout())
673  printer << ", " << layout;
674 
675  printer << ">";
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // XeGPU_MemDescType
680 //===----------------------------------------------------------------------===//
681 
683 
684  auto context = parser.getContext();
685  llvm::SMLoc loc = parser.getCurrentLocation();
686 
687  llvm::SmallDenseSet<StringRef> seenKeys;
688  SmallVector<NamedAttribute> attributes;
689 
690  auto parseElt = [&]() -> ParseResult {
691  StringRef nameId;
692  if (failed(parser.parseKeyword(&nameId)))
693  return parser.emitError(loc, "expected valid attribute name");
694 
695  if (!seenKeys.insert(nameId).second)
696  return parser.emitError(loc, "duplicate key '")
697  << nameId << " in mem layout attribute";
698 
699  if (failed(parser.parseEqual()))
700  return failure();
701 
702  Attribute attr;
703  if (failed(parser.parseAttribute(attr)))
704  return failure();
705  attributes.emplace_back(nameId, attr);
706  return success();
707  };
708 
709  // Parse literal '<'
710  if (parser.parseLess())
711  return {};
712 
713  if (failed(parser.parseCommaSeparatedList(parseElt)))
714  return {};
715 
716  // Parse literal '>'
717  if (parser.parseGreater())
718  return {};
719 
720  return parser.getChecked<MemLayoutAttr>(
721  loc, context, DictionaryAttr::get(context, attributes));
722 }
723 
724 void MemLayoutAttr::print(AsmPrinter &printer) const {
725  printer << "<";
726  ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
727  for (size_t i = 0; i < attrs.size(); i++) {
728  printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
729  if (i < attrs.size() - 1)
730  printer << ", ";
731  }
732  printer << ">";
733 }
734 // a helper utility to perform binary operation on OpFoldResult.
735 // If both a and b are attributes, it will simply return the result.
736 // Otherwise, the corresponding arith op will be generated, and an
737 // contant op will be created if one of them is an attribute.
738 template <typename ArithOp>
740  OpBuilder &builder) {
741  auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
742  auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
743  return ArithOp::create(builder, loc, aVal, bVal).getResult();
744 }
745 
746 // a helper utility to perform division operation on OpFoldResult and int64_t.
747 #define div(a, b) \
748  genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
749 
750 // a helper utility to perform reminder operation on OpFoldResult and int64_t.
751 #define rem(a, b) \
752  genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
753 
754 // a helper utility to perform multiply operation on OpFoldResult and int64_t.
755 #define mul(a, b) \
756  genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
757 
758 // a helper utility to perform addition operation on two OpFoldResult.
759 #define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
760 
761 // block the given offsets according to the block shape
762 // say the original offset is [y, x], and the block shape is [By, Bx],
763 // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
765  ArrayRef<OpFoldResult> offsets,
766  ArrayRef<int64_t> blockShape) {
767 
768  assert(offsets.size() == blockShape.size() &&
769  "offsets and blockShape must have the same size");
770  SmallVector<OpFoldResult> blockedOffsets;
771  SmallVector<OpFoldResult> divs, rems;
772 
773  for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
774  divs.push_back(div(offset, block));
775  rems.push_back(rem(offset, block));
776  }
777  blockedOffsets.append(divs.begin(), divs.end());
778  blockedOffsets.append(rems.begin(), rems.end());
779 
780  return blockedOffsets;
781 }
782 
783 // Get strides as vector of integer for MemDesc.
784 SmallVector<int64_t> MemDescType::getStrideShape() {
785 
786  SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
787 
788  ArrayAttr strideAttr = getStrideAttr();
789  SmallVector<int64_t> strides;
790  for (Attribute attr : strideAttr.getValue()) {
791  strides.push_back(cast<IntegerAttr>(attr).getInt());
792  }
793 
794  SmallVector<int64_t> innerBlkShape = getBlockShape();
795 
796  // get perm from FCD to LCD
797  // perm[i] = the dim with i-th smallest stride
798  SmallVector<int, 4> perm =
799  llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
800  llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
801 
802  assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
803 
804  SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
805  innerBlkStride[perm[0]] = 1;
806  for (size_t i = 1; i < perm.size(); ++i)
807  innerBlkStride[perm[i]] =
808  innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
809 
810  // compute the original matrix shape using the stride info
811  // and compute the number of blocks in each dimension
812  // The shape of highest dim can't be derived from stride info,
813  // but doesn't impact the stride computation for blocked layout.
814  SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
815  SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
816  for (size_t i = 0; i < perm.size() - 1; ++i) {
817  matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
818  BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
819  }
820 
821  int64_t innerBlkSize = 1;
822  for (auto s : innerBlkShape)
823  innerBlkSize *= s;
824 
825  SmallVector<int64_t> outerBlkStride(matrixShape.size());
826  outerBlkStride[perm[0]] = innerBlkSize;
827  for (size_t i = 0; i < perm.size() - 1; ++i) {
828  outerBlkStride[perm[i + 1]] =
829  outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
830  }
831 
832  // combine the inner and outer strides
833  SmallVector<int64_t> blockedStrides;
834  blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
835  blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
836 
837  return blockedStrides;
838 }
839 
840 // Calculate the linear offset using the blocked offsets and stride
841 Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
842  ArrayRef<OpFoldResult> offsets) {
843 
844  SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
845  SmallVector<int64_t> blockShape = getBlockShape();
846  SmallVector<int64_t> strides = getStrideShape();
847  SmallVector<OpFoldResult> blockedOffsets;
848 
849  // blockshape equal to matrixshape means no blocking
850  if (llvm::equal(blockShape, matrixShape)) {
851  // remove the outer dims from strides
852  strides.erase(strides.begin(), strides.begin() + matrixShape.size());
853  } else {
854  assert(offsets.size() == blockShape.size() &&
855  "offsets and blockShape must have the same size");
856  // say the original offset is [y, x], and the block shape is [By, Bx],
857  // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
858 
859  SmallVector<OpFoldResult> divs, rems;
860 
861  for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
862  divs.push_back(div(offset, block));
863  rems.push_back(rem(offset, block));
864  }
865  blockedOffsets.append(divs.begin(), divs.end());
866  blockedOffsets.append(rems.begin(), rems.end());
867  offsets = blockedOffsets;
868  }
869 
870  // Start with initial value as matrix descriptor's base offset.
871  Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
872  for (size_t i = 0; i < offsets.size(); ++i) {
873  OpFoldResult mulResult = mul(offsets[i], strides[i]);
874  Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
875  linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
876  }
877 
878  return linearOffset;
879 }
880 
881 } // namespace xegpu
882 } // namespace mlir
883 
884 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
885 #define GET_ATTRDEF_CLASSES
886 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
887 #define GET_TYPEDEF_CLASSES
888 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
#define mul(a, b)
#define div(a, b)
#define rem(a, b)
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:316
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This class helps build Operations.
Definition: Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:526
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
Definition: Utils.cpp:1967
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
Definition: VectorUtils.h:130
constexpr unsigned generalPackedFormatBitSize
Definition: uArchBase.h:32
static SmallVector< SmallVector< Value > > genOffsetsComputingInsts(OpBuilder &builder, Location loc, SmallVector< Value > sgId, ArrayRef< int64_t > sgLayout, ArrayRef< int64_t > sizePerSg, ArrayRef< int64_t > sizePerWg)
Generates instructions to compute offsets for a subgroup identified by its multidimensional indices (...
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423