MLIR 23.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Builders.h"
16#include "llvm/ADT/SmallVectorExtras.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/Debug.h"
19
20using std::optional;
21
22namespace mlir {
23namespace xegpu {
24
25void XeGPUDialect::initialize() {
26 addTypes<
27#define GET_TYPEDEF_LIST
28#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
29 >();
30 addOperations<
31#define GET_OP_LIST
32#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
33 >();
34 addAttributes<
35#define GET_ATTRDEF_LIST
36#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
37 >();
38}
39#define GET_OP_INTERFACE_CLASSES
40#include "mlir/Dialect/XeGPU/IR/XeGPUOpInterface.cpp.inc"
41
42// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
43// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
44// within each distribution unit.
45// Example:
46// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
47// distribution unit of shape 64x64, we have 2x4 such distribution units.
48// `delinearizedId` is used to identify a 16x32 of a subgroup in each
49// distribution unit.
52 SmallVector<Value> delinearizedId,
53 ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
54 ArrayRef<int64_t> srcShape) {
56
57 // A distribution unit must be less than or equal to `srcShape`
58 SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
59 llvm::zip_equal(srcShape,
60 computeElementwiseMul(subShapesLayout, subShape)),
61 [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
62
63 // Get the offset of `subShape` within a distribution unit.
64 SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
65 llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
66 return builder.createOrFold<arith::MulIOp>(
67 loc, std::get<0>(t),
68 builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
69 });
70
71 // For each dist unit
72 for (SmallVector<int64_t> unitOffs :
73 StaticTileOffsetRange(srcShape, distUnitShape)) {
74 // Get dist unit offset within `srcShape`.
76 llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
77 return arith::ConstantIndexOp::create(builder, loc, d);
78 });
79 // Calculate `subShape` offset within `srcShape`.
81 llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
82 [&](const auto &t) -> Value {
83 return builder.createOrFold<arith::AddIOp>(
84 loc, std::get<0>(t), std::get<1>(t));
85 });
86 // Do not go beyond `srcShape` bounds.
87 SmallVector<Value> mods = llvm::map_to_vector(
88 llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
89 return builder.createOrFold<arith::RemUIOp>(
90 loc, std::get<0>(t),
91 arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
92 });
93
94 coordinates.push_back(mods);
95 }
96 return coordinates;
97}
98
102 // Compute distribution unit shape (clamped to srcShape).
103 SmallVector<int64_t> distUnitShape(shape.size());
104 for (size_t i = 0; i < shape.size(); ++i)
105 distUnitShape[i] = std::min(shape[i], layout[i] * subShape[i]);
106
107 // Compute local offset of this ID within a distribution unit.
108 SmallVector<int64_t> localOffset(shape.size());
109 for (size_t i = 0; i < shape.size(); ++i)
110 localOffset[i] = canonicalIds[i] * subShape[i];
111
112 // Enumerate all distribution units and compute coordinates.
114 for (SmallVector<int64_t> unitOffs :
115 StaticTileOffsetRange(shape, distUnitShape)) {
116 SmallVector<int64_t> coord(shape.size());
117 for (size_t i = 0; i < shape.size(); ++i)
118 coord[i] = (unitOffs[i] + localOffset[i]) % shape[i];
119 coordinates.push_back(coord);
120 }
121 return coordinates;
122}
123
124//===----------------------------------------------------------------------===//
125// XeGPU_BlockTensorDescAttr
126//===----------------------------------------------------------------------===//
127BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
128 xegpu::MemorySpace memory_space,
129 int array_length,
130 bool boundary_check) {
131 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
132 auto lengthAttr =
133 IntegerAttr::get(IntegerType::get(context, 64), array_length);
134 auto boundaryAttr = BoolAttr::get(context, boundary_check);
135 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
136}
137
138bool BlockTensorDescAttr::hasDefaultsOnly() {
139 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
140 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
141}
142
143//===----------------------------------------------------------------------===//
144// XeGPU_LayoutAttr
145//===----------------------------------------------------------------------===//
146LogicalResult
147LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
148 DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
149 DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
150 DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
151
152 // Special case for store_matrix
153 if (!sg_layout && !inst_data && !lane_layout)
154 return success();
155
156 // generate code to check sg_laout, inst_data and lane_layout having the same
157 // rank if they are not null.
158
159 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
160 return emitError()
161 << "expected sg_layout and inst_data to have the same rank";
162 }
163
164 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
165 return emitError()
166 << "expected sg_layout and lane_layout to have the same rank";
167 }
168
169 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
170 return emitError() << "expected inst_data and lane_layout to have the same "
171 "rank, got inst_data "
172 << inst_data.size() << ", lane_layout "
173 << lane_layout.size();
174 }
175
176 if ((sg_layout && !sg_data) || (!sg_layout && sg_data))
177 return emitError() << "sg_layout and sg_data must be used together";
178 if (sg_layout && sg_data && sg_layout.size() != sg_data.size())
179 return emitError()
180 << "expected sg_data and sg_layout to have the same rank";
181
182 if ((lane_layout && !lane_data) || (!lane_layout && lane_data))
183 return emitError() << "lane_layout and lane_data must be used together";
184 if (lane_layout && lane_data && lane_layout.size() != lane_data.size())
185 return emitError()
186 << "expected lane_data and lane_layout to have the same rank";
187
188 if (order) {
189 if (!sg_layout && !lane_layout)
190 return emitError()
191 << "expected sg_layout/lane_layout being used with order";
192
193 if (sg_layout && order.size() != sg_layout.size())
194 return emitError()
195 << "expected order and sg_layout to have the same rank";
196
197 if (lane_layout && order.size() != lane_layout.size())
198 return emitError()
199 << "expected order and lane_layout to have the same rank";
200 }
201
202 return success();
203}
204
205FailureOr<SmallVector<Value>>
206LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
207
208 SmallVector<int64_t> sgLayoutInt;
209 if (isForWorkgroup()) {
210 sgLayoutInt = getEffectiveSgLayoutAsInt();
211 } else if (isForSubgroup()) {
212 sgLayoutInt = getEffectiveLaneLayoutAsInt();
213 } else {
214 return failure();
215 }
216
217 DenseI32ArrayAttr orderAttr = getOrder();
218
219 // Handle order attribute
220 SmallVector<int64_t> order;
221 if (orderAttr && !orderAttr.empty()) {
222 order = llvm::map_to_vector(orderAttr.asArrayRef(), [](int32_t idx) {
223 return static_cast<int64_t>(idx);
224 });
225 } else {
226 // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
227 order = llvm::to_vector(
228 llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
229 }
230
231 if (order.size() != sgLayoutInt.size()) {
232 return failure();
233 }
234
235 SmallVector<Value> result(sgLayoutInt.size());
236 Value remaining = linearId;
237
238 /// Process dimensions in the order they appear in the order array
239 /// The first dimension in order is the fastest-changing
240 ///
241 /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
242 ///
243 /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
244 /// result=[?,?,?]
245 ///
246 /// i=0 (process columns, dimIdx=2, dimSize=4):
247 /// result[2] = 22 % 4 = 2 (column coordinate)
248 /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
249 ///
250 /// i=1 (process rows, dimIdx=1, dimSize=4):
251 /// result[1] = 5 % 4 = 1 (row coordinate)
252 /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
253 ///
254 /// i=2 (process layers, dimIdx=0, dimSize=2):
255 /// result[0] = 1 % 2 = 1 (layer coordinate)
256 /// (no remaining update - last iteration)
257 ///
258 /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
259 for (size_t i = 0; i < order.size(); ++i) {
260 int64_t dimIdx = order[i];
261 int64_t dimSize = sgLayoutInt[dimIdx];
262
263 Value dimSizeVal =
264 builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
265
266 /// Extract the coordinate for this dimension using modulo operation
267 /// This gives us "how far within this dimension" we are
268 /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
269 /// this dimension)
270 result[dimIdx] =
271 builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
272
273 /// Update remaining for the next dimension by removing what we've already
274 /// processed. Division tells us "how many complete groups of this dimension
275 /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
276 /// completed 5 groups of 4) Skip this for the last iteration since there's
277 /// no next dimension to process
278 if (i < order.size() - 1) {
279 remaining =
280 builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
281 }
282 }
283 return result;
284}
285
286/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
287/// instructions for computing multi-dimensional offsets when distributed by
288/// LayoutAttr.
289FailureOr<SmallVector<SmallVector<Value>>>
290LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
291 Value linearId, ArrayRef<int64_t> shape) {
292 SmallVector<int64_t> layout;
293 SmallVector<int64_t> subShape;
294 if (isForWorkgroup()) {
295 layout = getEffectiveSgLayoutAsInt();
296 subShape = getEffectiveSgDataAsInt();
297 } else if (isForSubgroup()) {
298 layout = getEffectiveLaneLayoutAsInt();
299 subShape = getEffectiveLaneDataAsInt();
300 } else {
301 return failure();
302 }
303 assert(!subShape.empty() && "sgdata or lanedata cannot be empty for "
304 "distributed coordinates computation");
305
306 // delinearize Ids
307 auto maybeIds = delinearizeId(builder, loc, linearId);
308 if (failed(maybeIds))
309 return failure();
310 SmallVector<Value> ids = *maybeIds;
311
312 return genCoordinates(builder, loc, ids, layout, subShape, shape);
313}
314
315bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
316 if (dyn_cast<xegpu::SliceAttr>(other))
317 return false;
318
319 return *this == dyn_cast<xegpu::LayoutAttr>(other);
320}
321
322/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
323/// compute multi-dimensional offsets for a given linear ID when distributed by
324/// LayoutAttr.
325SmallVector<SmallVector<int64_t>>
326LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
327 ArrayRef<int64_t> shape) {
328 SmallVector<int64_t> layoutVec;
329 SmallVector<int64_t> subShape;
330 SmallVector<int64_t> instData;
331 if (isForWorkgroup()) {
332 layoutVec = getEffectiveSgLayoutAsInt();
333 subShape = getEffectiveSgDataAsInt();
334 } else if (isForSubgroup()) {
335 instData = getEffectiveInstDataAsInt();
336 layoutVec = getEffectiveLaneLayoutAsInt();
337 subShape = getEffectiveLaneDataAsInt();
338 }
339 if (!instData.empty()) {
340 linearId = 0;
341 subShape = instData;
342 }
343 assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
344
345 // Delinearize the linear ID using the order attribute.
346 SmallVector<int64_t> order = getEffectiveOrderAsInt();
347 SmallVector<int64_t> delinearizedId(layoutVec.size());
348 int64_t remaining = linearId;
349 for (size_t i = 0; i < order.size(); ++i) {
350 int64_t dimIdx = order[i];
351 delinearizedId[dimIdx] = remaining % layoutVec[dimIdx];
352 remaining = remaining / layoutVec[dimIdx];
353 }
354
355 return genStaticCoordinates(delinearizedId, layoutVec, subShape, shape);
356}
357
358// set the layout for unit dims: sg_data, inst_data and lane_data to 1
359DistributeLayoutAttr
360LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
361 auto sgDataOpt = getSgData();
362 auto instDataOpt = getInstData();
363 auto laneDataOpt = getLaneData();
364
365 SmallVector<int32_t> sgData;
366 SmallVector<int32_t> instData;
367 SmallVector<int32_t> laneData;
368
369 if (sgDataOpt)
370 sgData = llvm::to_vector(sgDataOpt.asArrayRef());
371
372 if (instDataOpt)
373 instData = llvm::to_vector(instDataOpt.asArrayRef());
374
375 if (laneDataOpt)
376 laneData = llvm::to_vector(laneDataOpt.asArrayRef());
377
378 for (auto dim : unitDims) {
379 if (dim < static_cast<int64_t>(sgData.size()))
380 sgData[dim] = 1;
381 if (dim < static_cast<int64_t>(instData.size()))
382 instData[dim] = 1;
383 if (dim < static_cast<int64_t>(laneData.size()))
384 laneData[dim] = 1;
385 }
386
387 return LayoutAttr::get(
388 getContext(), getSgLayout(),
389 sgData.empty() ? DenseI32ArrayAttr()
391 instData.empty() ? DenseI32ArrayAttr()
392 : DenseI32ArrayAttr::get(getContext(), instData),
393 getLaneLayout(),
394 laneData.empty() ? DenseI32ArrayAttr()
395 : DenseI32ArrayAttr::get(getContext(), laneData),
396 getOrder());
397}
398
399// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
400DistributeLayoutAttr
401LayoutAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
402 auto sgLayoutOpt = getSgLayout();
403 auto laneLayoutOpt = getLaneLayout();
404
405 SmallVector<int32_t> sgLayout;
406 SmallVector<int32_t> laneLayout;
407
408 if (sgLayoutOpt)
409 sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
410 if (laneLayoutOpt)
411 laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
412
413 for (auto dim : unitDims) {
414 if (dim < static_cast<int64_t>(sgLayout.size()))
415 sgLayout[dim] = 1;
416 if (dim < static_cast<int64_t>(laneLayout.size()))
417 laneLayout[dim] = 1;
418 }
419
420 return LayoutAttr::get(
421 getContext(),
422 sgLayout.empty() ? DenseI32ArrayAttr()
423 : DenseI32ArrayAttr::get(getContext(), sgLayout),
424 getSgData(), getInstData(),
425 laneLayout.empty() ? DenseI32ArrayAttr()
426 : DenseI32ArrayAttr::get(getContext(), laneLayout),
427 getLaneData(), getOrder());
428}
429
430// Derive a new layout with sg_data, inst_data and lane_data set to the
431// specified values for the given dimension
432DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
433 int64_t instData,
434 int64_t laneData) {
435
436 SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
437 SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
438 SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
439
440 if (dim < static_cast<int64_t>(sgDataVec.size()) && sgData != -1)
441 sgDataVec[dim] = sgData;
442 if (dim < static_cast<int64_t>(instDataVec.size()) && instData != -1)
443 instDataVec[dim] = instData;
444 if (dim < static_cast<int64_t>(laneDataVec.size()) && laneData != -1)
445 laneDataVec[dim] = laneData;
446
447 SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
448 SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
449 SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
450
451 return LayoutAttr::get(
452 getContext(), getSgLayout(),
453 sgDataVec.empty() ? DenseI32ArrayAttr()
454 : DenseI32ArrayAttr::get(getContext(), sgDataVec32),
455 instDataVec.empty() ? DenseI32ArrayAttr()
456 : DenseI32ArrayAttr::get(getContext(), instDataVec32),
457 getLaneLayout(),
458 laneDataVec.empty() ? DenseI32ArrayAttr()
459 : DenseI32ArrayAttr::get(getContext(), laneDataVec32),
460 getOrder());
461}
462
463// Derive a new layout by removing dimensions.
464// `dimGroup` specifies a group of dimensions to be removed in the derived
465// layout.
466DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
467
468 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
469 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
470 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
471 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
472 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
473 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
474
475 SmallVector<int64_t> sortedDimGroup = dimGroup;
476 llvm::sort(sortedDimGroup);
477
478 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
479 if (!sgLayout.empty()) {
480 sgLayout.erase(sgLayout.begin() + dimIdx);
481 sgData.erase(sgData.begin() + dimIdx);
482 }
483 if (!instData.empty())
484 instData.erase(instData.begin() + dimIdx);
485 if (!laneLayout.empty()) {
486 laneLayout.erase(laneLayout.begin() + dimIdx);
487 laneData.erase(laneData.begin() + dimIdx);
488 }
489 }
490
491 SmallVector<int64_t> newOrder;
492 for (int64_t d : origOrder) {
493 if (llvm::is_contained(dimGroup, d))
494 continue;
495 int64_t offset = llvm::count_if(dimGroup, [&](int64_t s) { return s < d; });
496 newOrder.push_back(d - offset);
497 }
498 if ((sgLayout.empty() && laneLayout.empty()) || newOrder.size() == 1)
499 newOrder.clear();
500
501 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
502 if (v.empty())
503 return DenseI32ArrayAttr();
504 SmallVector<int32_t> v32(v.begin(), v.end());
505 return DenseI32ArrayAttr::get(getContext(), v32);
506 };
507 auto droppedLayout = xegpu::LayoutAttr::get(
508 getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
509 toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
510 return droppedLayout;
511}
512
513// Derive a new layout by collapsing dimensions.
514// `dimGroup` specifies a group of adjacent dimensions
515// that are collapsed into a single dimension in the derived layout.
516DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
517
518 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
519 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
520 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
521 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
522 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
523 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
524
525 SmallVector<int64_t> sortedDimGroup = dimGroup;
526 llvm::sort(sortedDimGroup);
527 int64_t dimBeforeCurrent = -1;
528 for (auto dimIdx : sortedDimGroup) {
529 // when order attr is present, adjacency dims are values like [3, 2, 1, 0]
530 // in decreasing order; otherwise based on dim indices like [0, 1, 2, 3]
531 // in increasing order
532 if (dimBeforeCurrent >= 0) {
533 if (getOrder() && !getOrder().empty()) {
534 int64_t orderBefore = origOrder[dimBeforeCurrent];
535 int64_t orderCurrent = origOrder[dimIdx];
536 if (orderBefore != (orderCurrent - 1))
537 llvm::report_fatal_error(
538 "dimensions being collapsed must be adjacent in order");
539 } else {
540 if (dimIdx != (dimBeforeCurrent + 1))
541 llvm::report_fatal_error(
542 "dimensions being collapsed must be adjacent");
543 }
544 }
545 dimBeforeCurrent = dimIdx;
546 }
547
548 int firstDim = sortedDimGroup.front();
549
550 // collapse the dimensions in dimGroup into one dimension by multiplying their
551 // sizes together
552
553 if (!sgLayout.empty()) {
554 int64_t collapsedSglayout = 1, collapsedSgData = 1;
555 for (auto dimIdx : dimGroup) {
556 collapsedSglayout *= sgLayout[dimIdx];
557 collapsedSgData *= sgData[dimIdx];
558 }
559 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
560 sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
561 sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
562 }
563 sgLayout.insert(sgLayout.begin() + firstDim, collapsedSglayout);
564 sgData.insert(sgData.begin() + firstDim, collapsedSgData);
565 }
566
567 if (!instData.empty()) {
568 int64_t collapsedInstData = 1;
569 for (auto dimIdx : dimGroup)
570 collapsedInstData *= instData[dimIdx];
571 for (auto dimIdx : llvm::reverse(sortedDimGroup))
572 instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
573 instData.insert(instData.begin() + firstDim, collapsedInstData);
574 }
575
576 if (!laneLayout.empty()) {
577 int64_t collapsedLaneLayout = 1, collapsedLaneData = 1;
578 for (auto dimIdx : dimGroup) {
579 collapsedLaneLayout *= laneLayout[dimIdx];
580 collapsedLaneData *= laneData[dimIdx];
581 }
582 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
583 laneLayout.erase(laneLayout.begin() + dimIdx,
584 laneLayout.begin() + dimIdx + 1);
585 laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
586 }
587 laneLayout.insert(laneLayout.begin() + firstDim, collapsedLaneLayout);
588 laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
589 }
590
591 SmallVector<int64_t> newOrder;
592 DenseI32ArrayAttr orderAttr = getOrder();
593 if (orderAttr && !orderAttr.empty()) {
594
595 for (auto dimIdx : llvm::reverse(sortedDimGroup)) {
596 if (dimIdx != firstDim)
597 origOrder.erase(origOrder.begin() + dimIdx);
598 }
599 // say we have orderVec = {5, 3, 2, 1, 0}
600 // Create indices [0, 1, 2, 3, 4]
601 SmallVector<size_t> indices =
602 llvm::to_vector(llvm::seq<size_t>(0, orderAttr.size()));
603
604 // Sort indices based on corresponding values
605 llvm::sort(indices,
606 [&](size_t a, size_t b) { return origOrder[a] < origOrder[b]; });
607
608 newOrder = llvm::to_vector(llvm::map_range(
609 indices, [&](size_t i) { return static_cast<int64_t>(i); }));
610 }
611
612 auto toAttr = [&](ArrayRef<int64_t> v) -> DenseI32ArrayAttr {
613 if (v.empty())
614 return DenseI32ArrayAttr();
615 SmallVector<int32_t> v32(v.begin(), v.end());
616 return DenseI32ArrayAttr::get(getContext(), v32);
617 };
618 auto collapsedLayout = xegpu::LayoutAttr::get(
619 getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
620 toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
621 return collapsedLayout;
622}
623
624// Derive a new layout by transpose the layout using `permutation`.
625DistributeLayoutAttr LayoutAttr::transposeDims(ArrayRef<int64_t> permutation) {
626
627 SmallVector<int64_t> origSgLayout = getEffectiveSgLayoutAsInt();
628 SmallVector<int64_t> origSgData = getEffectiveSgDataAsInt();
629 SmallVector<int64_t> origInstData = getEffectiveInstDataAsInt();
630 SmallVector<int64_t> origLaneLayout = getEffectiveLaneLayoutAsInt();
631 SmallVector<int64_t> origLaneData = getEffectiveLaneDataAsInt();
632 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
633
634 SmallVector<int32_t> sgLayout;
635 SmallVector<int32_t> sgData;
636 SmallVector<int32_t> instData;
637 SmallVector<int32_t> laneLayout;
638 SmallVector<int32_t> laneData;
639 SmallVector<int32_t> order;
640
641 for (int64_t idx : permutation) {
642 if (!origLaneLayout.empty()) {
643 laneLayout.push_back(static_cast<int32_t>(origLaneLayout[idx]));
644 laneData.push_back(static_cast<int32_t>(origLaneData[idx]));
645 }
646 if (!origInstData.empty())
647 instData.push_back(static_cast<int32_t>(origInstData[idx]));
648 if (!origSgLayout.empty()) {
649 sgLayout.push_back(static_cast<int32_t>(origSgLayout[idx]));
650 sgData.push_back(static_cast<int32_t>(origSgData[idx]));
651 }
652 order.push_back(static_cast<int32_t>(origOrder[idx]));
653 }
654 if (origLaneLayout.empty() && origSgLayout.empty())
655 order.clear();
656
657 auto toAttr = [&](ArrayRef<int32_t> v) -> DenseI32ArrayAttr {
658 return v.empty() ? nullptr : DenseI32ArrayAttr::get(getContext(), v);
659 };
660 return xegpu::LayoutAttr::get(getContext(), toAttr(sgLayout), toAttr(sgData),
661 toAttr(instData), toAttr(laneLayout),
662 toAttr(laneData), toAttr(order));
663}
664
665/// Check if this layout is a transpose of another layout.
666bool LayoutAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
667 ArrayRef<int64_t> perm,
668 const xegpu::LayoutKind kind) {
669 if (!other)
670 return false;
671 if (getRank() != other.getRank() ||
672 perm.size() != static_cast<size_t>(getRank()))
673 return false;
674 if (!isPermutationVector(perm))
675 return false;
676 auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src,
677 ArrayRef<int64_t> perm) {
678 for (const auto &ta : llvm::enumerate(perm)) {
679 if (src[ta.index()] != dst[ta.value()])
680 return false;
681 }
682 return true;
683 };
684 if (kind == xegpu::LayoutKind::Subgroup)
685 return checkTranspose(getEffectiveSgLayoutAsInt(),
686 other.getEffectiveSgLayoutAsInt(), perm) &&
687 checkTranspose(getEffectiveSgDataAsInt(),
688 other.getEffectiveSgDataAsInt(), perm) &&
689 checkTranspose(getEffectiveOrderAsInt(),
690 other.getEffectiveOrderAsInt(), perm);
691 if (kind == xegpu::LayoutKind::InstData)
692 return checkTranspose(getEffectiveInstDataAsInt(),
693 other.getEffectiveInstDataAsInt(), perm);
694 if (kind == xegpu::LayoutKind::Lane)
695 return checkTranspose(getEffectiveLaneLayoutAsInt(),
696 other.getEffectiveLaneLayoutAsInt(), perm) &&
697 checkTranspose(getEffectiveLaneDataAsInt(),
698 other.getEffectiveLaneDataAsInt(), perm) &&
699 checkTranspose(getEffectiveOrderAsInt(),
700 other.getEffectiveOrderAsInt(), perm);
701
702 return false;
703}
704
705bool LayoutAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
706 SmallVector<int64_t> shape,
707 xegpu::LayoutKind level) {
708 if (!other)
709 return false;
710 if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
711 // short cut when order is the same, no need to compute coords and compare
712 if (level == xegpu::LayoutKind::Subgroup)
713 if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
714 getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
715 return true;
716 if (level == xegpu::LayoutKind::Lane)
717 if (getEffectiveLaneLayoutAsInt() ==
718 other.getEffectiveLaneLayoutAsInt() &&
719 getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
720 return true;
721 }
722
723 auto compareCoordsForAllIds = [&](int64_t size) {
724 for (int64_t id : llvm::seq<int64_t>(0, size)) {
725 auto coords = computeStaticDistributedCoords(id, shape);
726 auto otherCoords = other.computeStaticDistributedCoords(id, shape);
727 if (coords != otherCoords)
728 return false;
729 }
730 return true;
731 };
732
733 if (level == xegpu::LayoutKind::Subgroup) {
734 int64_t wgSize = computeProduct(getEffectiveSgLayoutAsInt());
735 return compareCoordsForAllIds(wgSize);
736 }
737 if (level == xegpu::LayoutKind::InstData) {
738 return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
739 }
740 if (level == xegpu::LayoutKind::Lane) {
741 int64_t subgroupSize = computeProduct(getEffectiveLaneLayoutAsInt());
742 return compareCoordsForAllIds(subgroupSize);
743 }
744 return true;
745}
746
747//===----------------------------------------------------------------------===//
748// XeGPU_SliceAttr
749//===----------------------------------------------------------------------===//
750LogicalResult
751SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
752 xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
753
754 if (!dims)
755 return emitError() << "expected dims attribute";
756
757 // check every element in dims is unique and smaller than rank
758 llvm::SmallDenseSet<int64_t> seen;
759 for (int64_t dim : dims.asArrayRef()) {
760 if (dim < 0)
761 return emitError() << "invalid dim (" << dim << ") in slice attribute.";
762 if (!seen.insert(dim).second)
763 return emitError() << "repeated dim (" << dim << ") in slice attribute.";
764 }
765 return success();
766}
767
768SliceAttr SliceAttr::flatten() const {
769 xegpu::DistributeLayoutAttr parent = getParent();
770 SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
771
772 while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
773 parent = sliceAttr.getParent();
774 slicedDims.push_back(sliceAttr.getDims());
775 }
776
777 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
778 SmallVector<int64_t> indices =
779 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
780
781 // get remaining dims (flattend) by applying slice ops with all slicedDims
782 SmallVector<int64_t> remainingDims(indices);
783 for (auto dim : llvm::reverse(slicedDims))
784 remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
785 dim.asArrayRef());
786
787 // get flattend sliced dims by applying slice ops with the remaining dims
788 SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
789 llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
790
791 return xegpu::SliceAttr::get(
792 getContext(), layoutAttr,
793 DenseI64ArrayAttr::get(getContext(), flattendDims));
794}
795
796FailureOr<SmallVector<Value>>
797SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
798 SliceAttr attr = flatten();
799 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
800 return parent.delinearizeId(builder, loc, linearId);
801}
802
803// Implements DistributeLayoutAttr::computeDistributedCoords to generate
804// instructions for computing multi-dimensional offsets when distributed by
805// LayoutAttr.
806FailureOr<SmallVector<SmallVector<Value>>>
807SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
808 Value linearId, ArrayRef<int64_t> shape) {
809 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
810
811 SmallVector<int64_t> layout;
812 SmallVector<int64_t> subShape;
813 if (isForWorkgroup()) {
814 layout = getEffectiveSgLayoutAsInt();
815 subShape = getEffectiveSgDataAsInt();
816 } else if (isForSubgroup()) {
817 layout = getEffectiveLaneLayoutAsInt();
818 subShape = getEffectiveLaneDataAsInt();
819 } else {
820 return failure();
821 }
822
823 if (subShape.empty())
824 return failure();
825
826 // delinearize Ids
827 auto maybeIds = delinearizeId(builder, loc, linearId);
828 if (failed(maybeIds))
829 return failure();
830
831 // The effective sgIds for offsets computing correspond
832 // to the dims that are not sliced.
833 ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
834 SmallVector<Value> canonicalIds =
835 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
836
837 return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
838}
839
840/// Implements DistributeLayoutAttr::computeStaticDistributedCoords to
841/// compute multi-dimensional offsets for a given linear ID when distributed by
842/// SliceAttr. Delegates delinearization to the parent LayoutAttr, then uses
843/// only the non-sliced dimensions for coordinate computation.
844SmallVector<SmallVector<int64_t>>
845SliceAttr::computeStaticDistributedCoords(int64_t linearId,
846 ArrayRef<int64_t> shape) {
847 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
848
849 SmallVector<int64_t> layout;
850 SmallVector<int64_t> subShape;
851 SmallVector<int64_t> instData;
852 if (isForWorkgroup()) {
853 layout = getEffectiveSgLayoutAsInt();
854 subShape = getEffectiveSgDataAsInt();
855 } else if (isForSubgroup()) {
856 instData = getEffectiveInstDataAsInt();
857 layout = getEffectiveLaneLayoutAsInt();
858 subShape = getEffectiveLaneDataAsInt();
859 }
860 if (!instData.empty()) {
861 linearId = 0;
862 subShape = instData;
863 }
864
865 assert(!subShape.empty() && "sgdata or lanedata cannot be empty");
866
867 // Delinearize the ID using the parent layout (same as the IR version).
868 SliceAttr flattened = flatten();
869 auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
870 SmallVector<int64_t> parentLayoutVec;
871 if (parent.isForWorkgroup())
872 parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
873 else
874 parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
875
876 SmallVector<int64_t> order = parent.getEffectiveOrderAsInt();
877 SmallVector<int64_t> allIds(parentLayoutVec.size());
878 int64_t remaining = linearId;
879 for (size_t i = 0; i < order.size(); ++i) {
880 int64_t dimIdx = order[i];
881 allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
882 if (i < order.size() - 1)
883 remaining = remaining / parentLayoutVec[dimIdx];
884 }
885
886 // The effective IDs for coordinate computation correspond
887 // to the dims that are not sliced.
888 ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
889 SmallVector<int64_t> canonicalIds =
890 XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
891
892 return genStaticCoordinates(canonicalIds, layout, subShape, shape);
893}
894
895bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
896 auto flattenedThis = flatten();
897 // If other is a LayoutAttr, just compare directly with parent of
898 // flattenedThis.
899 if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
900 return flattenedThis.getParent() == otherLayout;
901 // If other is a SliceAttr, flatten it first before comparing.
902 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
903 // Both must have common parent LayoutAttr.
904 if (flattenedThis.getParent() != flattenedOther.getParent())
905 return false;
906 // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
907 // dims.
908 llvm::SmallDenseSet<int64_t> thisDims(
909 flattenedThis.getDims().asArrayRef().begin(),
910 flattenedThis.getDims().asArrayRef().end());
911 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
912 [&](int64_t dim) { return thisDims.contains(dim); });
913}
914
915bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
916 if (dyn_cast<xegpu::LayoutAttr>(other))
917 return false;
918
919 auto flattenedThis = flatten();
920 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
921
922 return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
923 (flattenedThis.getDims() == flattenedOther.getDims()));
924}
925
926bool SliceAttr::isCompatibleWith(const xegpu::DistributeLayoutAttr &other,
927 SmallVector<int64_t> shape,
928 xegpu::LayoutKind level) {
929 if (!other)
930 return false;
931 if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
932 // short cut when order is the same, no need to compute coords and compare
933 if (level == xegpu::LayoutKind::Subgroup)
934 if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
935 getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
936 return true;
937 if (level == xegpu::LayoutKind::Lane)
938 if (getEffectiveLaneLayoutAsInt() ==
939 other.getEffectiveLaneLayoutAsInt() &&
940 getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
941 return true;
942 }
943
944 auto compareCoordsForAllIds = [&](int64_t size) {
945 for (int64_t id : llvm::seq<int64_t>(0, size)) {
946 auto coords = computeStaticDistributedCoords(id, shape);
947 auto otherCoords = other.computeStaticDistributedCoords(id, shape);
948 if (coords != otherCoords)
949 return false;
950 }
951 return true;
952 };
953
954 auto flattenedThis = flatten();
955 auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
956 if (level == xegpu::LayoutKind::Subgroup) {
957 int64_t wgSize = computeProduct(parent.getEffectiveSgLayoutAsInt());
958 return compareCoordsForAllIds(wgSize);
959 }
960 if (level == xegpu::LayoutKind::InstData) {
961 return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
962 }
963 if (level == xegpu::LayoutKind::Lane) {
964 int64_t subgroupSize = computeProduct(parent.getEffectiveLaneLayoutAsInt());
965 return compareCoordsForAllIds(subgroupSize);
966 }
967 return true;
968}
969
970xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
971 if (sliceDimsToDrop.empty())
972 return *this;
973 SmallVector<int64_t> sliceDims{getDims().asArrayRef()};
974 for (auto dim : sliceDimsToDrop) {
975 auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), dim);
976 assert(foundIt != sliceDims.end() &&
977 "Expected to find the specified reduction dim in slice dims");
978 sliceDims.erase(foundIt);
979 }
980
981 auto sliceWithoutDims = xegpu::SliceAttr::get(
982 this->getContext(), getParent(),
983 DenseI64ArrayAttr::get(this->getContext(), sliceDims));
984
985 return sliceWithoutDims;
986}
987
988// Helper function to adjust dimensions from sliced space to parent space
989// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
990// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
991// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
992// it maps to dim [2] in parent space.
993static SmallVector<int64_t>
995 ArrayRef<int64_t> sliceDims) {
996 // Rather than recovering the exact parent rank, we compute a safe upper
997 // bound so that dimsToMap can be adjusted safely. This upper bound is
998 // defined as max(dimsToMap, sliceDims) + 1 + sliceDims.size().
999 int64_t maxDim = -1;
1000 maxDim =
1001 std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
1002 maxDim =
1003 std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
1004 int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
1005
1006 // get remaining dims in parent space after applying slicing with parent's
1007 // slice Dims
1008 llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
1009 sliceDims.end());
1010 SmallVector<int64_t> remainingDims;
1011 for (int64_t i = 0; i < parentSpaceRank; ++i) {
1012 if (!slicedDimsSet.contains(i))
1013 remainingDims.push_back(i);
1014 }
1015
1016 // Map unit dims from sliced space to parent space
1017 SmallVector<int64_t> adjustUnitDims;
1018 for (auto dim : dimsToMap) {
1019 int64_t mappedDim = remainingDims[dim];
1020 adjustUnitDims.push_back(mappedDim);
1021 }
1022
1023 return adjustUnitDims;
1024}
1025
1026// set the layout for unit dims: sg_data, inst_data and lane_data to 1
1027DistributeLayoutAttr
1028SliceAttr::setUnitDimData(SmallVector<int64_t> unitDims) const {
1029 DistributeLayoutAttr parentLayout = getParent();
1030
1031 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1032
1033 SmallVector<int64_t> adjustUnitDims =
1034 mapSlicedDimsToParentSpace(unitDims, sliceDims);
1035
1036 return SliceAttr::get(getContext(),
1037 parentLayout.setUnitDimData(adjustUnitDims), getDims());
1038}
1039
1040// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
1041DistributeLayoutAttr
1042SliceAttr::setUnitDimLayout(SmallVector<int64_t> unitDims) const {
1043 DistributeLayoutAttr parentLayout = getParent();
1044
1045 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1046
1047 SmallVector<int64_t> adjustUnitDims =
1048 mapSlicedDimsToParentSpace(unitDims, sliceDims);
1049
1050 return SliceAttr::get(
1051 getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
1052}
1053
1054// Derive a new layout with sg_data, inst_data and lane_data set to the
1055// specified values for the given dimension
1056DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
1057 int64_t instData, int64_t laneData) {
1058 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
1059 auto parent = getParent();
1060
1061 SmallVector<int64_t> dimSet;
1062 dimSet.push_back(dim);
1063 SmallVector<int64_t> adjustDims =
1064 mapSlicedDimsToParentSpace(dimSet, sliceDims);
1065 return SliceAttr::get(
1066 getContext(),
1067 parent.setDimData(adjustDims[0], sgData, instData, laneData), getDims());
1068}
1069
1070// Derive a new layout by removing dimensions. `dimGroup` specifies a group of
1071// dimensions to be removed in the derived layout.
1072//
1073// Example: drop the 2nd dimension from a rank-3 sliced view.
1074//
1075// Suppose:
1076// xegpu.layout = slice<layout<[V0, V1, V2, V3, V4]>, [1, 3]>
1077//
1078// The slice removes parent dims [1, 3], so the sliced-space dims map to
1079// parent dims [V0, V2, V4].
1080//
1081// If we drop sliced-space dim 1 (the 2nd dim), that corresponds to dropping
1082// parent dim 2, result in parent layout [V0, V1, V3, V4] after dropping.
1083// After parent dim 2 is removed, sliced dims [1, 3] must be reindexed to [1,
1084// 2].
1085//
1086// Result:
1087// xegpu.layout = slice<layout<[0, 1, 3, 4]>, [1, 2]>
1088DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
1089 // Map the sliced dims from parent space to collapsed space
1090 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1091 SmallVector<int64_t> dimsInParentSpace =
1092 mapSlicedDimsToParentSpace(dimGroup, sliceDims);
1093
1094 auto droppedParent = getParent().dropDims(dimsInParentSpace);
1095
1096 // Adjust the sliced dims after dropping dims in parent space. For example, if
1097 // we drop dim 2 in parent space, the dims after dim 2 will all be shifted by
1098 // 1, so sliced dim 3 will be adjusted to 2.
1099 SmallVector<int64_t> newSliceDims;
1100 for (int64_t d : sliceDims) {
1101 int64_t offset =
1102 llvm::count_if(dimsInParentSpace, [&](int64_t s) { return s < d; });
1103 newSliceDims.push_back(d - offset);
1104 }
1105
1106 return SliceAttr::get(getContext(), droppedParent,
1107 DenseI64ArrayAttr::get(getContext(), newSliceDims));
1108}
1109
1110// Derive a new layout by collapsing dimensions.
1111// `dimGroup` specifies a group of adjacent dimensions
1112// that are collapsed into a single dimension in the derived layout.
1113DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
1114
1115 // Map the sliced dims from parent space to collapsed space
1116 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1117 assert("expect sliceDims not being collapsed" &&
1118 llvm::none_of(dimGroup, [&](int64_t dim) {
1119 return llvm::is_contained(sliceDims, dim);
1120 }));
1121 SmallVector<int64_t> dimsInParentSpace =
1122 mapSlicedDimsToParentSpace(dimGroup, sliceDims);
1123
1124 auto collapsedParent = getParent().collapseDims(dimsInParentSpace);
1125 return SliceAttr::get(getContext(), collapsedParent,
1126 DenseI64ArrayAttr::get(getContext(), sliceDims));
1127}
1128
1130 ArrayRef<int64_t> permutation) {
1131 SmallVector<int64_t> sortedSliceDims = llvm::to_vector(sliceDims);
1132 llvm::sort(sortedSliceDims);
1133
1134 for (size_t i = 1; i < sortedSliceDims.size(); ++i) {
1135 assert((sortedSliceDims[i] == sortedSliceDims[i - 1] + 1) &&
1136 "slice dims non consecutive, cannot be transposed");
1137 }
1138
1139 SmallVector<int64_t> permForParent;
1140 if (sortedSliceDims.front() == 0) {
1141 // Example: sliceDims.size() = 2, permutation= {1, 0}
1142 // result: {3, 2, 1, 0}.
1143 for (int64_t dim : permutation)
1144 permForParent.push_back(dim + sortedSliceDims.size());
1145 for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
1146 permForParent.push_back(i);
1147 } else {
1148 // Example: sliceDims.size() = 2, permutation = {0, 1}
1149 // result: {3, 2, 0, 1}.
1150 for (int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
1151 permForParent.push_back(i + permutation.size());
1152 for (int64_t dim : permutation)
1153 permForParent.push_back(dim);
1154 }
1155 return permForParent;
1156}
1157
1158// Derive a new layout by transpose the layout using `permutation`.
1159DistributeLayoutAttr SliceAttr::transposeDims(ArrayRef<int64_t> permutation) {
1160 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1161 DistributeLayoutAttr parent = getParent();
1162 SmallVector<int64_t> permForParent =
1163 getPermForParentLayout(sliceDims, permutation);
1164 auto transposedParent = parent.transposeDims(permForParent);
1165 return SliceAttr::get(getContext(), transposedParent,
1166 DenseI64ArrayAttr::get(getContext(), sliceDims));
1167}
1168
1169/// Check if this layout is a transpose of another layout.
1170bool SliceAttr::isTransposeOf(const xegpu::DistributeLayoutAttr &other,
1171 ArrayRef<int64_t> perm,
1172 const xegpu::LayoutKind kind) {
1173 // other must be a SliceAttr with the same slice dims.
1174 auto otherSlice = dyn_cast<xegpu::SliceAttr>(other);
1175 if (!otherSlice || getDims() != otherSlice.getDims())
1176 return false;
1177 // check whether the parent layout is transpose of each other.
1178 SmallVector<int64_t> sliceDims = llvm::to_vector(getDims().asArrayRef());
1179 DistributeLayoutAttr parent = getParent();
1180 SmallVector<int64_t> permForParent = getPermForParentLayout(sliceDims, perm);
1181 auto otherParent = otherSlice.getParent();
1182 return parent.isTransposeOf(otherParent, permForParent, kind);
1183}
1184
1185//===----------------------------------------------------------------------===//
1186// XeGPU_RangeAttr
1187//===----------------------------------------------------------------------===//
1188
1189LogicalResult
1190RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
1191 IntegerAttr startOfRange, IntegerAttr endOfRange) {
1192 if (startOfRange.getInt() >= endOfRange.getInt())
1193 return emitError() << "'end' : " << endOfRange.getInt()
1194 << " must be greater than 'start' : "
1195 << startOfRange.getInt();
1196
1197 return success();
1198}
1199
1200//===----------------------------------------------------------------------===//
1201// XeGPU_TensorDescType
1202//===----------------------------------------------------------------------===//
1203
1204mlir::Type TensorDescType::parse(AsmParser &parser) {
1205 llvm::SmallVector<int64_t> shape;
1206 mlir::Type elementType;
1207 mlir::FailureOr<mlir::Attribute> encoding;
1208 mlir::FailureOr<mlir::Attribute> layout;
1209
1210 // Parse literal '<'
1211 if (parser.parseLess())
1212 return {};
1213
1214 auto shapeLoc = parser.getCurrentLocation();
1215 if (mlir::failed(parser.parseDimensionList(shape))) {
1216 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
1217 return {};
1218 }
1219
1220 auto elemTypeLoc = parser.getCurrentLocation();
1221 if (mlir::failed(parser.parseType(elementType))) {
1222 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
1223 return {};
1224 }
1225
1226 // parse optional attributes
1227 while (mlir::succeeded(parser.parseOptionalComma())) {
1228 mlir::Attribute attr;
1229 ParseResult res = parser.parseAttribute(attr);
1230 if (mlir::succeeded(res)) {
1231 if (mlir::isa<DistributeLayoutAttr>(attr)) {
1232 layout = attr;
1233 continue;
1234 }
1235 if (mlir::isa<BlockTensorDescAttr>(attr)) {
1236 encoding = attr;
1237 continue;
1238 }
1239 }
1240 return {};
1241 }
1242
1243 // Parse literal '>'
1244 if (parser.parseGreater())
1245 return {};
1246
1247 MLIRContext *ctxt = parser.getContext();
1248 return TensorDescType::getChecked(
1249 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
1250 elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
1251 layout.value_or(mlir::Attribute()));
1252}
1253
1254void TensorDescType::print(AsmPrinter &printer) const {
1255 printer << "<";
1256
1257 auto shape = getShape();
1258 for (int64_t dim : shape) {
1259 if (mlir::ShapedType::isDynamic(dim))
1260 printer << '?';
1261 else
1262 printer << dim;
1263 printer << 'x';
1264 }
1265
1266 printer << getElementType();
1267
1268 auto encoding = getEncoding();
1269 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1270 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
1271 printer << ", " << encoding;
1272
1273 if (auto layout = getLayout())
1274 printer << ", " << layout;
1275
1276 printer << ">";
1277}
1278
1279TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
1280 mlir::Type elementType, int array_length,
1281 bool boundary_check,
1282 MemorySpace memory_space,
1283 mlir::Attribute layout) {
1284 auto *context = elementType.getContext();
1285 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
1286 boundary_check);
1287 return Base::get(context, shape, elementType, attr, layout);
1288}
1289
1290LogicalResult
1291TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
1292 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
1293 mlir::Attribute encoding, mlir::Attribute layout) {
1294 size_t rank = shape.size();
1295
1296 if (rank == 0)
1297 return emitError() << "expected non-zero rank tensor";
1298
1299 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1300 if (blockAttr) {
1301 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
1302 if (rank > 1 && memorySpaceAttr &&
1303 memorySpaceAttr.getValue() == MemorySpace::SLM)
1304 return emitError() << "SLM is only supported for 1D block tensor";
1305 }
1306
1307 if (!elementType.isIntOrFloat())
1308 return emitError() << "unsupported element type " << elementType
1309 << ": expected integer or float";
1310
1311 if (auto layoutAttr =
1312 mlir::dyn_cast_if_present<DistributeLayoutAttr>(layout)) {
1313 if (rank != (size_t)layoutAttr.getRank())
1314 return emitError() << "expected layout rank to match tensor rank";
1315
1316 if (!layoutAttr.isDistributable(SmallVector<int64_t>(shape))) {
1317 std::string shapeStr;
1318 llvm::raw_string_ostream stream(shapeStr);
1319 llvm::interleaveComma(shape, stream);
1320 return emitError() << "cannot distribute [" << shapeStr << "] using "
1321 << layoutAttr;
1322 }
1323 }
1324
1325 return success();
1326}
1327
1328//===----------------------------------------------------------------------===//
1329// XeGPU_MemDescType
1330//===----------------------------------------------------------------------===//
1331mlir::Type MemDescType::parse(AsmParser &parser) {
1332 llvm::SmallVector<int64_t> shape;
1333 mlir::Type elementType;
1334 mlir::FailureOr<MemLayoutAttr> layout;
1335
1336 // Parse literal '<'
1337 if (parser.parseLess())
1338 return {};
1339
1340 auto shapeLoc = parser.getCurrentLocation();
1341 if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
1342 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
1343 return {};
1344 }
1345
1346 auto elemTypeLoc = parser.getCurrentLocation();
1347 if (mlir::failed(parser.parseType(elementType))) {
1348 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
1349 return {};
1350 }
1351
1352 // parse optional attributes
1353 if (mlir::succeeded(parser.parseOptionalComma())) {
1354 MemLayoutAttr attr;
1355 ParseResult res = parser.parseAttribute(attr);
1356 if (mlir::failed(res))
1357 return {};
1358 layout = attr;
1359 }
1360
1361 // Parse literal '>'
1362 if (parser.parseGreater())
1363 return {};
1364
1365 MLIRContext *ctxt = parser.getContext();
1366 return MemDescType::getChecked(
1367 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
1368 elementType, layout.value_or(MemLayoutAttr()));
1369}
1370
1371void MemDescType::print(AsmPrinter &printer) const {
1372 printer << "<";
1373
1374 printer.printDimensionList(getShape());
1375 printer << 'x';
1376 printer << getElementType();
1377
1378 if (auto layout = getMemLayout())
1379 printer << ", " << layout;
1380
1381 printer << ">";
1382}
1383
1384//===----------------------------------------------------------------------===//
1385// XeGPU_MemDescType
1386//===----------------------------------------------------------------------===//
1387
1388Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
1389
1390 auto *context = parser.getContext();
1391 llvm::SMLoc loc = parser.getCurrentLocation();
1392
1393 llvm::SmallDenseSet<StringRef> seenKeys;
1394 SmallVector<NamedAttribute> attributes;
1395
1396 auto parseElt = [&]() -> ParseResult {
1397 StringRef nameId;
1398 if (failed(parser.parseKeyword(&nameId)))
1399 return parser.emitError(loc, "expected valid attribute name");
1400
1401 if (!seenKeys.insert(nameId).second)
1402 return parser.emitError(loc, "duplicate key '")
1403 << nameId << " in mem layout attribute";
1404
1405 if (failed(parser.parseEqual()))
1406 return failure();
1407
1408 Attribute attr;
1409 if (failed(parser.parseAttribute(attr)))
1410 return failure();
1411 attributes.emplace_back(nameId, attr);
1412 return success();
1413 };
1414
1415 // Parse literal '<'
1416 if (parser.parseLess())
1417 return {};
1418
1419 if (failed(parser.parseCommaSeparatedList(parseElt)))
1420 return {};
1421
1422 // Parse literal '>'
1423 if (parser.parseGreater())
1424 return {};
1425
1426 return parser.getChecked<MemLayoutAttr>(
1427 loc, context, DictionaryAttr::get(context, attributes));
1428}
1429
1430void MemLayoutAttr::print(AsmPrinter &printer) const {
1431 printer << "<";
1432 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
1433 for (size_t i = 0; i < attrs.size(); i++) {
1434 printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
1435 if (i < attrs.size() - 1)
1436 printer << ", ";
1437 }
1438 printer << ">";
1439}
1440// a helper utility to perform binary operation on OpFoldResult.
1441// If both a and b are attributes, it will simply return the result.
1442// Otherwise, the corresponding arith op will be generated, and an
1443// contant op will be created if one of them is an attribute.
1444template <typename ArithOp>
1446 OpBuilder &builder) {
1447 auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
1448 auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
1449 return ArithOp::create(builder, loc, aVal, bVal).getResult();
1450}
1451
1452// a helper utility to perform division operation on OpFoldResult and int64_t.
1453#define div(a, b) \
1454 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
1455
1456// a helper utility to perform reminder operation on OpFoldResult and int64_t.
1457#define rem(a, b) \
1458 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
1459
1460// a helper utility to perform multiply operation on OpFoldResult and int64_t.
1461#define mul(a, b) \
1462 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
1463
1464// a helper utility to perform addition operation on two OpFoldResult.
1465#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
1466
1467// block the given offsets according to the block shape
1468// say the original offset is [y, x], and the block shape is [By, Bx],
1469// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1471 ArrayRef<OpFoldResult> offsets,
1472 ArrayRef<int64_t> blockShape) {
1473
1474 assert(offsets.size() == blockShape.size() &&
1475 "offsets and blockShape must have the same size");
1476 SmallVector<OpFoldResult> blockedOffsets;
1477 SmallVector<OpFoldResult> divs, rems;
1478
1479 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1480 divs.push_back(div(offset, block));
1481 rems.push_back(rem(offset, block));
1482 }
1483 blockedOffsets.append(divs.begin(), divs.end());
1484 blockedOffsets.append(rems.begin(), rems.end());
1485
1486 return blockedOffsets;
1487}
1488
1489// Get strides as vector of integer for MemDesc.
1490SmallVector<int64_t> MemDescType::getStrideShape() {
1491
1492 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1493
1494 ArrayAttr strideAttr = getStrideAttr();
1495 SmallVector<int64_t> strides;
1496 for (Attribute attr : strideAttr.getValue()) {
1497 strides.push_back(cast<IntegerAttr>(attr).getInt());
1498 }
1499
1500 SmallVector<int64_t> innerBlkShape = getBlockShape();
1501
1502 // get perm from FCD to LCD
1503 // perm[i] = the dim with i-th smallest stride
1504 SmallVector<int, 4> perm =
1505 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
1506 llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
1507
1508 assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
1509
1510 SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
1511 innerBlkStride[perm[0]] = 1;
1512 for (size_t i = 1; i < perm.size(); ++i)
1513 innerBlkStride[perm[i]] =
1514 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
1515
1516 // compute the original matrix shape using the stride info
1517 // and compute the number of blocks in each dimension
1518 // The shape of highest dim can't be derived from stride info,
1519 // but doesn't impact the stride computation for blocked layout.
1520 SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
1521 SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
1522 for (size_t i = 0; i < perm.size() - 1; ++i) {
1523 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
1524 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
1525 }
1526
1527 int64_t innerBlkSize = 1;
1528 for (auto s : innerBlkShape)
1529 innerBlkSize *= s;
1530
1531 SmallVector<int64_t> outerBlkStride(matrixShape.size());
1532 outerBlkStride[perm[0]] = innerBlkSize;
1533 for (size_t i = 0; i < perm.size() - 1; ++i) {
1534 outerBlkStride[perm[i + 1]] =
1535 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
1536 }
1537
1538 // combine the inner and outer strides
1539 SmallVector<int64_t> blockedStrides;
1540 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
1541 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
1542
1543 return blockedStrides;
1544}
1545
1546// Calculate the linear offset using the blocked offsets and stride
1547Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
1548 ArrayRef<OpFoldResult> offsets) {
1549
1550 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1551 SmallVector<int64_t> blockShape = getBlockShape();
1552 SmallVector<int64_t> strides = getStrideShape();
1553 SmallVector<OpFoldResult> blockedOffsets;
1554
1555 // blockshape equal to matrixshape means no blocking
1556 if (llvm::equal(blockShape, matrixShape)) {
1557 // remove the outer dims from strides
1558 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
1559 } else {
1560 assert(offsets.size() == blockShape.size() &&
1561 "offsets and blockShape must have the same size");
1562 // say the original offset is [y, x], and the block shape is [By, Bx],
1563 // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1564
1565 SmallVector<OpFoldResult> divs, rems;
1566
1567 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1568 divs.push_back(div(offset, block));
1569 rems.push_back(rem(offset, block));
1570 }
1571 blockedOffsets.append(divs.begin(), divs.end());
1572 blockedOffsets.append(rems.begin(), rems.end());
1573 offsets = blockedOffsets;
1574 }
1575
1576 // Start with initial value as matrix descriptor's base offset.
1577 Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
1578 for (size_t i = 0; i < offsets.size(); ++i) {
1579 OpFoldResult mulResult = mul(offsets[i], strides[i]);
1580 Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
1581 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
1582 }
1583
1584 return linearOffset;
1585}
1586
1587} // namespace xegpu
1588} // namespace mlir
1589
1590#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
1591#define GET_ATTRDEF_CLASSES
1592#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
1593#define GET_TYPEDEF_CLASSES
1594#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
return success()
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define div(a, b)
#define rem(a, b)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:209
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
static SmallVector< SmallVector< int64_t > > genStaticCoordinates(llvm::ArrayRef< int64_t > canonicalIds, llvm::ArrayRef< int64_t > layout, llvm::ArrayRef< int64_t > subShape, llvm::ArrayRef< int64_t > shape)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
Definition XeGPU.h:32
static SmallVector< int64_t > mapSlicedDimsToParentSpace(const SmallVector< int64_t > &dimsToMap, ArrayRef< int64_t > sliceDims)
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
static SmallVector< SmallVector< Value > > genCoordinates(OpBuilder &builder, Location loc, SmallVector< Value > delinearizedId, ArrayRef< int64_t > subShapesLayout, ArrayRef< int64_t > subShape, ArrayRef< int64_t > srcShape)
SmallVector< int64_t > getPermForParentLayout(ArrayRef< int64_t > sliceDims, ArrayRef< int64_t > permutation)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:114
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.