MLIR 23.0.0git
XeGPUDialect.cpp
Go to the documentation of this file.
1//===- XeGPUDialect.cpp - MLIR XeGPU dialect implementation -----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/Builders.h"
16#include "llvm/ADT/TypeSwitch.h"
17#include "llvm/Support/Debug.h"
18
19using std::optional;
20
21namespace mlir {
22namespace xegpu {
23
24void XeGPUDialect::initialize() {
25 addTypes<
26#define GET_TYPEDEF_LIST
27#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
28 >();
29 addOperations<
30#define GET_OP_LIST
31#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
32 >();
33 addAttributes<
34#define GET_ATTRDEF_LIST
35#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
36 >();
37}
38#define GET_OP_INTERFACE_CLASSES
39#include "mlir/Dialect/XeGPU/IR/XeGPUOpInterface.cpp.inc"
40
41// A `srcShape` consists of N distribution units, each being `subShapesLayout` x
42// `subShape`. A `delinearizedId` is used to identify a particular `subShape`
43// within each distribution unit.
44// Example:
45// WG data is 128x256. SG data is 16x32, in 4x2 layout, this gives a
46// distribution unit of shape 64x64, we have 2x4 such distribution units.
47// `delinearizedId` is used to identify a 16x32 of a subgroup in each
48// distribution unit.
51 SmallVector<Value> delinearizedId,
52 ArrayRef<int64_t> subShapesLayout, ArrayRef<int64_t> subShape,
53 ArrayRef<int64_t> srcShape) {
55
56 // A distribution unit must be less than or equal to `srcShape`
57 SmallVector<int64_t> distUnitShape = llvm::map_to_vector(
58 llvm::zip_equal(srcShape,
59 computeElementwiseMul(subShapesLayout, subShape)),
60 [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
61
62 // Get the offset of `subShape` within a distribution unit.
63 SmallVector<Value> distUnitLocalOffset = llvm::map_to_vector(
64 llvm::zip(delinearizedId, subShape), [&](const auto &t) -> Value {
65 return builder.createOrFold<arith::MulIOp>(
66 loc, std::get<0>(t),
67 builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
68 });
69
70 // For each dist unit
71 for (SmallVector<int64_t> unitOffs :
72 StaticTileOffsetRange(srcShape, distUnitShape)) {
73 // Get dist unit offset within `srcShape`.
75 llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
76 return arith::ConstantIndexOp::create(builder, loc, d);
77 });
78 // Calculate `subShape` offset within `srcShape`.
80 llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
81 [&](const auto &t) -> Value {
82 return builder.createOrFold<arith::AddIOp>(
83 loc, std::get<0>(t), std::get<1>(t));
84 });
85 // Do not go beyond `srcShape` bounds.
86 SmallVector<Value> mods = llvm::map_to_vector(
87 llvm::zip_equal(adds, srcShape), [&](const auto &t) -> Value {
88 return builder.createOrFold<arith::RemUIOp>(
89 loc, std::get<0>(t),
90 arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
91 });
92
93 coordinates.push_back(mods);
94 }
95 return coordinates;
96}
97
98// Checks if the given shape can be evenly distributed based on the layout
99// and data factors provided by the LayoutAttr.
100bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
101 xegpu::DistributeLayoutAttr attr) {
102 assert(attr && "Layout attribute is missing.");
103
104 // Checks whether the given shape can be evenly distributed using the
105 // specified layout and data attributes. If successful, it returns the work
106 // size for each compute unit; otherwise, it returns `std::nullopt`. The work
107 // size per compute unit is calculated as follows:
108 // - If `data` is null: newShape[i] = shape[i] / layout[i]
109 // - If `data` is not null: newShape[i] = data[i]
110 // When round-robin distribution (`rr`) is enabled, `shape[i]` can be
111 // smaller than `layout[i] * data[i]`, allowing multiple compute units to
112 // share the data.
113 auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape,
116 bool rr = true) -> optional<SmallVector<int64_t>> {
118 if (layout.size()) {
119 if (layout.size() != shape.size())
120 return std::nullopt;
121 auto ratio = computeShapeRatio(shape, layout);
122 if (ratio.has_value()) {
123 newShape = ratio.value();
124 } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
125 return std::nullopt;
126 }
127 // Round-robin case: continue with original newShape
128 }
129
130 if (data.size()) {
131 if (data.size() != shape.size())
132 return std::nullopt;
133 auto ratio = computeShapeRatio(newShape, data);
134 if (!ratio.has_value() && rr)
135 ratio = computeShapeRatio(data, newShape);
136 if (!ratio.has_value())
137 return std::nullopt;
138
139 // if data is not null, we always return it for next phase.
140 newShape = data;
141 }
142 return newShape;
143 };
144
145 // check the sgLayout and sgData
146 auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
147 attr.getEffectiveSgDataAsInt());
148 if (!maybeSgShape)
149 return false;
150 auto sgShape = maybeSgShape.value();
151
152 // check InstData, it neither have layout nor need round-robin
153 auto maybeInstShape =
154 tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(), false);
155 if (!maybeInstShape)
156 return false;
157 auto instShape = maybeInstShape.value();
158
159 // check LaneLayout and LaneData
160 auto maybeLaneShape =
161 tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
162 attr.getEffectiveLaneDataAsInt(), false);
163 return maybeLaneShape.has_value();
164}
165
166//===----------------------------------------------------------------------===//
167// XeGPU_BlockTensorDescAttr
168//===----------------------------------------------------------------------===//
169BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
170 xegpu::MemorySpace memory_space,
171 int array_length,
172 bool boundary_check) {
173 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
174 auto lengthAttr =
175 IntegerAttr::get(IntegerType::get(context, 64), array_length);
176 auto boundaryAttr = BoolAttr::get(context, boundary_check);
177 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
178}
179
180bool BlockTensorDescAttr::hasDefaultsOnly() {
181 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
182 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
183}
184
185//===----------------------------------------------------------------------===//
186// XeGPU_ScatterTensorDescAttr
187//===----------------------------------------------------------------------===//
188ScatterTensorDescAttr
189ScatterTensorDescAttr::get(mlir::MLIRContext *context,
190 xegpu::MemorySpace memory_space, int chunk_size) {
191 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
192 auto chunkSizeAttr =
193 IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
194 return Base::get(context, scopeAttr, chunkSizeAttr);
195}
196
197LogicalResult ScatterTensorDescAttr::verify(
198 llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
199 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
200 int64_t chunkSize = chunk_size.getInt();
201 if (chunkSize <= 0)
202 return emitError() << "invalid chunk size";
203
204 return success();
205}
206
207//===----------------------------------------------------------------------===//
208// XeGPU_LayoutAttr
209//===----------------------------------------------------------------------===//
210LogicalResult
211LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
212 DenseI32ArrayAttr sg_layout, DenseI32ArrayAttr sg_data,
213 DenseI32ArrayAttr inst_data, DenseI32ArrayAttr lane_layout,
214 DenseI32ArrayAttr lane_data, DenseI32ArrayAttr order) {
215
216 // Special case for store_matrix
217 if (!sg_layout && !inst_data && !lane_layout)
218 return success();
219
220 // generate code to check sg_laout, inst_data and lane_layout having the same
221 // rank if they are not null.
222
223 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
224 return emitError()
225 << "expected sg_layout and inst_data to have the same rank";
226 }
227
228 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
229 return emitError()
230 << "expected sg_layout and lane_layout to have the same rank";
231 }
232
233 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
234 return emitError() << "expected inst_data and lane_layout to have the same "
235 "rank, got inst_data "
236 << inst_data.size() << ", lane_layout "
237 << lane_layout.size();
238 }
239
240 // sg_data is optional for Workgroup layout, but its presence requires
241 // sg_layout.
242 if (sg_data) {
243 if (!sg_layout)
244 return emitError() << "expected sg_layout being used with sg_data";
245 if (sg_data.size() != sg_layout.size())
246 return emitError()
247 << "expected sg_data and sg_layout to have the same rank";
248 }
249
250 // lane_data is optional for Subgroup layout, but its presence requires
251 // lane_layout.
252 if (lane_data) {
253 if (!lane_layout)
254 return emitError() << "expected lane_layout being used with lane_data";
255 if (lane_data.size() != lane_layout.size())
256 return emitError()
257 << "expected lane_data and lane_layout to have the same rank";
258 }
259
260 if (order) {
261 if (!sg_layout && !lane_layout)
262 return emitError()
263 << "expected sg_layout/lane_layout being used with order";
264
265 if (sg_layout && order.size() != sg_layout.size())
266 return emitError()
267 << "expected order and sg_layout to have the same rank";
268
269 if (lane_layout && order.size() != lane_layout.size())
270 return emitError()
271 << "expected order and lane_layout to have the same rank";
272 }
273
274 return success();
275}
276
277FailureOr<SmallVector<Value>>
278LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
279
280 SmallVector<int64_t> sgLayoutInt;
281 if (isForWorkgroup()) {
282 sgLayoutInt = getEffectiveSgLayoutAsInt();
283 } else if (isForSubgroup()) {
284 sgLayoutInt = getEffectiveLaneLayoutAsInt();
285 } else {
286 return failure();
287 }
288
289 DenseI32ArrayAttr orderAttr = getOrder();
290
291 // Handle order attribute
292 SmallVector<int64_t> order;
293 if (orderAttr && !orderAttr.empty()) {
294 order = llvm::to_vector(
295 llvm::map_range(orderAttr.asArrayRef(),
296 [](int32_t idx) { return static_cast<int64_t>(idx); }));
297 } else {
298 // Default order: [1, 0] for 2D (row-major), [2, 1, 0] for 3D, etc.
299 order = llvm::to_vector(
300 llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
301 }
302
303 if (order.size() != sgLayoutInt.size()) {
304 return failure();
305 }
306
307 SmallVector<Value> result(sgLayoutInt.size());
308 Value remaining = linearId;
309
310 /// Process dimensions in the order they appear in the order array
311 /// The first dimension in order is the fastest-changing
312 ///
313 /// Example walkthrough for linearId=22, sgLayout=[2,4,4], order=[2,1,0]:
314 ///
315 /// Initial: remaining=22, dimIdx = order[i], dimSize = sgLayout[dimIdx],
316 /// result=[?,?,?]
317 ///
318 /// i=0 (process columns, dimIdx=2, dimSize=4):
319 /// result[2] = 22 % 4 = 2 (column coordinate)
320 /// remaining = 22 / 4 = 5 (5 complete groups of 4 columns processed)
321 ///
322 /// i=1 (process rows, dimIdx=1, dimSize=4):
323 /// result[1] = 5 % 4 = 1 (row coordinate)
324 /// remaining = 5 / 4 = 1 (1 complete group of 4 rows processed)
325 ///
326 /// i=2 (process layers, dimIdx=0, dimSize=2):
327 /// result[0] = 1 % 2 = 1 (layer coordinate)
328 /// (no remaining update - last iteration)
329 ///
330 /// Final result: [1,1,2] = Layer 1, Row 1, Column 2
331 for (size_t i = 0; i < order.size(); ++i) {
332 int64_t dimIdx = order[i];
333 int64_t dimSize = sgLayoutInt[dimIdx];
334
335 Value dimSizeVal =
336 builder.createOrFold<arith::ConstantIndexOp>(loc, dimSize);
337
338 /// Extract the coordinate for this dimension using modulo operation
339 /// This gives us "how far within this dimension" we are
340 /// e.g., linearId=22, dimSize=4: 22 % 4 = 2 (we're at position 2 within
341 /// this dimension)
342 result[dimIdx] =
343 builder.createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
344
345 /// Update remaining for the next dimension by removing what we've already
346 /// processed. Division tells us "how many complete groups of this dimension
347 /// we've gone through" e.g., linearId=22, dimSize=4: 22 / 4 = 5 (we've
348 /// completed 5 groups of 4) Skip this for the last iteration since there's
349 /// no next dimension to process
350 if (i < order.size() - 1) {
351 remaining =
352 builder.createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
353 }
354 }
355 return result;
356}
357
358/// Implements DistributeLayoutAttr::computeDistributedCoords to generate
359/// instructions for computing multi-dimensional offsets when distributed by
360/// LayoutAttr.
361FailureOr<SmallVector<SmallVector<Value>>>
362LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
363 Value linearId, ArrayRef<int64_t> shape) {
364 SmallVector<int64_t> layout;
365 SmallVector<int64_t> subShape;
366 if (isForWorkgroup()) {
367 layout = getEffectiveSgLayoutAsInt();
368 subShape = getEffectiveSgDataAsInt();
369 } else if (isForSubgroup()) {
370 layout = getEffectiveLaneLayoutAsInt();
371 subShape = getEffectiveLaneDataAsInt();
372 } else {
373 return failure();
374 }
375 if (subShape.empty()) {
376 if (auto derivedShape = computeShapeRatio(shape, layout))
377 subShape = derivedShape.value();
378 else
379 return failure();
380 }
381
382 // delinearize Ids
383 auto maybeIds = delinearizeId(builder, loc, linearId);
384 if (failed(maybeIds))
385 return failure();
386 SmallVector<Value> ids = *maybeIds;
387
388 return genCoordinates(builder, loc, ids, layout, subShape, shape);
389}
390
391bool LayoutAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
392 if (dyn_cast<xegpu::SliceAttr>(other))
393 return false;
394
395 return *this == dyn_cast<xegpu::LayoutAttr>(other);
396}
397
398// set the layout for unit dims: sg_data, inst_data and lane_data to 1
399DistributeLayoutAttr
400LayoutAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
401 auto sgDataOpt = getSgData();
402 auto instDataOpt = getInstData();
403 auto laneDataOpt = getLaneData();
404
405 SmallVector<int32_t> sgData;
406 SmallVector<int32_t> instData;
407 SmallVector<int32_t> laneData;
408
409 if (sgDataOpt) {
410 sgData = llvm::to_vector(sgDataOpt.asArrayRef());
411 }
412 if (instDataOpt) {
413 instData = llvm::to_vector(instDataOpt.asArrayRef());
414 }
415 if (laneDataOpt) {
416 laneData = llvm::to_vector(laneDataOpt.asArrayRef());
417 }
418
419 for (auto dim : unitDims) {
420 if (dim < static_cast<int64_t>(sgData.size()))
421 sgData[dim] = 1;
422 if (dim < static_cast<int64_t>(instData.size()))
423 instData[dim] = 1;
424 if (dim < static_cast<int64_t>(laneData.size()))
425 laneData[dim] = 1;
426 }
427
428 return LayoutAttr::get(
429 getContext(), getSgLayout(),
430 sgData.empty() ? DenseI32ArrayAttr()
432 instData.empty() ? DenseI32ArrayAttr()
433 : DenseI32ArrayAttr::get(getContext(), instData),
434 getLaneLayout(),
435 laneData.empty() ? DenseI32ArrayAttr()
436 : DenseI32ArrayAttr::get(getContext(), laneData),
437 getOrder());
438}
439
440// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
441DistributeLayoutAttr
442LayoutAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
443 auto sgLayoutOpt = getSgLayout();
444 auto laneLayoutOpt = getLaneLayout();
445
446 SmallVector<int32_t> sgLayout;
447 SmallVector<int32_t> laneLayout;
448
449 if (sgLayoutOpt) {
450 sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
451 }
452 if (laneLayoutOpt) {
453 laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
454 }
455
456 for (auto dim : unitDims) {
457 if (dim < static_cast<int64_t>(sgLayout.size()))
458 sgLayout[dim] = 1;
459 if (dim < static_cast<int64_t>(laneLayout.size()))
460 laneLayout[dim] = 1;
461 }
462
463 return LayoutAttr::get(
464 getContext(),
465 sgLayout.empty() ? DenseI32ArrayAttr()
466 : DenseI32ArrayAttr::get(getContext(), sgLayout),
467 getSgData(), getInstData(),
468 laneLayout.empty() ? DenseI32ArrayAttr()
469 : DenseI32ArrayAttr::get(getContext(), laneLayout),
470 getLaneData(), getOrder());
471}
472
473//===----------------------------------------------------------------------===//
474// XeGPU_SliceAttr
475//===----------------------------------------------------------------------===//
476LogicalResult
477SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
478 xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
479
480 if (!dims)
481 return emitError() << "expected dims attribute";
482
483 // check every element in dims is unique and smaller than rank
484 llvm::SmallDenseSet<int64_t> seen;
485 for (int64_t dim : dims.asArrayRef()) {
486 if (dim < 0)
487 return emitError() << "invalid dim (" << dim << ") in slice attribute.";
488 if (!seen.insert(dim).second)
489 return emitError() << "repeated dim (" << dim << ") in slice attribute.";
490 }
491 return success();
492}
493
494SliceAttr SliceAttr::flatten() const {
495 xegpu::DistributeLayoutAttr parent = getParent();
496 SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
497
498 while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
499 parent = sliceAttr.getParent();
500 slicedDims.push_back(sliceAttr.getDims());
501 }
502
503 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
504 SmallVector<int64_t> indices =
505 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
506
507 // get remaining dims (flattend) by applying slice ops with all slicedDims
508 SmallVector<int64_t> remainingDims(indices);
509 for (auto dim : llvm::reverse(slicedDims))
510 remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
511 dim.asArrayRef());
512
513 // get flattend sliced dims by applying slice ops with the remaining dims
514 SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
515 llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
516
517 return xegpu::SliceAttr::get(
518 getContext(), layoutAttr,
519 DenseI64ArrayAttr::get(getContext(), flattendDims));
520}
521
522FailureOr<SmallVector<Value>>
523SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
524 SliceAttr attr = flatten();
525 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
526 return parent.delinearizeId(builder, loc, linearId);
527}
528
529// Implements DistributeLayoutAttr::computeDistributedCoords to generate
530// instructions for computing multi-dimensional offsets when distributed by
531// LayoutAttr.
532FailureOr<SmallVector<SmallVector<Value>>>
533SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
534 Value linearId, ArrayRef<int64_t> shape) {
535 assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
536 if (!isForWorkgroup())
537 return failure();
538
539 SmallVector<int64_t> layout;
540 SmallVector<int64_t> subShape;
541 if (isForWorkgroup()) {
542 layout = getEffectiveSgLayoutAsInt();
543 subShape = getEffectiveSgDataAsInt();
544 } else if (isForSubgroup()) {
545 layout = getEffectiveLaneLayoutAsInt();
546 subShape = getEffectiveLaneDataAsInt();
547 } else {
548 return failure();
549 }
550
551 if (subShape.empty()) {
552 if (auto derivedShape = computeShapeRatio(shape, layout))
553 subShape = derivedShape.value();
554 else
555 return failure();
556 }
557
558 // delinearize Ids
559 auto maybeIds = delinearizeId(builder, loc, linearId);
560 if (failed(maybeIds))
561 return failure();
562
563 // The effective sgIds for offsets computing correspond
564 // to the dims that are not sliced.
565 ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
566 SmallVector<Value> sgIds =
567 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
568
569 return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
570}
571
572bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
573 auto flattenedThis = flatten();
574 // If other is a LayoutAttr, just compare directly with parent of
575 // flattenedThis.
576 if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
577 return flattenedThis.getParent() == otherLayout;
578 // If other is a SliceAttr, flatten it first before comparing.
579 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
580 // Both must have common parent LayoutAttr.
581 if (flattenedThis.getParent() != flattenedOther.getParent())
582 return false;
583 // otherFlattened's sliced dims must be a subset of flattenedThis's sliced
584 // dims.
585 llvm::SmallDenseSet<int64_t> thisDims(
586 flattenedThis.getDims().asArrayRef().begin(),
587 flattenedThis.getDims().asArrayRef().end());
588 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
589 [&](int64_t dim) { return thisDims.contains(dim); });
590}
591
592xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
593 if (sliceDimsToDrop.empty())
594 return *this;
595 SmallVector<int64_t> sliceDims{getDims().asArrayRef()};
596 for (auto dim : sliceDimsToDrop) {
597 auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), dim);
598 assert(foundIt != sliceDims.end() &&
599 "Expected to find the specified reduction dim in slice dims");
600 sliceDims.erase(foundIt);
601 }
602
603 auto sliceWithoutDims = xegpu::SliceAttr::get(
604 this->getContext(), getParent(),
605 DenseI64ArrayAttr::get(this->getContext(), sliceDims));
606
607 return sliceWithoutDims;
608}
609
610bool SliceAttr::isEqualTo(const xegpu::DistributeLayoutAttr &other) {
611 if (dyn_cast<xegpu::LayoutAttr>(other))
612 return false;
613
614 auto flattenedThis = flatten();
615 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
616
617 return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
618 (flattenedThis.getDims() == flattenedOther.getDims()));
619}
620
621// Helper function to adjust dimensions from sliced space to parent space
622// say we have a parent shape of rank 4, and slice dims [1,3], so the sliced
623// shape is of rank 2, if we want to set unit dim [0] in sliced space, it maps
624// to dim [0] in parent space; if we want to set unit dim [1] in sliced space,
625// it maps to dim [2] in parent space.
628 ArrayRef<int64_t> sliceDims) {
629 // Rather than recovering the exact parent rank, we compute a safe upper bound
630 // so that dimsToMap can be adjusted safely. This upper bound is defined as
631 // max(dimsToMap, sliceDims) + 1 + sliceDims.size().
632 int64_t maxDim = -1;
633 maxDim =
634 std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
635 maxDim =
636 std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
637 int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
638
639 // get remaining dims in parent space after applying slicing with parent's
640 // slice Dims
641 llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
642 sliceDims.end());
643 SmallVector<int64_t> remainingDims;
644 for (int64_t i = 0; i < parentSpaceRank; ++i) {
645 if (!slicedDimsSet.contains(i))
646 remainingDims.push_back(i);
647 }
648
649 // Map unit dims from sliced space to parent space
650 SetVector<int64_t> adjustUnitDims;
651 for (auto dim : dimsToMap) {
652 int64_t mappedDim = remainingDims[dim];
653 adjustUnitDims.insert(mappedDim);
654 }
655
656 return adjustUnitDims;
657}
658
659// set the layout for unit dims: sg_data, inst_data and lane_data to 1
660DistributeLayoutAttr
661SliceAttr::setUnitDimData(SetVector<int64_t> unitDims) const {
662 DistributeLayoutAttr parentLayout = getParent();
663
664 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
665
666 SetVector<int64_t> adjustUnitDims =
667 mapSlicedDimsToParentSpace(unitDims, sliceDims);
668
669 return SliceAttr::get(getContext(),
670 parentLayout.setUnitDimData(adjustUnitDims), getDims());
671}
672
673// set the layout for the sepcified unit dims: sg_lane and lane_layout to 1
674DistributeLayoutAttr
675SliceAttr::setUnitDimLayout(SetVector<int64_t> unitDims) const {
676 DistributeLayoutAttr parentLayout = getParent();
677
678 ArrayRef<int64_t> sliceDims = getDims().asArrayRef();
679
680 SetVector<int64_t> adjustUnitDims =
681 mapSlicedDimsToParentSpace(unitDims, sliceDims);
682
683 return SliceAttr::get(
684 getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
685}
686
687//===----------------------------------------------------------------------===//
688// XeGPU_RangeAttr
689//===----------------------------------------------------------------------===//
690
691LogicalResult
692RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
693 IntegerAttr startOfRange, IntegerAttr endOfRange) {
694 if (startOfRange.getInt() >= endOfRange.getInt())
695 return emitError() << "'end' : " << endOfRange.getInt()
696 << " must be greater than 'start' : "
697 << startOfRange.getInt();
698
699 return success();
700}
701
702//===----------------------------------------------------------------------===//
703// XeGPU_TensorDescType
704//===----------------------------------------------------------------------===//
705
706mlir::Type TensorDescType::parse(AsmParser &parser) {
707 llvm::SmallVector<int64_t> shape;
708 mlir::Type elementType;
709 mlir::FailureOr<mlir::Attribute> encoding;
710 mlir::FailureOr<mlir::Attribute> layout;
711
712 // Parse literal '<'
713 if (parser.parseLess())
714 return {};
715
716 auto shapeLoc = parser.getCurrentLocation();
717 if (mlir::failed(parser.parseDimensionList(shape))) {
718 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
719 return {};
720 }
721
722 auto elemTypeLoc = parser.getCurrentLocation();
723 if (mlir::failed(parser.parseType(elementType))) {
724 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
725 return {};
726 }
727
728 // parse optional attributes
729 while (mlir::succeeded(parser.parseOptionalComma())) {
730 mlir::Attribute attr;
731 ParseResult res = parser.parseAttribute(attr);
732 if (mlir::succeeded(res)) {
733 if (mlir::isa<LayoutAttr>(attr)) {
734 layout = attr;
735 continue;
736 }
737 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
738 encoding = attr;
739 continue;
740 }
741 }
742 return {};
743 }
744
745 // Parse literal '>'
746 if (parser.parseGreater())
747 return {};
748
749 MLIRContext *ctxt = parser.getContext();
750 return TensorDescType::getChecked(
751 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
752 elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
753 layout.value_or(mlir::Attribute()));
754}
755
756void TensorDescType::print(AsmPrinter &printer) const {
757 printer << "<";
758
759 auto shape = getShape();
760 for (int64_t dim : shape) {
761 if (mlir::ShapedType::isDynamic(dim))
762 printer << '?';
763 else
764 printer << dim;
765 printer << 'x';
766 }
767
768 printer << getElementType();
769
770 auto encoding = getEncoding();
771 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
772 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
773 printer << ", " << encoding;
774
775 if (auto layout = getLayout())
776 printer << ", " << layout;
777
778 printer << ">";
779}
780
781TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
782 mlir::Type elementType, int array_length,
783 bool boundary_check,
784 MemorySpace memory_space,
785 mlir::Attribute layout) {
786 auto context = elementType.getContext();
787 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
788 boundary_check);
789 return Base::get(context, shape, elementType, attr, layout);
790}
791
792TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
793 mlir::Type elementType, int chunk_size,
794 MemorySpace memory_space,
795 mlir::Attribute layout) {
796 auto context = elementType.getContext();
797 auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
798 return Base::get(context, shape, elementType, attr, layout);
799}
800
801LogicalResult
802TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
803 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
804 mlir::Attribute encoding, mlir::Attribute layout) {
805 size_t rank = shape.size();
806
807 if (rank == 0)
808 return emitError() << "expected non-zero rank tensor";
809
810 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
811 if (blockAttr) {
812 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
813 if (rank > 1 && memorySpaceAttr &&
814 memorySpaceAttr.getValue() == MemorySpace::SLM)
815 return emitError() << "SLM is only supported for 1D block tensor";
816 }
817
818 if (!elementType.isIntOrFloat())
819 return emitError() << "unsupported element type " << elementType
820 << ": expected integer or float";
821
822 // for gather and scatter ops, Low-precision types are packed in 32-bit units.
823 unsigned bitWidth = elementType.getIntOrFloatBitWidth();
824 int chunkAlignmentFactor =
827 : 1;
828 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
829 if (scatterAttr) {
830 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
831 if (rank == 1 && chunkSize != 1)
832 return emitError() << "expected non-contiguous elements for 1D tensor";
833
834 // If chunk size > 1, the second dimension of the tensor shape must be
835 // equal to chunk size and it must be a multiple of the
836 // chunkAlignmentFactor.
837 if (chunkSize > 1) {
838 if (shape.back() != chunkSize)
839 return emitError() << "expected last dim of tensor to match chunk size";
840 if (shape.back() % chunkAlignmentFactor != 0)
841 return emitError() << "expected last dim of tensor to be a multiple of "
842 << chunkAlignmentFactor;
843 }
844 }
845
846 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
847 if (layoutAttr) {
848 if (rank != (size_t)layoutAttr.getRank())
849 return emitError() << "expected layout rank to match tensor rank";
850
851 auto laneData = layoutAttr.getLaneData();
852 if (scatterAttr && laneData) {
853 // Validate subgroup mapping rules for scattered tensors.
854 // if chunkSize > 1, the last dimension of the tensor should
855 // be distributed in the units divisible by chunkAlignmentFactor.
856 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
857 if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
858 return emitError()
859 << "expected last dim of lane_data to be a multiple of: "
860 << chunkAlignmentFactor;
861 }
862
863 if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
864 std::string shapeStr;
865 llvm::raw_string_ostream stream(shapeStr);
866 llvm::interleaveComma(shape, stream);
867 return emitError() << "cannot distribute [" << shapeStr << "] using "
868 << layoutAttr;
869 }
870 }
871 return success();
872}
873
874//===----------------------------------------------------------------------===//
875// XeGPU_MemDescType
876//===----------------------------------------------------------------------===//
877mlir::Type MemDescType::parse(AsmParser &parser) {
878 llvm::SmallVector<int64_t> shape;
879 mlir::Type elementType;
880 mlir::FailureOr<MemLayoutAttr> layout;
881
882 // Parse literal '<'
883 if (parser.parseLess())
884 return {};
885
886 auto shapeLoc = parser.getCurrentLocation();
887 if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
888 parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
889 return {};
890 }
891
892 auto elemTypeLoc = parser.getCurrentLocation();
893 if (mlir::failed(parser.parseType(elementType))) {
894 parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
895 return {};
896 }
897
898 // parse optional attributes
899 if (mlir::succeeded(parser.parseOptionalComma())) {
900 MemLayoutAttr attr;
901 ParseResult res = parser.parseAttribute(attr);
902 if (mlir::failed(res))
903 return {};
904 layout = attr;
905 }
906
907 // Parse literal '>'
908 if (parser.parseGreater())
909 return {};
910
911 MLIRContext *ctxt = parser.getContext();
912 return MemDescType::getChecked(
913 [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
914 elementType, layout.value_or(MemLayoutAttr()));
915}
916
917void MemDescType::print(AsmPrinter &printer) const {
918 printer << "<";
919
920 printer.printDimensionList(getShape());
921 printer << 'x';
922 printer << getElementType();
923
924 if (auto layout = getMemLayout())
925 printer << ", " << layout;
926
927 printer << ">";
928}
929
930//===----------------------------------------------------------------------===//
931// XeGPU_MemDescType
932//===----------------------------------------------------------------------===//
933
934Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
935
936 auto context = parser.getContext();
937 llvm::SMLoc loc = parser.getCurrentLocation();
938
939 llvm::SmallDenseSet<StringRef> seenKeys;
940 SmallVector<NamedAttribute> attributes;
941
942 auto parseElt = [&]() -> ParseResult {
943 StringRef nameId;
944 if (failed(parser.parseKeyword(&nameId)))
945 return parser.emitError(loc, "expected valid attribute name");
946
947 if (!seenKeys.insert(nameId).second)
948 return parser.emitError(loc, "duplicate key '")
949 << nameId << " in mem layout attribute";
950
951 if (failed(parser.parseEqual()))
952 return failure();
953
954 Attribute attr;
955 if (failed(parser.parseAttribute(attr)))
956 return failure();
957 attributes.emplace_back(nameId, attr);
958 return success();
959 };
960
961 // Parse literal '<'
962 if (parser.parseLess())
963 return {};
964
965 if (failed(parser.parseCommaSeparatedList(parseElt)))
966 return {};
967
968 // Parse literal '>'
969 if (parser.parseGreater())
970 return {};
971
972 return parser.getChecked<MemLayoutAttr>(
973 loc, context, DictionaryAttr::get(context, attributes));
974}
975
976void MemLayoutAttr::print(AsmPrinter &printer) const {
977 printer << "<";
978 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
979 for (size_t i = 0; i < attrs.size(); i++) {
980 printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
981 if (i < attrs.size() - 1)
982 printer << ", ";
983 }
984 printer << ">";
985}
986// a helper utility to perform binary operation on OpFoldResult.
987// If both a and b are attributes, it will simply return the result.
988// Otherwise, the corresponding arith op will be generated, and an
989// contant op will be created if one of them is an attribute.
990template <typename ArithOp>
992 OpBuilder &builder) {
993 auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
994 auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
995 return ArithOp::create(builder, loc, aVal, bVal).getResult();
996}
997
998// a helper utility to perform division operation on OpFoldResult and int64_t.
999#define div(a, b) \
1000 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
1001
1002// a helper utility to perform reminder operation on OpFoldResult and int64_t.
1003#define rem(a, b) \
1004 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
1005
1006// a helper utility to perform multiply operation on OpFoldResult and int64_t.
1007#define mul(a, b) \
1008 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
1009
1010// a helper utility to perform addition operation on two OpFoldResult.
1011#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
1012
1013// block the given offsets according to the block shape
1014// say the original offset is [y, x], and the block shape is [By, Bx],
1015// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1017 ArrayRef<OpFoldResult> offsets,
1018 ArrayRef<int64_t> blockShape) {
1019
1020 assert(offsets.size() == blockShape.size() &&
1021 "offsets and blockShape must have the same size");
1022 SmallVector<OpFoldResult> blockedOffsets;
1023 SmallVector<OpFoldResult> divs, rems;
1024
1025 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1026 divs.push_back(div(offset, block));
1027 rems.push_back(rem(offset, block));
1028 }
1029 blockedOffsets.append(divs.begin(), divs.end());
1030 blockedOffsets.append(rems.begin(), rems.end());
1031
1032 return blockedOffsets;
1033}
1034
1035// Get strides as vector of integer for MemDesc.
1036SmallVector<int64_t> MemDescType::getStrideShape() {
1037
1038 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1039
1040 ArrayAttr strideAttr = getStrideAttr();
1041 SmallVector<int64_t> strides;
1042 for (Attribute attr : strideAttr.getValue()) {
1043 strides.push_back(cast<IntegerAttr>(attr).getInt());
1044 }
1045
1046 SmallVector<int64_t> innerBlkShape = getBlockShape();
1047
1048 // get perm from FCD to LCD
1049 // perm[i] = the dim with i-th smallest stride
1050 SmallVector<int, 4> perm =
1051 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
1052 llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
1053
1054 assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
1055
1056 SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
1057 innerBlkStride[perm[0]] = 1;
1058 for (size_t i = 1; i < perm.size(); ++i)
1059 innerBlkStride[perm[i]] =
1060 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
1061
1062 // compute the original matrix shape using the stride info
1063 // and compute the number of blocks in each dimension
1064 // The shape of highest dim can't be derived from stride info,
1065 // but doesn't impact the stride computation for blocked layout.
1066 SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
1067 SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
1068 for (size_t i = 0; i < perm.size() - 1; ++i) {
1069 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
1070 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
1071 }
1072
1073 int64_t innerBlkSize = 1;
1074 for (auto s : innerBlkShape)
1075 innerBlkSize *= s;
1076
1077 SmallVector<int64_t> outerBlkStride(matrixShape.size());
1078 outerBlkStride[perm[0]] = innerBlkSize;
1079 for (size_t i = 0; i < perm.size() - 1; ++i) {
1080 outerBlkStride[perm[i + 1]] =
1081 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
1082 }
1083
1084 // combine the inner and outer strides
1085 SmallVector<int64_t> blockedStrides;
1086 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
1087 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
1088
1089 return blockedStrides;
1090}
1091
1092// Calculate the linear offset using the blocked offsets and stride
1093Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
1094 ArrayRef<OpFoldResult> offsets) {
1095
1096 SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
1097 SmallVector<int64_t> blockShape = getBlockShape();
1098 SmallVector<int64_t> strides = getStrideShape();
1099 SmallVector<OpFoldResult> blockedOffsets;
1100
1101 // blockshape equal to matrixshape means no blocking
1102 if (llvm::equal(blockShape, matrixShape)) {
1103 // remove the outer dims from strides
1104 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
1105 } else {
1106 assert(offsets.size() == blockShape.size() &&
1107 "offsets and blockShape must have the same size");
1108 // say the original offset is [y, x], and the block shape is [By, Bx],
1109 // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
1110
1111 SmallVector<OpFoldResult> divs, rems;
1112
1113 for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
1114 divs.push_back(div(offset, block));
1115 rems.push_back(rem(offset, block));
1116 }
1117 blockedOffsets.append(divs.begin(), divs.end());
1118 blockedOffsets.append(rems.begin(), rems.end());
1119 offsets = blockedOffsets;
1120 }
1121
1122 // Start with initial value as matrix descriptor's base offset.
1123 Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
1124 for (size_t i = 0; i < offsets.size(); ++i) {
1125 OpFoldResult mulResult = mul(offsets[i], strides[i]);
1126 Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
1127 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
1128 }
1129
1130 return linearOffset;
1131}
1132
1133} // namespace xegpu
1134} // namespace mlir
1135
1136#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
1137#define GET_ATTRDEF_CLASSES
1138#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
1139#define GET_TYPEDEF_CLASSES
1140#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
return success()
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
#define mul(a, b)
#define div(a, b)
#define rem(a, b)
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
Definition Attributes.h:25
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:207
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:362
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
constexpr unsigned generalPackedFormatBitSize
Definition uArchBase.h:32
static SetVector< int64_t > mapSlicedDimsToParentSpace(const SetVector< int64_t > &dimsToMap, ArrayRef< int64_t > sliceDims)
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
static SmallVector< SmallVector< Value > > genCoordinates(OpBuilder &builder, Location loc, SmallVector< Value > delinearizedId, ArrayRef< int64_t > subShapesLayout, ArrayRef< int64_t > subShape, ArrayRef< int64_t > srcShape)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition Utils.cpp:111
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.