MLIR 23.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/IR/Builders.h"
18#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/Debug.h"
21
22using std::optional;
23
24namespace mlir {
25namespace xegpu {
26
27void XeGPUDialect::initialize() {
28 addTypes<
29#define GET_TYPEDEF_LIST
30#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
31 >();
32 addOperations<
33#define GET_OP_LIST
34#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
35 >();
36 addAttributes<
37#define GET_ATTRDEF_LIST
38#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
39 >();
40}
41#define GET_OP_INTERFACE_CLASSES
42#include "mlir/Dialect/XeGPU/IR/XeGPUOpInterface.cpp.inc"
43
44// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
45// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
46// within each distribution unit.
47// Example:
48// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
49// distribution unit of shape 64x64, we have 2x4 such distribution units.
50// `delinearizedId` is used to identify a 16x32 of a subgroup in each
51// distribution unit.
54 SmallVector<Value> delinearizedId,
55 ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
56 ArrayRef<int64_t> srcShape) {
58
59 // A distribution unit must be less than or equal to `srcShape`
60 SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
61 llvm::zip_equal(srcShape,
62 computeElementwiseMul(subShapesLayout, subShape)),
63 [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
64
65 // Get the offset of `subShape` within a distribution unit.
66 SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
67 llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
68 return builder.createOrFold<arith::MulIOp>(
69 loc, std::get<0>(t),
70 builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
71 });
72
73 // For each dist unit
74 for (SmallVector<int64_t> unitOffs :
75 StaticTileOffsetRange(srcShape, distUnitShape)) {
76 // Get dist unit offset within `srcShape`.
78 llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
79 return arith::ConstantIndexOp::create(builder, loc, d);
80 });
81 // Calculate `subShape` offset within `srcShape`.
83 llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
84 [&](const auto &t) -> Value {
85 return builder.createOrFold<arith::AddIOp>(
86 loc, std::get<0>(t), std::get<1>(t));
87 });
88 // Do not go beyond `srcShape` bounds.
89 SmallVector<Value> mods = llvm::map_to_vector(
90 llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
91 return builder.createOrFold<arith::RemUIOp>(
92 loc, std::get<0>(t),
93 arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
94 });
95
96 coordinates.push_back(mods);
97 }
98 return coordinates;
99}
100
104 // Compute distribution unit shape (clamped to srcShape).
105 SmallVector<int64_t> distUnitShape(shape.size());
106 for (size_t i = 0; i < shape.size(); ++i)
107 distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
108
109 // Compute local offset of this ID within a distribution unit.
110 SmallVector<int64_t> localOffset(shape.size());
111 for (size_t i = 0; i < shape.size(); ++i)
112 localOffset[i] = canonicalIds[i] * subShape[i];
113
114 // Enumerate all distribution units and compute coordinates.
116 for (SmallVector<int64_t> unitOffs :
117 StaticTileOffsetRange(shape, distUnitShape)) {
118 SmallVector<int64_t> coord(shape.size());
119 for (size_t i = 0; i < shape.size(); ++i)
120 coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
121 coordinates.push_back(coord);
122 }
123 return coordinates;
124}
125
126// Checks if the given memref type represents shared local memory (SLM).
127bool XeGPUDialect::isSharedMemory(const MemRefType &memrefTy) {
128 Attribute attr = memrefTy.getMemorySpace();
129 if (!attr)
130 return false; // Default memory space is not shared local memory
131 if (auto intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr))
132 return intAttr.getInt() == 3;
133 if (auto memrefSpace = llvm::dyn_cast_if_present<MemorySpaceAttr>(attr))
134 return memrefSpace.getValue() == MemorySpace::SLM;
135 if (auto xevmSpace = llvm::dyn_cast_if_present<xevm::AddrSpaceAttr>(attr))
136 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
137 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
138}
139
140//===----------------------------------------------------------------------===//
141// XeGPU_BlockTensorDescAttr
142//===----------------------------------------------------------------------===//
143BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
144 xegpu::MemorySpace memory_space,
145 int array_length,
146 bool boundary_check) {
147 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
148 auto lengthAttr =
149 IntegerAttr::get(IntegerType::get(context, 64), array_length);
150 auto boundaryAttr = BoolAttr::get(context, boundary_check);
151 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
152}
153
154bool BlockTensorDescAttr::hasDefaultsOnly() {
155 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
156 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
157}
158
159//===----------------------------------------------------------------------===//
160// XeGPU_LayoutAttr
161//===----------------------------------------------------------------------===//
162LogicalResult
163LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
164 DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
165 DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
166 DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
167
168 // Special case for store_matrix
169 if (!sg_layout && !inst_data && !lane_layout)
170 return success();
171
172 // generate code to check sg_laout, inst_data and lane_layout having the same
173 // rank if they are not null.
174
175 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
176 return emitError()
177 << "expected sg_layout and inst_data to have the same rank";
178 }
179
180 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
181 return emitError()
182 << "expected sg_layout and lane_layout to have the same rank";
183 }
184
185 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
186 return emitError() << "expected inst_data and lane_layout to have the same "
187 "rank, got inst_data "
188 << inst_data.size() << ", lane_layout "
189 << lane_layout.size();
190 }
191
192 if ((sg_layout && !sg_data) || (!sg_layout && sg_data))
193 return emitError() << "sg_layout and sg_data must be used together";
194 if (sg_layout && sg_data && sg_layout.size() != sg_data.size())
195 return emitError()
196 << "expected sg_data and sg_layout to have the same rank";
197
198 if ((lane_layout && !lane_data) || (!lane_layout && lane_data))
199 return emitError() << "lane_layout and lane_data must be used together";
200 if (lane_layout && lane_data && lane_layout.size() != lane_data.size())
201 return emitError()
202 << "expected lane_data and lane_layout to have the same rank";
203
204 if (order) {
205 if (!sg_layout && !lane_layout)
206 return emitError()
207 << "expected sg_layout/lane_layout being used with order";
208
209 if (sg_layout && order.size() != sg_layout.size())
210 return emitError()
211 << "expected order and sg_layout to have the same rank";
212
213 if (lane_layout && order.size() != lane_layout.size())
214 return emitError()
215 << "expected order and lane_layout to have the same rank";
216 }
217
218 return success();
219}
220
221FailureOr<SmallVector<Value>>
222LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
223
224 SmallVector<int64_t> sgLayoutInt;
225 if (isForWorkgroup()) {
226 sgLayoutInt = getEffectiveSgLayoutAsInt();
227 } else if (isForSubgroup()) {
228 sgLayoutInt = getEffectiveLaneLayoutAsInt();
229 } else {
230 return failure();
231 }
232
233 DenseI32ArrayAttr orderAttr = getOrder();
234
235 // Handle order attribute
236 SmallVector<int64_t> order;
237 if (orderAttr && !orderAttr.empty()) {
238 order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
239 return static_cast<int64_t>(idx);
240 });
241 } else {
242 // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
243 order = llvm::to_vector(
244 llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
245 }
246
247 if (order.size() != sgLayoutInt.size()) {
248 return failure();
249 }
250
251 SmallVector<Value> result(sgLayoutInt.size());
252 Value remaining = linearId;
253
254 /// Process dimensions in the order they appear in the order array
255 /// The first dimension in order is the fastest-changing
256 ///
257 /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
258 ///
259 /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
260 /// result=[?,?,?]
261 ///
262 /// i=0 (process columns, dimIdx=2, dimSize=4):
263 /// result[2] = 22 % 4 = 2 (column coordinate)
264 /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
265 ///
266 /// i=1 (process rows, dimIdx=1, dimSize=4):
267 /// result[1] = 5 % 4 = 1 (row coordinate)
268 /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
269 ///
270 /// i=2 (process layers, dimIdx=0, dimSize=2):
271 /// result[0] = 1 % 2 = 1 (layer coordinate)
272 /// (no remaining update - last iteration)
273 ///
274 /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
275 for (size_t i = 0; i < order.size(); ++i) {
276 int64_t dimIdx = order[i];
277 int64_t dimSize = sgLayoutInt[dimIdx];
278
279 Value dimSizeVal =
280 builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
281
282 /// Extract the coordinate for this dimension using modulo operation
283 /// This gives us "how far within this dimension" we are
284 /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
285 /// this dimension)
286 result[dimIdx] =
287 builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
288
289 /// Update remaining for the next dimension by removing what we've already
290 /// processed. Division tells us "how many complete groups of this dimension
291 /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
292 /// completed 5 groups of 4) Skip this for the last iteration since there's
293 /// no next dimension to process
294 if (i < order.size() - 1) {
295 remaining =
296 builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
297 }
298 }
299 return result;
300}
301
302/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
303/// instructions for computing multi-dimensional offsets when distributed by
304/// LayoutAttr.
305FailureOr<SmallVector<SmallVector<Value>>>
306LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
307 Value linearId, ArrayRef<int64_t> shape) {
308 SmallVector<int64_t> layout;
309 SmallVector<int64_t> subShape;
310 if (isForWorkgroup()) {
311 layout = getEffectiveSgLayoutAsInt();
312 subShape = getEffectiveSgDataAsInt();
313 } else if (isForSubgroup()) {
314 layout = getEffectiveLaneLayoutAsInt();
315 subShape = getEffectiveLaneDataAsInt();
316 } else {
317 return failure();
318 }
319 assert(!subShape.empty() && "sgdata or lanedata cannot be empty for "
320 "distributed coordinates computation");
321
322 // delinearize Ids
323 auto maybeIds = delinearizeId(builder, loc, linearId);
324 if (failed(maybeIds))
325 return failure();
326 SmallVector<Value> ids = *maybeIds;
327
328 return genCoordinates(builder, loc, ids, layout, subShape, shape);
329}
330
331bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
332 if (dyn_cast<xegpu::SliceAttr>(other))
333 return false;
334
335 return *this == dyn_cast<xegpu::LayoutAttr>(other);
336}
337
338/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
339/// compute multi-dimensional offsets for a given linear ID when distributed by
340/// LayoutAttr.
341SmallVector<SmallVector<int64_t>>
342LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
343 ArrayRef<int64_t> shape) {
344 SmallVector<int64_t> layoutVec;
345 SmallVector<int64_t> subShape;
346 SmallVector<int64_t> instData;
347 if (isForWorkgroup()) {
348 layoutVec = getEffectiveSgLayoutAsInt();
349 subShape = getEffectiveSgDataAsInt();
350 } else if (isForSubgroup()) {
351 instData = getEffectiveInstDataAsInt();
352 layoutVec = getEffectiveLaneLayoutAsInt();
353 subShape = getEffectiveLaneDataAsInt();
354 }
355 if (!instData.empty()) {
356 linearId = 0;
357 subShape = instData;
358 }
359 assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
360
361 // Delinearize the linear ID using the order attribute.
362 SmallVector<int64_t> order = getEffectiveOrderAsInt();
363 SmallVector<int64_t> delinearizedId(layoutVec.size());
364 int64_t remaining = linearId;
365 for (size_t i = 0; i < order.size(); ++i) {
366 int64_t dimIdx = order[i];
367 delinearizedId[dimIdx] = remaining % layoutVec[dimIdx];
368 remaining = remaining / layoutVec[dimIdx];
369 }
370
371 return genStaticCoordinates(delinearizedId, layoutVec, subShape, shape);
372}
373
374// set the layout for unit dims: sg_data, inst_data and lane_data to 1
375DistributeLayoutAttr
376LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
377 auto sgDataOpt = getSgData();
378 auto instDataOpt = getInstData();
379 auto laneDataOpt = getLaneData();
380
381 SmallVector<int32_t> sgData;
382 SmallVector<int32_t> instData;
383 SmallVector<int32_t> laneData;
384
385 if (sgDataOpt)
386 sgData = llvm::to_vector(sgDataOpt.asArrayRef());
387
388 if (instDataOpt)
389 instData = llvm::to_vector(instDataOpt.asArrayRef());
390
391 if (laneDataOpt)
392 laneData = llvm::to_vector(laneDataOpt.asArrayRef());
393
394 for (auto dim : unitDims) {
395 if (dim < static_cast<int64_t>(sgData.size()))
396 sgData[dim] = 1;
397 if (dim < static_cast<int64_t>(instData.size()))
398 instData[dim] = 1;
399 if (dim < static_cast<int64_t>(laneData.size()))
400 laneData[dim] = 1;
401 }
402
403 return LayoutAttr::get(
404 getContext(), getSgLayout(),
405 sgData.empty() ? DenseI32ArrayAttr()
407 instData.empty() ? DenseI32ArrayAttr()
408 : DenseI32ArrayAttr::get(getContext(), instData),
409 getLaneLayout(),
410 laneData.empty() ? DenseI32ArrayAttr()
411 : DenseI32ArrayAttr::get(getContext(), laneData),
412 getOrder());
413}
414
415// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
416DistributeLayoutAttr
417LayoutAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
418 auto sgLayoutOpt = getSgLayout();
419 auto laneLayoutOpt = getLaneLayout();
420
421 SmallVector<int32_t> sgLayout;
422 SmallVector<int32_t> laneLayout;
423
424 if (sgLayoutOpt)
425 sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
426 if (laneLayoutOpt)
427 laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
428
429 for (auto dim : unitDims) {
430 if (dim < static_cast<int64_t>(sgLayout.size()))
431 sgLayout[dim] = 1;
432 if (dim < static_cast<int64_t>(laneLayout.size()))
433 laneLayout[dim] = 1;
434 }
435
436 return LayoutAttr::get(
437 getContext(),
438 sgLayout.empty() ? DenseI32ArrayAttr()
439 : DenseI32ArrayAttr::get(getContext(), sgLayout),
440 getSgData(), getInstData(),
441 laneLayout.empty() ? DenseI32ArrayAttr()
442 : DenseI32ArrayAttr::get(getContext(), laneLayout),
443 getLaneData(), getOrder());
444}
445
446// Derive a new layout with sg_data, inst_data and lane_data set to the
447// specified values for the given dimension
448DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
449 int64_t instData,
450 int64_t laneData) {
451
452 SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
453 SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
454 SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
455
456 if (dim < static_cast<int64_t>(sgDataVec.size()) && sgData != -1)
457 sgDataVec[dim] = sgData;
458 if (dim < static_cast<int64_t>(instDataVec.size()) && instData != -1)
459 instDataVec[dim] = instData;
460 if (dim < static_cast<int64_t>(laneDataVec.size()) && laneData != -1)
461 laneDataVec[dim] = laneData;
462
463 SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
464 SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
465 SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
466
467 return LayoutAttr::get(
468 getContext(), getSgLayout(),
469 sgDataVec.empty() ? DenseI32ArrayAttr()
470 : DenseI32ArrayAttr::get(getContext(), sgDataVec32),
471 instDataVec.empty() ? DenseI32ArrayAttr()
472 : DenseI32ArrayAttr::get(getContext(), instDataVec32),
473 getLaneLayout(),
474 laneDataVec.empty() ? DenseI32ArrayAttr()
475 : DenseI32ArrayAttr::get(getContext(), laneDataVec32),
476 getOrder());
477}
478
479// Derive a new layout by removing dimensions.
480// `dimGroup` specifies a group of dimensions to be removed in the derived
481// layout.
482DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
483
484 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
485 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
486 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
487 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
488 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
489 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
490
491 SmallVector<int64_t> sortedDimGroup = dimGroup;
492 llvm::sort(sortedDimGroup);
493
494 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
495 if (!sgLayout.empty()) {
496 sgLayout.erase(sgLayout.begin() + dimIdx);
497 sgData.erase(sgData.begin() + dimIdx);
498 }
499 if (!instData.empty())
500 instData.erase(instData.begin() + dimIdx);
501 if (!laneLayout.empty()) {
502 laneLayout.erase(laneLayout.begin() + dimIdx);
503 laneData.erase(laneData.begin() + dimIdx);
504 }
505 }
506
507 SmallVector<int64_t> newOrder;
508 for (int64_t d : origOrder) {
509 if (llvm::is_contained(dimGroup, d))
510 continue;
511 int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
512 newOrder.push_back(d - offset);
513 }
514 if ((sgLayout.empty() && laneLayout.empty()) || newOrder.size() == 1)
515 newOrder.clear();
516
517 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
518 if (v.empty())
519 return DenseI32ArrayAttr();
520 SmallVector<int32_t> v32(v.begin(), v.end());
521 return DenseI32ArrayAttr::get(getContext(), v32);
522 };
523 auto droppedLayout = xegpu::LayoutAttr::get(
524 getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
525 toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
526 return droppedLayout;
527}
528
529// Derive a new layout by collapsing dimensions.
530// `dimGroup` specifies a group of adjacent dimensions
531// that are collapsed into a single dimension in the derived layout.
532DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
533
534 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
535 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
536 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
537 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
538 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
539 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
540
541 SmallVector<int64_t> sortedDimGroup = dimGroup;
542 llvm::sort(sortedDimGroup);
543 int64_t dimBeforeCurrent = -1;
544 for (auto dimIdx : sortedDimGroup) {
545 // when order attr is present, adjacency dims are values like [3, 2, 1, 0]
546 // in decreasing order; otherwise based on dim indices like [0, 1, 2, 3]
547 // in increasing order
548 if (dimBeforeCurrent >= 0) {
549 if (getOrder() && !getOrder().empty()) {
550 int64_t orderBefore = origOrder[dimBeforeCurrent];
551 int64_t orderCurrent = origOrder[dimIdx];
552 if (orderBefore != (orderCurrent - 1))
553 llvm::report_fatal_error(
554 "dimensions being collapsed must be adjacent in order");
555 } else {
556 if (dimIdx != (dimBeforeCurrent + 1))
557 llvm::report_fatal_error(
558 "dimensions being collapsed must be adjacent");
559 }
560 }
561 dimBeforeCurrent = dimIdx;
562 }
563
564 int firstDim = sortedDimGroup.front();
565
566 // collapse the dimensions in dimGroup into one dimension by multiplying their
567 // sizes together
568
569 if (!sgLayout.empty()) {
570 int64_t collapsedSglayout = 1, collapsedSgData = 1;
571 for (auto dimIdx : dimGroup) {
572 collapsedSglayout *= sgLayout[dimIdx];
573 collapsedSgData *= sgData[dimIdx];
574 }
575 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
576 sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
577 sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
578 }
579 sgLayout.insert(sgLayout.begin() + firstDim, collapsedSglayout);
580 sgData.insert(sgData.begin() + firstDim, collapsedSgData);
581 }
582
583 if (!instData.empty()) {
584 int64_t collapsedInstData = 1;
585 for (auto dimIdx : dimGroup)
586 collapsedInstData *= instData[dimIdx];
587 for (auto dimIdx : llvm::reverse(sortedDimGroup))
588 instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
589 instData.insert(instData.begin() + firstDim, collapsedInstData);
590 }
591
592 if (!laneLayout.empty()) {
593 int64_t collapsedLaneLayout = 1, collapsedLaneData = 1;
594 for (auto dimIdx : dimGroup) {
595 collapsedLaneLayout *= laneLayout[dimIdx];
596 collapsedLaneData *= laneData[dimIdx];
597 }
598 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
599 laneLayout.erase(laneLayout.begin() + dimIdx,
600 laneLayout.begin() + dimIdx + 1);
601 laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
602 }
603 laneLayout.insert(laneLayout.begin() + firstDim, collapsedLaneLayout);
604 laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
605 }
606
607 SmallVector<int64_t> newOrder;
608 DenseI32ArrayAttr orderAttr = getOrder();
609 if (orderAttr && !orderAttr.empty()) {
610
611 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
612 if (dimIdx != firstDim)
613 origOrder.erase(origOrder.begin() + dimIdx);
614 }
615 // say we have orderVec = {5, 3, 2, 1, 0}
616 // Create indices [0, 1, 2, 3, 4]
617 SmallVector<size_t> indices =
618 llvm::to_vector(llvm::seq<size_t>(0, origOrder.size()));
619
620 // Sort indices based on corresponding values
621 llvm::sort(indices,
622 [&](size_t a, size_t b) { return origOrder[a] < origOrder[b]; });
623
624 newOrder = llvm::to_vector(llvm::map_range(
625 indices, [&](size_t i) { return static_cast<int64_t>(i); }));
626 }
627
628 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
629 if (v.empty())
630 return DenseI32ArrayAttr();
631 SmallVector<int32_t> v32(v.begin(), v.end());
632 return DenseI32ArrayAttr::get(getContext(), v32);
633 };
634 auto collapsedLayout = xegpu::LayoutAttr::get(
635 getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
636 toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
637 return collapsedLayout;
638}
639
640// Derive a new layout by transpose the layout using `permutation`.
641DistributeLayoutAttr LayoutAttr::transposeDims(ArrayRef<int64_t> permutation) {
642
643 SmallVector<int64_t> origSgLayout = getEffectiveSgLayoutAsInt();
644 SmallVector<int64_t> origSgData = getEffectiveSgDataAsInt();
645 SmallVector<int64_t> origInstData = getEffectiveInstDataAsInt();
646 SmallVector<int64_t> origLaneLayout = getEffectiveLaneLayoutAsInt();
647 SmallVector<int64_t> origLaneData = getEffectiveLaneDataAsInt();
648 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
649
650 SmallVector<int32_t> sgLayout;
651 SmallVector<int32_t> sgData;
652 SmallVector<int32_t> instData;
653 SmallVector<int32_t> laneLayout;
654 SmallVector<int32_t> laneData;
655 SmallVector<int32_t> order;
656
657 for (int64_t idx : permutation) {
658 if (!origLaneLayout.empty()) {
659 laneLayout.push_back(static_cast<int32_t>(origLaneLayout[idx]));
660 laneData.push_back(static_cast<int32_t>(origLaneData[idx]));
661 }
662 if (!origInstData.empty())
663 instData.push_back(static_cast<int32_t>(origInstData[idx]));
664 if (!origSgLayout.empty()) {
665 sgLayout.push_back(static_cast<int32_t>(origSgLayout[idx]));
666 sgData.push_back(static_cast<int32_t>(origSgData[idx]));
667 }
668 order.push_back(static_cast<int32_t>(origOrder[idx]));
669 }
670 if (origLaneLayout.empty() && origSgLayout.empty())
671 order.clear();
672
673 auto toAttr = [&](ArrayRef<int32_t> v) -> DenseI32ArrayAttr {
674 return v.empty() ? nullptr : DenseI32ArrayAttr::get(getContext(), v);
675 };
676 return xegpu::LayoutAttr::get(getContext(), toAttr(sgLayout), toAttr(sgData),
677 toAttr(instData), toAttr(laneLayout),
678 toAttr(laneData), toAttr(order));
679}
680
681/// Check if this layout is a transpose of another layout.
682bool LayoutAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
683 ArrayRef<int64_t> perm,
684 const xegpu::LayoutKind kind) {
685 if (!other)
686 return false;
687 if (getRank() != other.getRank() ||
688 perm.size() != static_cast<size_t>(getRank()))
689 return false;
690 if (!isPermutationVector(perm))
691 return false;
692 auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src,
693 ArrayRef<int64_t> perm) {
694 for (const auto &ta : llvm::enumerate(perm)) {
695 if (src[ta.index()] != dst[ta.value()])
696 return false;
697 }
698 return true;
699 };
700 if (kind == xegpu::LayoutKind::Subgroup)
701 return checkTranspose(getEffectiveSgLayoutAsInt(),
702 other.getEffectiveSgLayoutAsInt(), perm) &&
703 checkTranspose(getEffectiveSgDataAsInt(),
704 other.getEffectiveSgDataAsInt(), perm) &&
705 checkTranspose(getEffectiveOrderAsInt(),
706 other.getEffectiveOrderAsInt(), perm);
707 if (kind == xegpu::LayoutKind::InstData)
708 return checkTranspose(getEffectiveInstDataAsInt(),
709 other.getEffectiveInstDataAsInt(), perm);
710 if (kind == xegpu::LayoutKind::Lane)
711 return checkTranspose(getEffectiveLaneLayoutAsInt(),
712 other.getEffectiveLaneLayoutAsInt(), perm) &&
713 checkTranspose(getEffectiveLaneDataAsInt(),
714 other.getEffectiveLaneDataAsInt(), perm) &&
715 checkTranspose(getEffectiveOrderAsInt(),
716 other.getEffectiveOrderAsInt(), perm);
717
718 return false;
719}
720
721bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
722 SmallVector<int64_t> shape,
723 xegpu::LayoutKind level) {
724 if (!other)
725 return false;
726 if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
727 // short cut when order is the same, no need to compute coords and compare
728 if (level == xegpu::LayoutKind::Subgroup)
729 if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
730 getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
731 return true;
732 if (level == xegpu::LayoutKind::Lane)
733 if (getEffectiveLaneLayoutAsInt() ==
734 other.getEffectiveLaneLayoutAsInt() &&
735 getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
736 return true;
737 }
738
739 auto compareCoordsForAllIds = [&](int64_t size) {
740 for (int64_t id : llvm::seq<int64_t>(0, size)) {
741 auto coords = computeStaticDistributedCoords(id, shape);
742 auto otherCoords = other.computeStaticDistributedCoords(id, shape);
743 if (coords != otherCoords)
744 return false;
745 }
746 return true;
747 };
748
749 if (level == xegpu::LayoutKind::Subgroup) {
750 int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
751 return compareCoordsForAllIds(wgSize);
752 }
753 if (level == xegpu::LayoutKind::InstData) {
754 return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
755 }
756 if (level == xegpu::LayoutKind::Lane) {
757 int64_t subgroupSize = computeProduct(getEffectiveLaneLayoutAsInt());
758 return compareCoordsForAllIds(subgroupSize);
759 }
760 return true;
761}
762
763//===----------------------------------------------------------------------===//
764// XeGPU_SliceAttr
765//===----------------------------------------------------------------------===//
766LogicalResult
767SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
768 xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
769
770 if (!dims)
771 return emitError() << "expected dims attribute";
772
773 // check every element in dims is unique and smaller than rank
774 llvm::SmallDenseSet<int64_t> seen;
775 for (int64_t dim : dims.asArrayRef()) {
776 if (dim < 0)
777 return emitError() << "invalid dim (" << dim << ") in slice attribute.";
778 if (!seen.insert(dim).second)
779 return emitError() << "repeated dim (" << dim << ") in slice attribute.";
780 }
781 return success();
782}
783
784SliceAttr SliceAttr::flatten() const {
785 xegpu::DistributeLayoutAttr parent = getParent();
786 SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
787
788 while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
789 parent = sliceAttr.getParent();
790 slicedDims.push_back(sliceAttr.getDims());
791 }
792
793 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
794 SmallVector<int64_t> indices =
795 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
796
797 // get remaining dims (flattened) by applying slice ops with all slicedDims
798 SmallVector<int64_t> remainingDims(indices);
799 for (auto dim : llvm::reverse(slicedDims))
800 remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
801 dim.asArrayRef());
802
803 // get flattened sliced dims by applying slice ops with the remaining dims
804 SmallVector<int64_t> flattenedDims = XeGPUDialect::slice(
805 llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
806
807 return xegpu::SliceAttr::get(
808 getContext(), layoutAttr,
809 DenseI64ArrayAttr::get(getContext(), flattenedDims));
810}
811
812FailureOr<SmallVector<Value>>
813SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
814 SliceAttr attr = flatten();
815 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
816 return parent.delinearizeId(builder, loc, linearId);
817}
818
819// Implements DistributeLayoutAttr::computeDistributedCoords to generate
820// instructions for computing multi-dimensional offsets when distributed by
821// LayoutAttr.
822FailureOr<SmallVector<SmallVector<Value>>>
823SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
824 Value linearId, ArrayRef<int64_t> shape) {
825 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
826
827 SmallVector<int64_t> layout;
828 SmallVector<int64_t> subShape;
829 if (isForWorkgroup()) {
830 layout = getEffectiveSgLayoutAsInt();
831 subShape = getEffectiveSgDataAsInt();
832 } else if (isForSubgroup()) {
833 layout = getEffectiveLaneLayoutAsInt();
834 subShape = getEffectiveLaneDataAsInt();
835 } else {
836 return failure();
837 }
838
839 if (subShape.empty())
840 return failure();
841
842 // delinearize Ids
843 auto maybeIds = delinearizeId(builder, loc, linearId);
844 if (failed(maybeIds))
845 return failure();
846
847 // The effective sgIds for offsets computing correspond
848 // to the dims that are not sliced.
849 ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
850 SmallVector<Value> canonicalIds =
851 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
852
853 return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
854}
855
856/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
857/// compute multi-dimensional offsets for a given linear ID when distributed by
858/// SliceAttr. Delegates delinearization to the parent LayoutAttr, then uses
859/// only the non-sliced dimensions for coordinate computation.
860SmallVector<SmallVector<int64_t>>
861SliceAttr::computeStaticDistributedCoords(int64_t linearId,
862 ArrayRef<int64_t> shape) {
863 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
864
865 SmallVector<int64_t> layout;
866 SmallVector<int64_t> subShape;
867 SmallVector<int64_t> instData;
868 if (isForWorkgroup()) {
869 layout = getEffectiveSgLayoutAsInt();
870 subShape = getEffectiveSgDataAsInt();
871 } else if (isForSubgroup()) {
872 instData = getEffectiveInstDataAsInt();
873 layout = getEffectiveLaneLayoutAsInt();
874 subShape = getEffectiveLaneDataAsInt();
875 }
876 if (!instData.empty()) {
877 linearId = 0;
878 subShape = instData;
879 }
880
881 assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
882
883 // Delinearize the ID using the parent layout (same as the IR version).
884 SliceAttr flattened = flatten();
885 auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
886 SmallVector<int64_t> parentLayoutVec;
887 if (parent.isForWorkgroup())
888 parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
889 else
890 parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
891
892 SmallVector<int64_t> order = parent.getEffectiveOrderAsInt();
893 SmallVector<int64_t> allIds(parentLayoutVec.size());
894 int64_t remaining = linearId;
895 for (size_t i = 0; i < order.size(); ++i) {
896 int64_t dimIdx = order[i];
897 allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
898 if (i < order.size() - 1)
899 remaining = remaining / parentLayoutVec[dimIdx];
900 }
901
902 // The effective IDs for coordinate computation correspond
903 // to the dims that are not sliced.
904 ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
905 SmallVector<int64_t> canonicalIds =
906 XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
907
908 return genStaticCoordinates(canonicalIds, layout, subShape, shape);
909}
910
911bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
912 auto flattenedThis = flatten();
913 // If other is a LayoutAttr, just compare directly with parent of
914 // flattenedThis.
915 if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
916 return flattenedThis.getParent() == otherLayout;
917 // If other is a SliceAttr, flatten it first before comparing.
918 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
919 // Both must have common parent LayoutAttr.
920 if (flattenedThis.getParent() != flattenedOther.getParent())
921 return false;
922 // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
923 // dims.
924 llvm::SmallDenseSet<int64_t> thisDims(
925 flattenedThis.getDims().asArrayRef().begin(),
926 flattenedThis.getDims().asArrayRef().end());
927 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
928 [&](int64_t dim) { return thisDims.contains(dim); });
929}
930
931bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
932 if (dyn_cast<xegpu::LayoutAttr>(other))
933 return false;
934
935 auto flattenedThis = flatten();
936 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
937
938 return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
939 (flattenedThis.getDims() == flattenedOther.getDims()));
940}
941
942bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
943 SmallVector<int64_t> shape,
944 xegpu::LayoutKind level) {
945 if (!other)
946 return false;
947 if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
948 // short cut when order is the same, no need to compute coords and compare
949 if (level == xegpu::LayoutKind::Subgroup)
950 if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
951 getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
952 return true;
953 if (level == xegpu::LayoutKind::Lane)
954 if (getEffectiveLaneLayoutAsInt() ==
955 other.getEffectiveLaneLayoutAsInt() &&
956 getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
957 return true;
958 }
959
960 auto compareCoordsForAllIds = [&](int64_t size) {
961 for (int64_t id : llvm::seq<int64_t>(0, size)) {
962 auto coords = computeStaticDistributedCoords(id, shape);
963 auto otherCoords = other.computeStaticDistributedCoords(id, shape);
964 if (coords != otherCoords)
965 return false;
966 }
967 return true;
968 };
969
970 auto flattenedThis = flatten();
971 auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
972 if (level == xegpu::LayoutKind::Subgroup) {
973 int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
974 return compareCoordsForAllIds(wgSize);
975 }
976 if (level == xegpu::LayoutKind::InstData) {
977 return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
978 }
979 if (level == xegpu::LayoutKind::Lane) {
980 int64_t subgroupSize = computeProduct(parent.getEffectiveLaneLayoutAsInt());
981 return compareCoordsForAllIds(subgroupSize);
982 }
983 return true;
984}
985
986xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
987 if (sliceDimsToDrop.empty())
988 return *this;
989 SmallVector<int64_t> sliceDims{getDims().asArrayRef()};
990 for (auto dim : sliceDimsToDrop) {
991 auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), dim);
992 assert(foundIt != sliceDims.end() &&
993 "Expected to find the specified reduction dim in slice dims");
994 sliceDims.erase(foundIt);
995 }
996
997 auto sliceWithoutDims = xegpu::SliceAttr::get(
998 this->getContext(), getParent(),
999 DenseI64ArrayAttr::get(this->getContext(), sliceDims));
1000
1001 return sliceWithoutDims;
1002}
1003
1004// Helper function to adjust dimensions from sliced space to parent space
1005// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
1006// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
1007// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
1008// it maps to dim [2] in parent space.
1009static SmallVector<int64_t>
1011 ArrayRef<int64_t> sliceDims) {
1012 // Rather than recovering the exact parent rank, we compute a safe upper
1013 // bound so that dimsToMap can be adjusted safely. This upper bound is
1014 // defined as max(dimsToMap, sliceDims) + 1 + sliceDims.size().
1015 int64_t maxDim = -1;
1016 maxDim =
1017 std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
1018 maxDim =
1019 std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
1020 int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
1021
1022 // get remaining dims in parent space after applying slicing with parent's
1023 // slice Dims
1024 llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
1025 sliceDims.end());
1026 SmallVector<int64_t> remainingDims;
1027 for (int64_t i = 0; i < parentSpaceRank; ++i) {
1028 if (!slicedDimsSet.contains(i))
1029 remainingDims.push_back(i);
1030 }
1031
1032 // Map unit dims from sliced space to parent space
1033 SmallVector<int64_t> adjustUnitDims;
1034 for (auto dim : dimsToMap) {
1035 int64_t mappedDim = remainingDims[dim];
1036 adjustUnitDims.push_back(mappedDim);
1037 }
1038
1039 return adjustUnitDims;
1040}
1041
1042// set the layout for unit dims: sg_data, inst_data and lane_data to 1
1043DistributeLayoutAttr
1044SliceAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
1045 DistributeLayoutAttr parentLayout = getParent();
1046
1047 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1048
1049 SmallVector<int64_t> adjustUnitDims =
1050 mapSlicedDimsToParentSpace(unitDims, sliceDims);
1051
1052 return SliceAttr::get(getContext(),
1053 parentLayout.setUnitDimData(adjustUnitDims), getDims());
1054}
1055
1056// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
1057DistributeLayoutAttr
1058SliceAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
1059 DistributeLayoutAttr parentLayout = getParent();
1060
1061 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1062
1063 SmallVector<int64_t> adjustUnitDims =
1064 mapSlicedDimsToParentSpace(unitDims, sliceDims);
1065
1066 return SliceAttr::get(
1067 getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
1068}
1069
1070// Derive a new layout with sg_data, inst_data and lane_data set to the
1071// specified values for the given dimension
1072DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
1073 int64_t instData, int64_t laneData) {
1074 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1075 auto parent = getParent();
1076
1077 SmallVector<int64_t> dimSet;
1078 dimSet.push_back(dim);
1079 SmallVector<int64_t> adjustDims =
1080 mapSlicedDimsToParentSpace(dimSet, sliceDims);
1081 return SliceAttr::get(
1082 getContext(),
1083 parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
1084}
1085
1086// Derive a new layout by removing dimensions. `dimGroup` specifies a group of
1087// dimensions to be removed in the derived layout.
1088//
1089// Example: drop the 2nd dimension from a rank-3 sliced view.
1090//
1091// Suppose:
1092// xegpu.layout = slice<layout<[V0, V1, V2, V3, V4]>, [1, 3]>
1093//
1094// The slice removes parent dims [1, 3], so the sliced-space dims map to
1095// parent dims [V0, V2, V4].
1096//
1097// If we drop sliced-space dim 1 (the 2nd dim), that corresponds to dropping
1098// parent dim 2, result in parent layout [V0, V1, V3, V4] after dropping.
1099// After parent dim 2 is removed, sliced dims [1, 3] must be reindexed to [1,
1100// 2].
1101//
1102// Result:
1103// xegpu.layout = slice<layout<[0, 1, 3, 4]>, [1, 2]>
1104DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
1105 // Map the sliced dims from parent space to collapsed space
1106 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1107 SmallVector<int64_t> dimsInParentSpace =
1108 mapSlicedDimsToParentSpace(dimGroup, sliceDims);
1109
1110 auto droppedParent = getParent().dropDims(dimsInParentSpace);
1111
1112 // Adjust the sliced dims after dropping dims in parent space. For example, if
1113 // we drop dim 2 in parent space, the dims after dim 2 will all be shifted by
1114 // 1, so sliced dim 3 will be adjusted to 2.
1115 SmallVector<int64_t> newSliceDims;
1116 for (int64_t d : sliceDims) {
1117 int64_t offset =
1118 llvm::count_if(dimsInParentSpace, [&](int64_t s) { return s < d; });
1119 newSliceDims.push_back(d - offset);
1120 }
1121
1122 return SliceAttr::get(getContext(), droppedParent,
1123 DenseI64ArrayAttr::get(getContext(), newSliceDims));
1124}
1125
1126// Derive a new layout by collapsing dimensions.
1127// `dimGroup` specifies a group of adjacent dimensions
1128// that are collapsed into a single dimension in the derived layout.
1129DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
1130
1131 // Map the sliced dims from parent space to collapsed space
1132 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1133 assert("expect sliceDims not being collapsed" &&
1134 llvm::none_of(dimGroup, [&](int64_t dim) {
1135 return llvm::is_contained(sliceDims, dim);
1136 }));
1137 SmallVector<int64_t> dimsInParentSpace =
1138 mapSlicedDimsToParentSpace(dimGroup, sliceDims);
1139
1140 auto collapsedParent = getParent().collapseDims(dimsInParentSpace);
1141 return SliceAttr::get(getContext(), collapsedParent,
1142 DenseI64ArrayAttr::get(getContext(), sliceDims));
1143}
1144
1146 ArrayRef<int64_t> permutation) {
1147 SmallVector<int64_t> sortedSliceDims = llvm::to_vector(sliceDims);
1148 llvm::sort(sortedSliceDims);
1149
1150 for (size_t i = 1; i < sortedSliceDims.size(); ++i) {
1151 assert((sortedSliceDims[i] == sortedSliceDims[i - 1] + 1) &&
1152 "slice dims non consecutive, cannot be transposed");
1153 }
1154
1155 SmallVector<int64_t> permForParent;
1156 if (sortedSliceDims.front() == 0) {
1157 // Example: sliceDims.size() = 2, permutation= {1, 0}
1158 // result: {3, 2, 1, 0}.
1159 for (int64_t dim : permutation)
1160 permForParent.push_back(dim + sortedSliceDims.size());
1161 for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
1162 permForParent.push_back(i);
1163 } else {
1164 // Example: sliceDims.size() = 2, permutation = {0, 1}
1165 // result: {3, 2, 0, 1}.
1166 for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
1167 permForParent.push_back(i + permutation.size());
1168 for (int64_t dim : permutation)
1169 permForParent.push_back(dim);
1170 }
1171 return permForParent;
1172}
1173
1174// Derive a new layout by transpose the layout using `permutation`.
1175DistributeLayoutAttr SliceAttr::transposeDims(ArrayRef<int64_t> permutation) {
1176 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1177 DistributeLayoutAttr parent = getParent();
1178 SmallVector<int64_t> permForParent =
1179 getPermForParentLayout(sliceDims, permutation);
1180 auto transposedParent = parent.transposeDims(permForParent);
1181 return SliceAttr::get(getContext(), transposedParent,
1182 DenseI64ArrayAttr::get(getContext(), sliceDims));
1183}
1184
1185/// Check if this layout is a transpose of another layout.
1186bool SliceAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
1187 ArrayRef<int64_t> perm,
1188 const xegpu::LayoutKind kind) {
1189 // other must be a SliceAttr with the same slice dims.
1190 auto otherSlice = dyn_cast<xegpu::SliceAttr>(other);
1191 if (!otherSlice || getDims() != otherSlice.getDims())
1192 return false;
1193 // check whether the parent layout is transpose of each other.
1194 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1195 DistributeLayoutAttr parent = getParent();
1196 SmallVector<int64_t> permForParent = getPermForParentLayout(sliceDims, perm);
1197 auto otherParent = otherSlice.getParent();
1198 return parent.isTransposeOf(otherParent, permForParent, kind);
1199}
1200
1201//===----------------------------------------------------------------------===//
1202// XeGPU_RangeAttr
1203//===----------------------------------------------------------------------===//
1204
1205LogicalResult
1206RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1207 IntegerAttr startOfRange, IntegerAttr endOfRange) {
1208 if (startOfRange.getInt() >= endOfRange.getInt())
1209 return emitError() << "'end' : " << endOfRange.getInt()
1210 << " must be greater than 'start' : "
1211 << startOfRange.getInt();
1212
1213 return success();
1214}
1215
1216//===----------------------------------------------------------------------===//
1217// XeGPU_TensorDescType
1218//===----------------------------------------------------------------------===//
1219
1220mlir::Type TensorDescType::parse(AsmParser &parser) {
1221 llvm::SmallVector<int64_t> shape;
1222 mlir::Type elementType;
1223 mlir::FailureOr<mlir::Attribute> encoding;
1224 mlir::FailureOr<mlir::Attribute> layout;
1225
1226 // Parse literal '<'
1227 if (parser.parseLess())
1228 return {};
1229
1230 auto shapeLoc = parser.getCurrentLocation();
1231 if (mlir::failed(parser.parseDimensionList(shape))) {
1232 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
1233 return {};
1234 }
1235
1236 auto elemTypeLoc = parser.getCurrentLocation();
1237 if (mlir::failed(parser.parseType(elementType))) {
1238 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
1239 return {};
1240 }
1241
1242 // parse optional attributes
1243 while (mlir::succeeded(parser.parseOptionalComma())) {
1244 mlir::Attribute attr;
1245 ParseResult res = parser.parseAttribute(attr);
1246 if (mlir::succeeded(res)) {
1247 if (mlir::isa<DistributeLayoutAttr>(attr)) {
1248 layout = attr;
1249 continue;
1250 }
1251 if (mlir::isa<BlockTensorDescAttr>(attr)) {
1252 encoding = attr;
1253 continue;
1254 }
1255 }
1256 return {};
1257 }
1258
1259 // Parse literal '>'
1260 if (parser.parseGreater())
1261 return {};
1262
1263 MLIRContext *ctxt = parser.getContext();
1264 return TensorDescType::getChecked(
1265 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
1266 elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
1267 layout.value_or(mlir::Attribute()));
1268}
1269
1270void TensorDescType::print(AsmPrinter &printer) const {
1271 printer << "<";
1272
1273 auto shape = getShape();
1274 for (int64_t dim : shape) {
1275 if (mlir::ShapedType::isDynamic(dim))
1276 printer << '?';
1277 else
1278 printer << dim;
1279 printer << 'x';
1280 }
1281
1282 printer << getElementType();
1283
1284 auto encoding = getEncoding();
1285 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1286 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
1287 printer << ", " << encoding;
1288
1289 if (auto layout = getLayout())
1290 printer << ", " << layout;
1291
1292 printer << ">";
1293}
1294
1295TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
1296 mlir::Type elementType, int array_length,
1297 bool boundary_check,
1298 MemorySpace memory_space,
1299 mlir::Attribute layout) {
1300 auto *context = elementType.getContext();
1301 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
1302 boundary_check);
1303 return Base::get(context, shape, elementType, attr, layout);
1304}
1305
1306LogicalResult
1307TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
1308 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
1309 mlir::Attribute encoding, mlir::Attribute layout) {
1310 size_t rank = shape.size();
1311
1312 if (rank == 0)
1313 return emitError() << "expected non-zero rank tensor";
1314
1315 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1316 if (blockAttr) {
1317 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
1318 if (rank > 1 && memorySpaceAttr &&
1319 memorySpaceAttr.getValue() == MemorySpace::SLM)
1320 return emitError() << "SLM is only supported for 1D block tensor";
1321 }
1322
1323 if (!elementType.isIntOrFloat())
1324 return emitError() << "unsupported element type " << elementType
1325 << ": expected integer or float";
1326
1327 if (auto layoutAttr =
1328 mlir::dyn_cast_if_present<DistributeLayoutAttr>(layout)) {
1329 if (rank != (size_t)layoutAttr.getRank())
1330 return emitError() << "expected layout rank to match tensor rank";
1331
1332 if (!layoutAttr.isDistributable(SmallVector<int64_t>(shape))) {
1333 std::string shapeStr;
1334 llvm::raw_string_ostream stream(shapeStr);
1335 llvm::interleaveComma(shape, stream);
1336 return emitError() << "cannot distribute [" << shapeStr << "] using "
1337 << layoutAttr;
1338 }
1339 }
1340
1341 return success();
1342}
1343
1344//===----------------------------------------------------------------------===//
1345// XeGPU_MemDescType
1346//===----------------------------------------------------------------------===//
1347mlir::Type MemDescType::parse(AsmParser &parser) {
1348 llvm::SmallVector<int64_t> shape;
1349 mlir::Type elementType;
1350 mlir::FailureOr<MemLayoutAttr> layout;
1351
1352 // Parse literal '<'
1353 if (parser.parseLess())
1354 return {};
1355
1356 auto shapeLoc = parser.getCurrentLocation();
1357 if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
1358 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
1359 return {};
1360 }
1361
1362 auto elemTypeLoc = parser.getCurrentLocation();
1363 if (mlir::failed(parser.parseType(elementType))) {
1364 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
1365 return {};
1366 }
1367
1368 // parse optional attributes
1369 if (mlir::succeeded(parser.parseOptionalComma())) {
1370 MemLayoutAttr attr;
1371 ParseResult res = parser.parseAttribute(attr);
1372 if (mlir::failed(res))
1373 return {};
1374 layout = attr;
1375 }
1376
1377 // Parse literal '>'
1378 if (parser.parseGreater())
1379 return {};
1380
1381 MLIRContext *ctxt = parser.getContext();
1382 return MemDescType::getChecked(
1383 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
1384 elementType, layout.value_or(MemLayoutAttr()));
1385}
1386
1387void MemDescType::print(AsmPrinter &printer) const {
1388 printer << "<";
1389
1390 printer.printDimensionList(getShape());
1391 printer << 'x';
1392 printer << getElementType();
1393
1394 if (auto layout = getMemLayout())
1395 printer << ", " << layout;
1396
1397 printer << ">";
1398}
1399
1400//===----------------------------------------------------------------------===//
1401// XeGPU_MemDescType
1402//===----------------------------------------------------------------------===//
1403
1404Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
1405
1406 auto *context = parser.getContext();
1407 llvm::SMLoc loc = parser.getCurrentLocation();
1408
1409 llvm::SmallDenseSet<StringRef> seenKeys;
1410 SmallVector<NamedAttribute> attributes;
1411
1412 auto parseElt = [&]() -> ParseResult {
1413 StringRef nameId;
1414 if (failed(parser.parseKeyword(&nameId)))
1415 return parser.emitError(loc, "expected valid attribute name");
1416
1417 if (!seenKeys.insert(nameId).second)
1418 return parser.emitError(loc, "duplicate key '")
1419 << nameId << " in mem layout attribute";
1420
1421 if (failed(parser.parseEqual()))
1422 return failure();
1423
1424 Attribute attr;
1425 if (failed(parser.parseAttribute(attr)))
1426 return failure();
1427 attributes.emplace_back(nameId, attr);
1428 return success();
1429 };
1430
1431 // Parse literal '<'
1432 if (parser.parseLess())
1433 return {};
1434
1435 if (failed(parser.parseCommaSeparatedList(parseElt)))
1436 return {};
1437
1438 // Parse literal '>'
1439 if (parser.parseGreater())
1440 return {};
1441
1442 return parser.getChecked<MemLayoutAttr>(
1443 loc, context, DictionaryAttr::get(context, attributes));
1444}
1445
1446void MemLayoutAttr::print(AsmPrinter &printer) const {
1447 printer << "<";
1448 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
1449 for (size_t i = 0; i < attrs.size(); i++) {
1450 printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
1451 if (i < attrs.size() - 1)
1452 printer << ", ";
1453 }
1454 printer << ">";
1455}
1456// a helper utility to perform binary operation on OpFoldResult.
1457// If both a and b are attributes, it will simply return the result.
1458// Otherwise, the corresponding arith op will be generated, and an
1459// contant op will be created if one of them is an attribute.
1460template <typename ArithOp>
1462 OpBuilder &builder) {
1463 auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
1464 auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
1465 return ArithOp::create(builder, loc, aVal, bVal).getResult();
1466}
1467
1468// a helper utility to perform division operation on OpFoldResult and int64_t.
1469#define div(a, b) \
1470 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
1471
1472// a helper utility to perform reminder operation on OpFoldResult and int64_t.
1473#define rem(a, b) \
1474 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
1475
1476// a helper utility to perform multiply operation on OpFoldResult and int64_t.
1477#define mul(a, b) \
1478 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
1479
1480// a helper utility to perform addition operation on two OpFoldResult.
1481#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
1482
1483// block the given offsets according to the block shape
1484// say the original offset is [y, x], and the block shape is [By, Bx],
1485// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1487 ArrayRef<OpFoldResult> offsets,
1488 ArrayRef<int64_t> blockShape) {
1489
1490 assert(offsets.size() == blockShape.size() &&
1491 "offsets and blockShape must have the same size");
1492 SmallVector<OpFoldResult> blockedOffsets;
1493 SmallVector<OpFoldResult> divs, rems;
1494
1495 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1496 divs.push_back(div(offset, block));
1497 rems.push_back(rem(offset, block));
1498 }
1499 blockedOffsets.append(divs.begin(), divs.end());
1500 blockedOffsets.append(rems.begin(), rems.end());
1501
1502 return blockedOffsets;
1503}
1504
1505// Get strides as vector of integer for MemDesc.
1506SmallVector<int64_t> MemDescType::getStrideShape() {
1507
1508 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1509
1510 ArrayAttr strideAttr = getStrideAttr();
1511 SmallVector<int64_t> strides;
1512 for (Attribute attr : strideAttr.getValue()) {
1513 strides.push_back(cast<IntegerAttr>(attr).getInt());
1514 }
1515
1516 SmallVector<int64_t> innerBlkShape = getBlockShape();
1517
1518 // get perm from FCD to LCD
1519 // perm[i] = the dim with i-th smallest stride
1520 SmallVector<int, 4> perm =
1521 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
1522 llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
1523
1524 assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
1525
1526 SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
1527 innerBlkStride[perm[0]] = 1;
1528 for (size_t i = 1; i < perm.size(); ++i)
1529 innerBlkStride[perm[i]] =
1530 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
1531
1532 // compute the original matrix shape using the stride info
1533 // and compute the number of blocks in each dimension
1534 // The shape of highest dim can't be derived from stride info,
1535 // but doesn't impact the stride computation for blocked layout.
1536 SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
1537 SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
1538 for (size_t i = 0; i < perm.size() - 1; ++i) {
1539 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
1540 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
1541 }
1542
1543 int64_t innerBlkSize = 1;
1544 for (auto s : innerBlkShape)
1545 innerBlkSize *= s;
1546
1547 SmallVector<int64_t> outerBlkStride(matrixShape.size());
1548 outerBlkStride[perm[0]] = innerBlkSize;
1549 for (size_t i = 0; i < perm.size() - 1; ++i) {
1550 outerBlkStride[perm[i + 1]] =
1551 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
1552 }
1553
1554 // combine the inner and outer strides
1555 SmallVector<int64_t> blockedStrides;
1556 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
1557 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
1558
1559 return blockedStrides;
1560}
1561
1562// Calculate the linear offset using the blocked offsets and stride
1563Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
1564 ArrayRef<OpFoldResult> offsets) {
1565
1566 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1567 SmallVector<int64_t> blockShape = getBlockShape();
1568 SmallVector<int64_t> strides = getStrideShape();
1569 SmallVector<OpFoldResult> blockedOffsets;
1570
1571 // blockshape equal to matrixshape means no blocking
1572 if (llvm::equal(blockShape, matrixShape)) {
1573 // remove the outer dims from strides
1574 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
1575 } else {
1576 assert(offsets.size() == blockShape.size() &&
1577 "offsets and blockShape must have the same size");
1578 // say the original offset is [y, x], and the block shape is [By, Bx],
1579 // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1580
1581 SmallVector<OpFoldResult> divs, rems;
1582
1583 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1584 divs.push_back(div(offset, block));
1585 rems.push_back(rem(offset, block));
1586 }
1587 blockedOffsets.append(divs.begin(), divs.end());
1588 blockedOffsets.append(rems.begin(), rems.end());
1589 offsets = blockedOffsets;
1590 }
1591
1592 // Start with initial value as matrix descriptor's base offset.
1593 Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
1594 for (size_t i = 0; i < offsets.size(); ++i) {
1595 OpFoldResult mulResult = mul(offsets[i], strides[i]);
1596 Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
1597 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
1598 }
1599
1600 return linearOffset;
1601}
1602
1603} // namespace xegpu
1604} // namespace mlir
1605
1606#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
1607#define GET_ATTRDEF_CLASSES
1608#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
1609#define GET_TYPEDEF_CLASSES
1610#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
return success()
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define div(a, b)
#define rem(a, b)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
static SmallVector< SmallVector< int64_t > > genStaticCoordinates(llvm::ArrayRef< int64_t > canonicalIds, llvm::ArrayRef< int64_t > layout, llvm::ArrayRef< int64_t > subShape, llvm::ArrayRef< int64_t > shape)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
static SmallVector< int64_t > mapSlicedDimsToParentSpace(const SmallVector< int64_t > &dimsToMap, ArrayRef< int64_t > sliceDims)
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
static SmallVector< SmallVector< Value > > genCoordinates(OpBuilder &builder, Location loc, SmallVector< Value > delinearizedId, ArrayRef< int64_t > subShapesLayout, ArrayRef< int64_t > subShape, ArrayRef< int64_t > srcShape)
SmallVector< int64_t > getPermForParentLayout(ArrayRef< int64_t > sliceDims, ArrayRef< int64_t > permutation)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.