18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Debug.h"
26 void XeGPUDialect::initialize() {
28 #define GET_TYPEDEF_LIST
29 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
33 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
36 #define GET_ATTRDEF_LIST
37 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
55 llvm::zip(sgId, sizePerSg), [&](
const auto &t) ->
Value {
65 [](
const auto &t) {
return std::min(std::get<0>(t), std::get<1>(t)); });
70 llvm::map_to_vector(unitOffs, [&](int64_t d) ->
Value {
75 llvm::zip_equal(base, localOffsets), [&](
const auto &t) ->
Value {
76 return builder.
createOrFold<arith::AddIOp>(loc, std::get<0>(t),
81 llvm::zip_equal(adds, sizePerWg), [&](
const auto &t) ->
Value {
87 offsets.push_back(mods);
95 xegpu::DistributeLayoutAttr attr) {
96 assert(attr &&
"Layout attribute is missing.");
113 if (layout.size() != shape.size())
116 if (!ratio.has_value())
118 newShape = ratio.value();
122 if (data.size() != shape.size())
125 if (!ratio.has_value() && rr)
127 if (!ratio.has_value())
137 auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
138 attr.getEffectiveSgDataAsInt());
141 auto sgShape = maybeSgShape.value();
144 auto maybeInstShape =
145 tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(),
false);
148 auto instShape = maybeInstShape.value();
151 auto maybeLaneShape =
152 tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
153 attr.getEffectiveLaneDataAsInt(),
false);
154 return maybeLaneShape.has_value();
161 xegpu::MemorySpace memory_space,
163 bool boundary_check) {
168 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
171 bool BlockTensorDescAttr::hasDefaultsOnly() {
172 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
173 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
179 ScatterTensorDescAttr
181 xegpu::MemorySpace memory_space,
int chunk_size) {
185 return Base::get(context, scopeAttr, chunkSizeAttr);
190 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
191 int64_t chunkSize = chunk_size.getInt();
193 return emitError() <<
"invalid chunk size";
210 if (!sg_layout && !inst_data && !lane_layout) {
212 <<
"expected at least one of sg_layout, inst_data or lane_layout";
218 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
220 <<
"expected sg_layout and inst_data to have the same rank";
223 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
225 <<
"expected sg_layout and lane_layout to have the same rank";
228 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
230 <<
"expected inst_data and lane_layout to have the same rank";
237 return emitError() <<
"expected sg_layout being used with sg_data";
238 if (sg_data.size() != sg_layout.size())
240 <<
"expected sg_data and sg_layout to have the same rank";
247 return emitError() <<
"expected lane_layout being used with lane_data";
248 if (lane_data.size() != lane_layout.size())
250 <<
"expected lane_data and lane_layout to have the same rank";
254 if (!sg_layout && !lane_layout)
256 <<
"expected sg_layout/lane_layout being used with order";
258 if (sg_layout && order.size() != sg_layout.size())
260 <<
"expected order and sg_layout to have the same rank";
262 if (lane_layout && order.size() != lane_layout.size())
264 <<
"expected order and lane_layout to have the same rank";
270 FailureOr<SmallVector<Value>>
275 if (!isForWorkgroup())
279 auto hasDefaultOrder = [&]() {
284 if (!hasDefaultOrder())
285 return mlir::emitError(loc,
"order attribute is currently not supported.");
288 llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) ->
Value {
289 return builder.
createOrFold<arith::ConstantIndexOp>(loc, d);
298 FailureOr<SmallVector<SmallVector<Value>>>
301 if (!isForWorkgroup())
306 if (sgShape.empty()) {
308 sgShape = derivedShape.value();
314 auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
329 if (!parent || !dims)
330 return emitError() <<
"expected parent layout and dims attribute";
332 int64_t rank = parent.getRank();
335 llvm::SmallDenseSet<int64_t> seen;
337 if (dim < 0 || dim >= rank)
338 return emitError() <<
"invalid dim (" << dim <<
") in slice attribute.";
339 if (!seen.insert(dim).second)
340 return emitError() <<
"repeated dim (" << dim <<
") in slice attribute.";
345 SliceAttr SliceAttr::flatten()
const {
346 xegpu::DistributeLayoutAttr parent = getParent();
349 while (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
350 parent = sliceAttr.getParent();
351 slicedDims.push_back(sliceAttr.getDims());
354 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
356 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
360 for (
auto dim : llvm::reverse(slicedDims))
373 FailureOr<SmallVector<Value>>
376 SliceAttr attr = flatten();
377 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
378 return parent.delinearizeSubgroupId(builder, loc, linearId);
384 FailureOr<SmallVector<SmallVector<Value>>>
387 assert(getRank() ==
static_cast<int64_t
>(shape.size()) &&
"invalid shape.");
388 if (!isForWorkgroup())
393 if (sgShape.empty()) {
395 sgShape = derivedShape.value();
401 auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
415 bool SliceAttr::isSliceOf(
const xegpu::DistributeLayoutAttr &other) {
416 auto flattenedThis = flatten();
419 if (
auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
420 return flattenedThis.getParent() == otherLayout;
422 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
424 if (flattenedThis.getParent() != flattenedOther.getParent())
428 llvm::SmallDenseSet<int64_t> thisDims(
429 flattenedThis.getDims().asArrayRef().begin(),
430 flattenedThis.getDims().asArrayRef().end());
431 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
432 [&](int64_t dim) { return thisDims.contains(dim); });
441 IntegerAttr startOfRange, IntegerAttr endOfRange) {
442 if (startOfRange.getInt() >= endOfRange.getInt())
443 return emitError() <<
"'end' : " << endOfRange.getInt()
444 <<
" must be greater than 'start' : "
445 << startOfRange.getInt();
457 mlir::FailureOr<mlir::Attribute> encoding;
458 mlir::FailureOr<mlir::Attribute> layout;
466 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
472 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
480 if (mlir::succeeded(res)) {
481 if (mlir::isa<LayoutAttr>(attr)) {
485 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
498 return TensorDescType::getChecked(
508 for (int64_t dim : shape) {
509 if (mlir::ShapedType::isDynamic(dim))
518 auto encoding = getEncoding();
519 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
520 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
521 printer <<
", " << encoding;
523 if (
auto layout = getLayout())
524 printer <<
", " << layout;
532 MemorySpace memory_space,
537 return Base::get(context, shape, elementType, attr, layout);
542 MemorySpace memory_space,
546 return Base::get(context, shape, elementType, attr, layout);
553 size_t rank = shape.size();
556 return emitError() <<
"expected non-zero rank tensor";
558 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
560 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
561 if (rank > 1 && memorySpaceAttr &&
562 memorySpaceAttr.getValue() == MemorySpace::SLM)
563 return emitError() <<
"SLM is only supported for 1D block tensor";
568 int chunkAlignmentFactor =
572 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
574 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
575 if (rank == 1 && chunkSize != 1)
576 return emitError() <<
"expected non-contiguous elements for 1D tensor";
582 if (shape.back() != chunkSize)
583 return emitError() <<
"expected last dim of tensor to match chunk size";
584 if (shape.back() % chunkAlignmentFactor != 0)
585 return emitError() <<
"expected last dim of tensor to be a multiple of "
586 << chunkAlignmentFactor;
590 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
592 if (rank != (
size_t)layoutAttr.getRank())
593 return emitError() <<
"expected layout rank to match tensor rank";
595 auto laneData = layoutAttr.getLaneData();
596 if (scatterAttr && laneData) {
600 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
601 if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
603 <<
"expected last dim of lane_data to be a multiple of: "
604 << chunkAlignmentFactor;
607 if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
608 std::string shapeStr;
609 llvm::raw_string_ostream stream(shapeStr);
610 llvm::interleaveComma(shape, stream);
611 return emitError() <<
"cannot distribute [" << shapeStr <<
"] using "
624 mlir::FailureOr<MemLayoutAttr> layout;
632 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
638 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
656 return MemDescType::getChecked(
658 elementType, layout.value_or(MemLayoutAttr()));
668 if (
auto layout = getMemLayout())
669 printer <<
", " << layout;
683 llvm::SmallDenseSet<StringRef> seenKeys;
686 auto parseElt = [&]() -> ParseResult {
689 return parser.
emitError(loc,
"expected valid attribute name");
691 if (!seenKeys.insert(nameId).second)
692 return parser.
emitError(loc,
"duplicate key '")
693 << nameId <<
" in mem layout attribute";
701 attributes.emplace_back(nameId, attr);
723 for (
size_t i = 0; i < attrs.size(); i++) {
724 printer << attrs[i].getName().str() <<
" = " << attrs[i].getValue();
725 if (i < attrs.size() - 1)
734 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
735 #define GET_ATTRDEF_CLASSES
736 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
737 #define GET_TYPEDEF_CLASSES
738 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
constexpr unsigned packedSizeInBitsForGatherScatter
static SmallVector< SmallVector< Value > > genOffsetsComputingInsts(OpBuilder &builder, Location loc, SmallVector< Value > sgId, ArrayRef< int64_t > sgLayout, ArrayRef< int64_t > sizePerSg, ArrayRef< int64_t > sizePerWg)
Generates instructions to compute offsets for a subgroup identified by its multidimensional indices (...
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...