17 #include "llvm/ADT/TypeSwitch.h"
18 #include "llvm/Support/Debug.h"
25 void XeGPUDialect::initialize() {
27 #define GET_TYPEDEF_LIST
28 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
32 #include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
35 #define GET_ATTRDEF_LIST
36 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
44 static SmallVector<SmallVector<Value>>
54 llvm::zip(sgId, sizePerSg), [&](
const auto &t) ->
Value {
64 [](
const auto &t) {
return std::min(std::get<0>(t), std::get<1>(t)); });
69 llvm::map_to_vector(unitOffs, [&](int64_t d) ->
Value {
74 llvm::zip_equal(base, localOffsets), [&](
const auto &t) ->
Value {
75 return builder.
createOrFold<arith::AddIOp>(loc, std::get<0>(t),
80 llvm::zip_equal(adds, sizePerWg), [&](
const auto &t) ->
Value {
86 offsets.push_back(mods);
94 xegpu::DistributeLayoutAttr attr) {
95 assert(attr &&
"Layout attribute is missing.");
112 if (layout.size() != shape.size())
115 if (!ratio.has_value())
117 newShape = ratio.value();
121 if (data.size() != shape.size())
124 if (!ratio.has_value() && rr)
126 if (!ratio.has_value())
137 tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt());
140 auto sgShape = maybeSgShape.value();
143 auto maybeInstShape =
144 tryDistribute(sgShape, {}, attr.getInstDataAsInt(),
false);
147 auto instShape = maybeInstShape.value();
150 auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(),
151 attr.getLaneDataAsInt(),
false);
152 return maybeLaneShape.has_value();
159 xegpu::MemorySpace memory_space,
161 bool boundary_check) {
166 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
169 bool BlockTensorDescAttr::hasDefaultsOnly() {
170 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
171 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
177 ScatterTensorDescAttr
179 xegpu::MemorySpace memory_space,
int chunk_size) {
183 return Base::get(context, scopeAttr, chunkSizeAttr);
188 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
189 int64_t chunkSize = chunk_size.getInt();
191 return emitError() <<
"invalid chunk size";
208 if (!sg_layout && !inst_data && !lane_layout) {
210 <<
"expected at least one of sg_layout, inst_data or lane_layout";
216 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
218 <<
"expected sg_layout and inst_data to have the same rank";
221 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
223 <<
"expected sg_layout and lane_layout to have the same rank";
226 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
228 <<
"expected inst_data and lane_layout to have the same rank";
235 return emitError() <<
"expected sg_layout being used with sg_data";
236 if (sg_data.size() != sg_layout.size())
238 <<
"expected sg_data and sg_layout to have the same rank";
245 return emitError() <<
"expected lane_layout being used with lane_data";
246 if (lane_data.size() != lane_layout.size())
248 <<
"expected lane_data and lane_layout to have the same rank";
252 if (!sg_layout && !lane_layout)
254 <<
"expected sg_layout/lane_layout being used with order";
256 if (sg_layout && order.size() != sg_layout.size())
258 <<
"expected order and sg_layout to have the same rank";
260 if (lane_layout && order.size() != lane_layout.size())
262 <<
"expected order and lane_layout to have the same rank";
268 FailureOr<SmallVector<Value>>
269 LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
273 if (!isForWorkgroup())
277 auto hasDefaultOrder = [&]() {
280 llvm::reverse(order.asArrayRef())));
282 if (!hasDefaultOrder())
283 return mlir::emitError(loc,
"order attribute is currently not supported.");
285 auto dims = llvm::map_to_vector(getSgLayoutAsInt(), [&](int64_t d) -> Value {
286 return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
295 FailureOr<SmallVector<SmallVector<Value>>>
296 LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
297 ArrayRef<int64_t> shape) {
298 if (!isForWorkgroup())
301 SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
302 SmallVector<int64_t> sgShape = getSgDataAsInt();
303 if (sgShape.empty()) {
305 sgShape = derivedShape.value();
311 auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
314 SmallVector<Value> sgIds = *maybeIds;
326 if (!parent || !dims)
327 return emitError() <<
"expected parent layout and dims attribute";
329 int64_t rank = parent.getRank();
332 llvm::SmallDenseSet<int64_t> seen;
333 for (int64_t dim : dims.asArrayRef()) {
334 if (dim < 0 || dim >= rank)
335 return emitError() <<
"invalid dim (" << dim <<
") in slice attribute.";
336 if (!seen.insert(dim).second)
337 return emitError() <<
"repeated dim (" << dim <<
") in slice attribute.";
342 SliceAttr SliceAttr::flatten()
const {
343 xegpu::DistributeLayoutAttr parent = getParent();
344 SmallVector<DenseI64ArrayAttr> slicedDims({
getDims()});
346 while (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
347 parent = sliceAttr.getParent();
348 slicedDims.push_back(sliceAttr.getDims());
351 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
352 SmallVector<int64_t> indices =
353 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
356 SmallVector<int64_t> remainingDims(indices);
357 for (
auto dim : llvm::reverse(slicedDims))
362 SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
370 FailureOr<SmallVector<Value>>
371 SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
373 SliceAttr attr = flatten();
374 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
375 return parent.delinearizeSubgroupId(builder, loc, linearId);
381 FailureOr<SmallVector<SmallVector<Value>>>
382 SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
383 ArrayRef<int64_t> shape) {
384 assert(getRank() ==
static_cast<int64_t
>(shape.size()) &&
"invalid shape.");
385 if (!isForWorkgroup())
388 SmallVector<int64_t> sgLayout = getSgLayoutAsInt();
389 SmallVector<int64_t> sgShape = getSgDataAsInt();
390 if (sgShape.empty()) {
392 sgShape = derivedShape.value();
398 auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
404 ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
405 SmallVector<Value> sgIds =
406 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
418 IntegerAttr startOfRange, IntegerAttr endOfRange) {
419 if (startOfRange.getInt() >= endOfRange.getInt())
420 return emitError() <<
"'end' : " << endOfRange.getInt()
421 <<
" must be greater than 'start' : "
422 << startOfRange.getInt();
434 mlir::FailureOr<mlir::Attribute> encoding;
435 mlir::FailureOr<mlir::Attribute> layout;
438 if (parser.parseLess())
441 auto shapeLoc = parser.getCurrentLocation();
443 parser.emitError(shapeLoc,
"failed to parse parameter 'shape'");
447 auto elemTypeLoc = parser.getCurrentLocation();
449 parser.emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
454 while (mlir::succeeded(parser.parseOptionalComma())) {
456 ParseResult res = parser.parseAttribute(attr);
457 if (mlir::succeeded(res)) {
458 if (mlir::isa<LayoutAttr>(attr)) {
462 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
471 if (parser.parseGreater())
475 return TensorDescType::getChecked(
476 [&]() {
return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
485 for (int64_t dim : shape) {
486 if (mlir::ShapedType::isDynamic(dim))
495 auto encoding = getEncoding();
496 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
497 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
498 printer <<
", " << encoding;
500 if (
auto layout = getLayout())
501 printer <<
", " << layout;
509 MemorySpace memory_space,
514 return Base::get(context, shape, elementType, attr, layout);
519 MemorySpace memory_space,
523 return Base::get(context, shape, elementType, attr, layout);
530 size_t rank = shape.size();
533 return emitError() <<
"expected non-zero rank tensor";
535 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
537 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
538 if (rank > 1 && memorySpaceAttr &&
539 memorySpaceAttr.getValue() == MemorySpace::SLM)
540 return emitError() <<
"SLM is only supported for 1D block tensor";
545 int chunkAlignmentFactor =
549 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
551 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
552 if (rank == 1 && chunkSize != 1)
553 return emitError() <<
"expected non-contiguous elements for 1D tensor";
559 if (shape.back() != chunkSize)
560 return emitError() <<
"expected last dim of tensor to match chunk size";
561 if (shape.back() % chunkAlignmentFactor != 0)
562 return emitError() <<
"expected last dim of tensor to be a multiple of "
563 << chunkAlignmentFactor;
567 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
569 if (rank != (
size_t)layoutAttr.getRank())
570 return emitError() <<
"expected layout rank to match tensor rank";
572 auto laneData = layoutAttr.getLaneData();
573 if (scatterAttr && laneData) {
577 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
578 if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
580 <<
"expected last dim of lane_data to be a multiple of: "
581 << chunkAlignmentFactor;
584 if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
585 std::string shapeStr;
586 llvm::raw_string_ostream stream(shapeStr);
587 llvm::interleaveComma(shape, stream);
588 return emitError() <<
"cannot distribute [" << shapeStr <<
"] using "
601 mlir::FailureOr<MemLayoutAttr> layout;
604 if (parser.parseLess())
607 auto shapeLoc = parser.getCurrentLocation();
608 if (
mlir::failed(parser.parseDimensionList(shape,
false,
true))) {
609 parser.emitError(shapeLoc,
"failed to parse parameter 'shape'");
613 auto elemTypeLoc = parser.getCurrentLocation();
615 parser.emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
620 if (mlir::succeeded(parser.parseOptionalComma())) {
622 ParseResult res = parser.parseAttribute(attr);
629 if (parser.parseGreater())
633 return MemDescType::getChecked(
634 [&]() {
return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
635 elementType, layout.value_or(MemLayoutAttr()));
641 printer.printDimensionList(
getShape());
645 if (
auto layout = getMemLayout())
646 printer <<
", " << layout;
657 auto context = parser.getContext();
658 llvm::SMLoc loc = parser.getCurrentLocation();
660 llvm::SmallDenseSet<StringRef> seenKeys;
661 SmallVector<NamedAttribute> attributes;
663 auto parseElt = [&]() -> ParseResult {
665 if (
failed(parser.parseKeyword(&nameId)))
666 return parser.emitError(loc,
"expected valid attribute name");
668 if (!seenKeys.insert(nameId).second)
669 return parser.emitError(loc,
"duplicate key '")
670 << nameId <<
" in mem layout attribute";
672 if (
failed(parser.parseEqual()))
676 if (
failed(parser.parseAttribute(attr)))
678 attributes.emplace_back(nameId, attr);
683 if (parser.parseLess())
686 if (
failed(parser.parseCommaSeparatedList(parseElt)))
690 if (parser.parseGreater())
693 return parser.getChecked<MemLayoutAttr>(
699 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
700 for (
size_t i = 0; i < attrs.size(); i++) {
701 printer << attrs[i].getName().str() <<
" = " << attrs[i].getValue();
702 if (i < attrs.size() - 1)
711 #include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
712 #define GET_ATTRDEF_CLASSES
713 #include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
714 #define GET_TYPEDEF_CLASSES
715 #include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
static BoolAttr get(MLIRContext *context, bool value)
This class represents a diagnostic that is inflight and set to be reported.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
FailureOr< SmallVector< Value > > delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef< Value > basis, bool hasOuterBound=true)
Generate the IR to delinearize linearIndex given the basis and return the multi-index.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
constexpr unsigned packedSizeInBitsForGatherScatter
static SmallVector< SmallVector< Value > > genOffsetsComputingInsts(OpBuilder &builder, Location loc, SmallVector< Value > sgId, ArrayRef< int64_t > sgLayout, ArrayRef< int64_t > sizePerSg, ArrayRef< int64_t > sizePerWg)
Generates instructions to compute offsets for a subgroup identified by its multidimensional indices (...
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...