17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/Debug.h"
25 void XeGPUDialect::initialize() {
27 #define GET_TYPEDEF_LIST
28 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
32 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
35 #define GET_ATTRDEF_LIST
36 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
54 llvm::zip(sgId, sizePerSg), [&](
const auto &t) ->
Value {
64 [](
const auto &t) {
return std::min(std::get<0>(t), std::get<1>(t)); });
69 llvm::map_to_vector(unitOffs, [&](int64_t d) ->
Value {
74 llvm::zip_equal(base, localOffsets), [&](
const auto &t) ->
Value {
75 return builder.
createOrFold<arith::AddIOp>(loc, std::get<0>(t),
80 llvm::zip_equal(adds, sizePerWg), [&](
const auto &t) ->
Value {
86 offsets.push_back(mods);
94 xegpu::DistributeLayoutAttr attr) {
95 assert(attr &&
"Layout attribute is missing.");
112 if (layout.size() != shape.size())
115 if (ratio.has_value()) {
116 newShape = ratio.value();
124 if (data.size() != shape.size())
127 if (!ratio.has_value() && rr)
129 if (!ratio.has_value())
139 auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
140 attr.getEffectiveSgDataAsInt());
143 auto sgShape = maybeSgShape.value();
146 auto maybeInstShape =
147 tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(),
false);
150 auto instShape = maybeInstShape.value();
153 auto maybeLaneShape =
154 tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
155 attr.getEffectiveLaneDataAsInt(),
false);
156 return maybeLaneShape.has_value();
163 xegpu::MemorySpace memory_space,
165 bool boundary_check) {
170 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
173 bool BlockTensorDescAttr::hasDefaultsOnly() {
174 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
175 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
181 ScatterTensorDescAttr
183 xegpu::MemorySpace memory_space,
int chunk_size) {
187 return Base::get(context, scopeAttr, chunkSizeAttr);
192 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
193 int64_t chunkSize = chunk_size.getInt();
195 return emitError() <<
"invalid chunk size";
212 if (!sg_layout && !inst_data && !lane_layout) {
214 <<
"expected at least one of sg_layout, inst_data or lane_layout";
220 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
222 <<
"expected sg_layout and inst_data to have the same rank";
225 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
227 <<
"expected sg_layout and lane_layout to have the same rank";
230 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
231 return emitError() <<
"expected inst_data and lane_layout to have the same "
232 "rank, got inst_data "
233 << inst_data.size() <<
", lane_layout "
234 << lane_layout.size();
241 return emitError() <<
"expected sg_layout being used with sg_data";
242 if (sg_data.size() != sg_layout.size())
244 <<
"expected sg_data and sg_layout to have the same rank";
251 return emitError() <<
"expected lane_layout being used with lane_data";
252 if (lane_data.size() != lane_layout.size())
254 <<
"expected lane_data and lane_layout to have the same rank";
258 if (!sg_layout && !lane_layout)
260 <<
"expected sg_layout/lane_layout being used with order";
262 if (sg_layout && order.size() != sg_layout.size())
264 <<
"expected order and sg_layout to have the same rank";
266 if (lane_layout && order.size() != lane_layout.size())
268 <<
"expected order and lane_layout to have the same rank";
274 FailureOr<SmallVector<Value>>
279 if (!isForWorkgroup())
283 auto hasDefaultOrder = [&]() {
288 if (!hasDefaultOrder())
289 return mlir::emitError(loc,
"order attribute is currently not supported.");
292 llvm::map_to_vector(getEffectiveSgLayoutAsInt(), [&](int64_t d) ->
Value {
293 return builder.
createOrFold<arith::ConstantIndexOp>(loc, d);
302 FailureOr<SmallVector<SmallVector<Value>>>
305 if (!isForWorkgroup())
310 if (sgShape.empty()) {
312 sgShape = derivedShape.value();
318 auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
333 if (!parent || !dims)
334 return emitError() <<
"expected parent layout and dims attribute";
336 int64_t rank = parent.getRank();
339 llvm::SmallDenseSet<int64_t> seen;
341 if (dim < 0 || dim >= rank)
342 return emitError() <<
"invalid dim (" << dim <<
") in slice attribute.";
343 if (!seen.insert(dim).second)
344 return emitError() <<
"repeated dim (" << dim <<
") in slice attribute.";
349 SliceAttr SliceAttr::flatten()
const {
350 xegpu::DistributeLayoutAttr parent = getParent();
353 while (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
354 parent = sliceAttr.getParent();
355 slicedDims.push_back(sliceAttr.getDims());
358 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
360 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
364 for (
auto dim : llvm::reverse(slicedDims))
377 FailureOr<SmallVector<Value>>
380 SliceAttr attr = flatten();
381 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
382 return parent.delinearizeSubgroupId(builder, loc, linearId);
388 FailureOr<SmallVector<SmallVector<Value>>>
391 assert(getRank() ==
static_cast<int64_t
>(shape.size()) &&
"invalid shape.");
392 if (!isForWorkgroup())
397 if (sgShape.empty()) {
399 sgShape = derivedShape.value();
405 auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
419 bool SliceAttr::isSliceOf(
const xegpu::DistributeLayoutAttr &other) {
420 auto flattenedThis = flatten();
423 if (
auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
424 return flattenedThis.getParent() == otherLayout;
426 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
428 if (flattenedThis.getParent() != flattenedOther.getParent())
432 llvm::SmallDenseSet<int64_t> thisDims(
433 flattenedThis.getDims().asArrayRef().begin(),
434 flattenedThis.getDims().asArrayRef().end());
435 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
436 [&](int64_t dim) { return thisDims.contains(dim); });
445 IntegerAttr startOfRange, IntegerAttr endOfRange) {
446 if (startOfRange.getInt() >= endOfRange.getInt())
447 return emitError() <<
"'end' : " << endOfRange.getInt()
448 <<
" must be greater than 'start' : "
449 << startOfRange.getInt();
461 mlir::FailureOr<mlir::Attribute> encoding;
462 mlir::FailureOr<mlir::Attribute> layout;
470 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
476 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
484 if (mlir::succeeded(res)) {
485 if (mlir::isa<LayoutAttr>(attr)) {
489 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
502 return TensorDescType::getChecked(
512 for (int64_t dim : shape) {
513 if (mlir::ShapedType::isDynamic(dim))
522 auto encoding = getEncoding();
523 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
524 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
525 printer <<
", " << encoding;
527 if (
auto layout = getLayout())
528 printer <<
", " << layout;
536 MemorySpace memory_space,
541 return Base::get(context, shape, elementType, attr, layout);
546 MemorySpace memory_space,
550 return Base::get(context, shape, elementType, attr, layout);
557 size_t rank = shape.size();
560 return emitError() <<
"expected non-zero rank tensor";
562 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
564 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
565 if (rank > 1 && memorySpaceAttr &&
566 memorySpaceAttr.getValue() == MemorySpace::SLM)
567 return emitError() <<
"SLM is only supported for 1D block tensor";
572 int chunkAlignmentFactor =
576 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
578 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
579 if (rank == 1 && chunkSize != 1)
580 return emitError() <<
"expected non-contiguous elements for 1D tensor";
586 if (shape.back() != chunkSize)
587 return emitError() <<
"expected last dim of tensor to match chunk size";
588 if (shape.back() % chunkAlignmentFactor != 0)
589 return emitError() <<
"expected last dim of tensor to be a multiple of "
590 << chunkAlignmentFactor;
594 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
596 if (rank != (
size_t)layoutAttr.getRank())
597 return emitError() <<
"expected layout rank to match tensor rank";
599 auto laneData = layoutAttr.getLaneData();
600 if (scatterAttr && laneData) {
604 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
605 if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
607 <<
"expected last dim of lane_data to be a multiple of: "
608 << chunkAlignmentFactor;
611 if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
612 std::string shapeStr;
613 llvm::raw_string_ostream stream(shapeStr);
614 llvm::interleaveComma(shape, stream);
615 return emitError() <<
"cannot distribute [" << shapeStr <<
"] using "
628 mlir::FailureOr<MemLayoutAttr> layout;
636 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
642 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
660 return MemDescType::getChecked(
662 elementType, layout.value_or(MemLayoutAttr()));
672 if (
auto layout = getMemLayout())
673 printer <<
", " << layout;
687 llvm::SmallDenseSet<StringRef> seenKeys;
690 auto parseElt = [&]() -> ParseResult {
693 return parser.
emitError(loc,
"expected valid attribute name");
695 if (!seenKeys.insert(nameId).second)
696 return parser.
emitError(loc,
"duplicate key '")
697 << nameId <<
" in mem layout attribute";
705 attributes.emplace_back(nameId, attr);
727 for (
size_t i = 0; i < attrs.size(); i++) {
728 printer << attrs[i].getName().str() <<
" = " << attrs[i].getValue();
729 if (i < attrs.size() - 1)
738 template <
typename ArithOp>
743 return ArithOp::create(builder, loc, aVal, bVal).getResult();
748 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
752 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
756 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
759 #define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
768 assert(offsets.size() == blockShape.size() &&
769 "offsets and blockShape must have the same size");
773 for (
auto [offset, block] : llvm::zip(offsets, blockShape)) {
774 divs.push_back(
div(offset, block));
775 rems.push_back(
rem(offset, block));
777 blockedOffsets.append(divs.begin(), divs.end());
778 blockedOffsets.append(rems.begin(), rems.end());
780 return blockedOffsets;
788 ArrayAttr strideAttr = getStrideAttr();
790 for (
Attribute attr : strideAttr.getValue()) {
791 strides.push_back(cast<IntegerAttr>(attr).getInt());
799 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
800 llvm::sort(perm, [&](
int a,
int b) {
return strides[a] < strides[b]; });
802 assert(strides[perm[0]] == 1 &&
"inner most dim must have stride 1");
805 innerBlkStride[perm[0]] = 1;
806 for (
size_t i = 1; i < perm.size(); ++i)
807 innerBlkStride[perm[i]] =
808 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
816 for (
size_t i = 0; i < perm.size() - 1; ++i) {
817 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
818 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
821 int64_t innerBlkSize = 1;
822 for (
auto s : innerBlkShape)
826 outerBlkStride[perm[0]] = innerBlkSize;
827 for (
size_t i = 0; i < perm.size() - 1; ++i) {
828 outerBlkStride[perm[i + 1]] =
829 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
834 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
835 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
837 return blockedStrides;
850 if (llvm::equal(blockShape, matrixShape)) {
852 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
854 assert(offsets.size() == blockShape.size() &&
855 "offsets and blockShape must have the same size");
861 for (
auto [offset, block] : llvm::zip(offsets, blockShape)) {
862 divs.push_back(
div(offset, block));
863 rems.push_back(
rem(offset, block));
865 blockedOffsets.append(divs.begin(), divs.end());
866 blockedOffsets.append(rems.begin(), rems.end());
867 offsets = blockedOffsets;
872 for (
size_t i = 0; i < offsets.size(); ++i) {
875 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
884 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
885 #define GET_ATTRDEF_CLASSES
886 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
887 #define GET_TYPEDEF_CLASSES
888 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
constexpr unsigned generalPackedFormatBitSize
static SmallVector< SmallVector< Value > > genOffsetsComputingInsts(OpBuilder &builder, Location loc, SmallVector< Value > sgId, ArrayRef< int64_t > sgLayout, ArrayRef< int64_t > sizePerSg, ArrayRef< int64_t > sizePerWg)
Generates instructions to compute offsets for a subgroup identified by its multidimensional indices (...
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...