16#include "llvm/ADT/SmallVectorExtras.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/Debug.h"
25void XeGPUDialect::initialize() {
27#define GET_TYPEDEF_LIST
28#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
32#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
35#define GET_ATTRDEF_LIST
36#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
39#define GET_OP_INTERFACE_CLASSES
40#include "mlir/Dialect/XeGPU/IR/XeGPUOpInterface.cpp.inc"
59 llvm::zip_equal(srcShape,
61 [](
const auto &t) {
return std::min(std::get<0>(t), std::get<1>(t)); });
65 llvm::zip(delinearizedId, subShape), [&](
const auto &t) ->
Value {
81 llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
82 [&](
const auto &t) ->
Value {
84 loc, std::get<0>(t), std::get<1>(t));
88 llvm::zip_equal(adds, srcShape), [&](
const auto &t) ->
Value {
94 coordinates.push_back(mods);
102 xegpu::DistributeLayoutAttr attr) {
103 assert(attr &&
"Layout attribute is missing.");
120 if (layout.size() !=
shape.size())
123 if (ratio.has_value()) {
124 newShape = ratio.value();
132 if (data.size() != shape.size())
135 if (!ratio.has_value() && rr)
137 if (!ratio.has_value())
147 auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
148 attr.getEffectiveSgDataAsInt());
151 auto sgShape = maybeSgShape.value();
154 auto maybeInstShape =
155 tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(),
false);
158 auto instShape = maybeInstShape.value();
161 auto maybeLaneShape =
162 tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
163 attr.getEffectiveLaneDataAsInt(),
false);
164 return maybeLaneShape.has_value();
170BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
171 xegpu::MemorySpace memory_space,
173 bool boundary_check) {
174 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
176 IntegerAttr::get(IntegerType::get(context, 64), array_length);
178 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
181bool BlockTensorDescAttr::hasDefaultsOnly() {
182 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
183 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
190ScatterTensorDescAttr::get(mlir::MLIRContext *context,
191 xegpu::MemorySpace memory_space,
int chunk_size) {
192 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
194 IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
195 return Base::get(context, scopeAttr, chunkSizeAttr);
198LogicalResult ScatterTensorDescAttr::verify(
199 llvm::function_ref<mlir::InFlightDiagnostic()>
emitError,
200 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
201 int64_t chunkSize = chunk_size.getInt();
203 return emitError() <<
"invalid chunk size";
212LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()>
emitError,
218 if (!sg_layout && !inst_data && !lane_layout)
224 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
226 <<
"expected sg_layout and inst_data to have the same rank";
229 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
231 <<
"expected sg_layout and lane_layout to have the same rank";
234 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
235 return emitError() <<
"expected inst_data and lane_layout to have the same "
236 "rank, got inst_data "
237 << inst_data.size() <<
", lane_layout "
238 << lane_layout.size();
245 return emitError() <<
"expected sg_layout being used with sg_data";
246 if (sg_data.size() != sg_layout.size())
248 <<
"expected sg_data and sg_layout to have the same rank";
255 return emitError() <<
"expected lane_layout being used with lane_data";
256 if (lane_data.size() != lane_layout.size())
258 <<
"expected lane_data and lane_layout to have the same rank";
262 if (!sg_layout && !lane_layout)
264 <<
"expected sg_layout/lane_layout being used with order";
266 if (sg_layout && order.size() != sg_layout.size())
268 <<
"expected order and sg_layout to have the same rank";
270 if (lane_layout && order.size() != lane_layout.size())
272 <<
"expected order and lane_layout to have the same rank";
278FailureOr<SmallVector<Value>>
279LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
281 SmallVector<int64_t> sgLayoutInt;
282 if (isForWorkgroup()) {
283 sgLayoutInt = getEffectiveSgLayoutAsInt();
284 }
else if (isForSubgroup()) {
285 sgLayoutInt = getEffectiveLaneLayoutAsInt();
293 SmallVector<int64_t> order;
294 if (orderAttr && !orderAttr.empty()) {
295 order = llvm::map_to_vector(orderAttr.
asArrayRef(), [](int32_t idx) {
296 return static_cast<int64_t>(idx);
300 order = llvm::to_vector(
301 llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
304 if (order.size() != sgLayoutInt.size()) {
308 SmallVector<Value>
result(sgLayoutInt.size());
309 Value remaining = linearId;
332 for (
size_t i = 0; i < order.size(); ++i) {
333 int64_t dimIdx = order[i];
334 int64_t dimSize = sgLayoutInt[dimIdx];
337 builder.
createOrFold<arith::ConstantIndexOp>(loc, dimSize);
344 builder.
createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
351 if (i < order.size() - 1) {
353 builder.
createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
362FailureOr<SmallVector<SmallVector<Value>>>
363LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
364 Value linearId, ArrayRef<int64_t> shape) {
365 SmallVector<int64_t> layout;
366 SmallVector<int64_t> subShape;
367 if (isForWorkgroup()) {
368 layout = getEffectiveSgLayoutAsInt();
369 subShape = getEffectiveSgDataAsInt();
370 }
else if (isForSubgroup()) {
371 layout = getEffectiveLaneLayoutAsInt();
372 subShape = getEffectiveLaneDataAsInt();
376 if (subShape.empty()) {
378 subShape = derivedShape.value();
384 auto maybeIds = delinearizeId(builder, loc, linearId);
387 SmallVector<Value> ids = *maybeIds;
389 return genCoordinates(builder, loc, ids, layout, subShape, shape);
392bool LayoutAttr::isEqualTo(
const xegpu::DistributeLayoutAttr &other) {
393 if (dyn_cast<xegpu::SliceAttr>(other))
396 return *
this == dyn_cast<xegpu::LayoutAttr>(other);
401LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims)
const {
402 auto sgDataOpt = getSgData();
403 auto instDataOpt = getInstData();
404 auto laneDataOpt = getLaneData();
406 SmallVector<int32_t> sgData;
407 SmallVector<int32_t> instData;
408 SmallVector<int32_t> laneData;
411 sgData = llvm::to_vector(sgDataOpt.asArrayRef());
414 instData = llvm::to_vector(instDataOpt.asArrayRef());
417 laneData = llvm::to_vector(laneDataOpt.asArrayRef());
419 for (
auto dim : unitDims) {
420 if (dim <
static_cast<int64_t
>(sgData.size()))
422 if (dim <
static_cast<int64_t
>(instData.size()))
424 if (dim <
static_cast<int64_t
>(laneData.size()))
428 return LayoutAttr::get(
442LayoutAttr::setUnitDimLayout(SmallVector<int64_t> unitDims)
const {
443 auto sgLayoutOpt = getSgLayout();
444 auto laneLayoutOpt = getLaneLayout();
446 SmallVector<int32_t> sgLayout;
447 SmallVector<int32_t> laneLayout;
450 sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
452 laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
454 for (
auto dim : unitDims) {
455 if (dim <
static_cast<int64_t
>(sgLayout.size()))
457 if (dim <
static_cast<int64_t
>(laneLayout.size()))
461 return LayoutAttr::get(
465 getSgData(), getInstData(),
468 getLaneData(), getOrder());
473DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
477 SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
478 SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
479 SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
481 if (dim <
static_cast<int64_t
>(sgDataVec.size()) && sgData != -1)
482 sgDataVec[dim] = sgData;
483 if (dim <
static_cast<int64_t
>(instDataVec.size()) && instData != -1)
484 instDataVec[dim] = instData;
485 if (dim <
static_cast<int64_t
>(laneDataVec.size()) && laneData != -1)
486 laneDataVec[dim] = laneData;
488 SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
489 SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
490 SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
492 return LayoutAttr::get(
507DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
509 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
510 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
511 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
512 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
513 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
516 SmallVector<int32_t> orderVec;
517 if (orderAttr && !orderAttr.empty()) {
518 orderVec = llvm::to_vector(
520 [](int32_t idx) { return static_cast<int32_t>(idx); }));
523 SmallVector<int64_t> sortedDimGroup = dimGroup;
524 llvm::sort(sortedDimGroup);
525 int64_t dimBeforeCurrent = -1;
526 for (
auto dimIdx : sortedDimGroup) {
530 if (dimBeforeCurrent >= 0) {
531 if (!orderVec.empty()) {
532 int64_t orderBefore = orderVec[dimBeforeCurrent];
533 int64_t orderCurrent = orderVec[dimIdx];
534 if (orderBefore != (orderCurrent - 1))
535 llvm::report_fatal_error(
536 "dimensions being collapsed must be adjacent in order");
538 if (dimIdx != (dimBeforeCurrent + 1))
539 llvm::report_fatal_error(
540 "dimensions being collapsed must be adjacent");
543 dimBeforeCurrent = dimIdx;
546 int firstDim = sortedDimGroup.front();
551 if (!sgLayout.empty()) {
552 int64_t collapsedSglayout = 1, collapsedSgData = 1;
553 for (
auto dimIdx : dimGroup) {
554 collapsedSglayout *= sgLayout[dimIdx];
555 collapsedSgData *= sgData[dimIdx];
557 for (
auto dimIdx : llvm::reverse(sortedDimGroup)) {
558 sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
559 sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
561 sgLayout.insert(sgLayout.begin() + firstDim, collapsedSglayout);
562 sgData.insert(sgData.begin() + firstDim, collapsedSgData);
565 if (!instData.empty()) {
566 int64_t collapsedInstData = 1;
567 for (
auto dimIdx : dimGroup)
568 collapsedInstData *= instData[dimIdx];
569 for (
auto dimIdx : llvm::reverse(sortedDimGroup))
570 instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
571 instData.insert(instData.begin() + firstDim, collapsedInstData);
574 if (!laneLayout.empty()) {
575 int64_t collapsedLaneLayout = 1, collapsedLaneData = 1;
576 for (
auto dimIdx : dimGroup) {
577 collapsedLaneLayout *= laneLayout[dimIdx];
578 collapsedLaneData *= laneData[dimIdx];
580 for (
auto dimIdx : llvm::reverse(sortedDimGroup)) {
581 laneLayout.erase(laneLayout.begin() + dimIdx,
582 laneLayout.begin() + dimIdx + 1);
583 laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
585 laneLayout.insert(laneLayout.begin() + firstDim, collapsedLaneLayout);
586 laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
593 SmallVector<int32_t> collapsedOrder;
594 if (!orderVec.empty()) {
596 for (
auto dimIdx : llvm::reverse(sortedDimGroup)) {
597 if (dimIdx != firstDim)
598 orderVec.erase(orderVec.begin() + dimIdx,
599 orderVec.begin() + dimIdx + 1);
605 llvm::to_vector(llvm::seq<size_t>(0, orderVec.size()));
609 [&](
size_t a,
size_t b) {
return orderVec[a] < orderVec[
b]; });
610 collapsedOrder = llvm::to_vector(llvm::map_range(
611 indices, [&](
size_t i) {
return static_cast<int32_t
>(i); }));
615 SmallVector<int32_t> sgLayout32(sgLayout.begin(), sgLayout.end());
616 SmallVector<int32_t> sgData32(sgData.begin(), sgData.end());
617 SmallVector<int32_t> instData32(instData.begin(), instData.end());
618 SmallVector<int32_t> laneLayout32(laneLayout.begin(), laneLayout.end());
619 SmallVector<int32_t> laneData32(laneData.begin(), laneData.end());
621 auto collapsedLayout = xegpu::LayoutAttr::get(
633 collapsedOrder.empty()
636 return collapsedLayout;
643SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()>
emitError,
647 return emitError() <<
"expected dims attribute";
650 llvm::SmallDenseSet<int64_t> seen;
653 return emitError() <<
"invalid dim (" << dim <<
") in slice attribute.";
654 if (!seen.insert(dim).second)
655 return emitError() <<
"repeated dim (" << dim <<
") in slice attribute.";
660SliceAttr SliceAttr::flatten()
const {
661 xegpu::DistributeLayoutAttr parent = getParent();
662 SmallVector<DenseI64ArrayAttr> slicedDims({
getDims()});
664 while (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
665 parent = sliceAttr.getParent();
666 slicedDims.push_back(sliceAttr.getDims());
669 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
671 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
674 SmallVector<int64_t> remainingDims(
indices);
675 for (
auto dim : llvm::reverse(slicedDims))
676 remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
680 SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
681 llvm::ArrayRef<int64_t>(
indices), llvm::ArrayRef<int64_t>(remainingDims));
683 return xegpu::SliceAttr::get(
688FailureOr<SmallVector<Value>>
689SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
690 SliceAttr attr = flatten();
691 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
692 return parent.delinearizeId(builder, loc, linearId);
698FailureOr<SmallVector<SmallVector<Value>>>
699SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
700 Value linearId, ArrayRef<int64_t> shape) {
701 assert(getRank() ==
static_cast<int64_t
>(shape.size()) &&
"invalid shape.");
702 if (!isForWorkgroup())
705 SmallVector<int64_t> layout;
706 SmallVector<int64_t> subShape;
707 if (isForWorkgroup()) {
708 layout = getEffectiveSgLayoutAsInt();
709 subShape = getEffectiveSgDataAsInt();
710 }
else if (isForSubgroup()) {
711 layout = getEffectiveLaneLayoutAsInt();
712 subShape = getEffectiveLaneDataAsInt();
717 if (subShape.empty()) {
719 subShape = derivedShape.value();
725 auto maybeIds = delinearizeId(builder, loc, linearId);
731 ArrayRef<int64_t> dims = flatten().getDims().
asArrayRef();
732 SmallVector<Value> sgIds =
733 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
735 return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
738bool SliceAttr::isSliceOf(
const xegpu::DistributeLayoutAttr &other) {
739 auto flattenedThis = flatten();
742 if (
auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
743 return flattenedThis.getParent() == otherLayout;
745 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
747 if (flattenedThis.getParent() != flattenedOther.getParent())
751 llvm::SmallDenseSet<int64_t> thisDims(
752 flattenedThis.getDims().asArrayRef().begin(),
753 flattenedThis.getDims().asArrayRef().end());
754 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
755 [&](int64_t dim) { return thisDims.contains(dim); });
758bool SliceAttr::isEqualTo(
const xegpu::DistributeLayoutAttr &other) {
759 if (dyn_cast<xegpu::LayoutAttr>(other))
762 auto flattenedThis = flatten();
763 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
765 return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
766 (flattenedThis.getDims() == flattenedOther.getDims()));
769xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
770 if (sliceDimsToDrop.empty())
772 SmallVector<int64_t> sliceDims{
getDims().asArrayRef()};
773 for (
auto dim : sliceDimsToDrop) {
774 auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), dim);
775 assert(foundIt != sliceDims.end() &&
776 "Expected to find the specified reduction dim in slice dims");
777 sliceDims.erase(foundIt);
780 auto sliceWithoutDims = xegpu::SliceAttr::get(
784 return sliceWithoutDims;
792static SmallVector<int64_t>
800 std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
802 std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
803 int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
807 llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
810 for (
int64_t i = 0; i < parentSpaceRank; ++i) {
811 if (!slicedDimsSet.contains(i))
812 remainingDims.push_back(i);
817 for (
auto dim : dimsToMap) {
818 int64_t mappedDim = remainingDims[dim];
819 adjustUnitDims.push_back(mappedDim);
822 return adjustUnitDims;
828 DistributeLayoutAttr parentLayout = getParent();
836 parentLayout.setUnitDimData(adjustUnitDims), getDims());
842 DistributeLayoutAttr parentLayout = getParent();
849 return SliceAttr::get(
850 getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
855DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
856 int64_t instData, int64_t laneData) {
857 ArrayRef<int64_t> sliceDims =
getDims().asArrayRef();
858 auto parent = getParent();
860 SmallVector<int64_t> dimSet;
861 dimSet.push_back(dim);
862 SmallVector<int64_t> adjustDims =
864 return SliceAttr::get(
866 parent.setDimData(adjustDims[0], sgData, instData, laneData),
getDims());
872DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
875 SmallVector<int64_t> sliceDims = llvm::to_vector(
getDims().asArrayRef());
877 SmallVector<int64_t> dimsInParentSpace =
880 auto collapsedParent = getParent().collapseDims(dimsInParentSpace);
882 return SliceAttr::get(
getContext(), collapsedParent,
891RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()>
emitError,
892 IntegerAttr startOfRange, IntegerAttr endOfRange) {
893 if (startOfRange.getInt() >= endOfRange.getInt())
894 return emitError() <<
"'end' : " << endOfRange.getInt()
895 <<
" must be greater than 'start' : "
896 << startOfRange.getInt();
905mlir::Type TensorDescType::parse(AsmParser &parser) {
906 llvm::SmallVector<int64_t> shape;
907 mlir::Type elementType;
908 mlir::FailureOr<mlir::Attribute> encoding;
909 mlir::FailureOr<mlir::Attribute> layout;
917 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
922 if (mlir::failed(parser.
parseType(elementType))) {
923 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
929 mlir::Attribute attr;
931 if (mlir::succeeded(res)) {
932 if (mlir::isa<LayoutAttr>(attr)) {
936 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
949 return TensorDescType::getChecked(
951 elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
952 layout.value_or(mlir::Attribute()));
955void TensorDescType::print(AsmPrinter &printer)
const {
959 for (int64_t dim : shape) {
960 if (mlir::ShapedType::isDynamic(dim))
969 auto encoding = getEncoding();
970 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
971 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
972 printer <<
", " << encoding;
974 if (
auto layout = getLayout())
975 printer <<
", " << layout;
980TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
981 mlir::Type elementType,
int array_length,
983 MemorySpace memory_space,
984 mlir::Attribute layout) {
986 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
988 return Base::get(context, shape, elementType, attr, layout);
991TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
992 mlir::Type elementType,
int chunk_size,
993 MemorySpace memory_space,
994 mlir::Attribute layout) {
996 auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
997 return Base::get(context, shape, elementType, attr, layout);
1001TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()>
emitError,
1002 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
1003 mlir::Attribute encoding, mlir::Attribute layout) {
1004 size_t rank = shape.size();
1007 return emitError() <<
"expected non-zero rank tensor";
1009 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1011 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
1012 if (rank > 1 && memorySpaceAttr &&
1013 memorySpaceAttr.getValue() == MemorySpace::SLM)
1014 return emitError() <<
"SLM is only supported for 1D block tensor";
1018 return emitError() <<
"unsupported element type " << elementType
1019 <<
": expected integer or float";
1024 int chunkAlignmentFactor =
1028 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
1030 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
1031 if (rank == 1 && chunkSize != 1)
1032 return emitError() <<
"expected non-contiguous elements for 1D tensor";
1037 if (chunkSize > 1) {
1038 if (shape.back() != chunkSize)
1039 return emitError() <<
"expected last dim of tensor to match chunk size";
1040 if (shape.back() % chunkAlignmentFactor != 0)
1041 return emitError() <<
"expected last dim of tensor to be a multiple of "
1042 << chunkAlignmentFactor;
1046 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
1048 if (rank != (
size_t)layoutAttr.getRank())
1049 return emitError() <<
"expected layout rank to match tensor rank";
1051 auto laneData = layoutAttr.getLaneData();
1052 if (scatterAttr && laneData) {
1056 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
1057 if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
1059 <<
"expected last dim of lane_data to be a multiple of: "
1060 << chunkAlignmentFactor;
1063 if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
1064 std::string shapeStr;
1065 llvm::raw_string_ostream stream(shapeStr);
1066 llvm::interleaveComma(shape, stream);
1067 return emitError() <<
"cannot distribute [" << shapeStr <<
"] using "
1077mlir::Type MemDescType::parse(AsmParser &parser) {
1078 llvm::SmallVector<int64_t> shape;
1079 mlir::Type elementType;
1080 mlir::FailureOr<MemLayoutAttr> layout;
1088 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
1093 if (mlir::failed(parser.
parseType(elementType))) {
1094 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
1102 if (mlir::failed(res))
1112 return MemDescType::getChecked(
1114 elementType, layout.value_or(MemLayoutAttr()));
1117void MemDescType::print(AsmPrinter &printer)
const {
1124 if (
auto layout = getMemLayout())
1125 printer <<
", " << layout;
1134Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
1139 llvm::SmallDenseSet<StringRef> seenKeys;
1140 SmallVector<NamedAttribute> attributes;
1142 auto parseElt = [&]() -> ParseResult {
1145 return parser.
emitError(loc,
"expected valid attribute name");
1147 if (!seenKeys.insert(nameId).second)
1148 return parser.
emitError(loc,
"duplicate key '")
1149 << nameId <<
" in mem layout attribute";
1157 attributes.emplace_back(nameId, attr);
1173 loc, context, DictionaryAttr::get(context, attributes));
1176void MemLayoutAttr::print(AsmPrinter &printer)
const {
1178 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
1179 for (
size_t i = 0; i < attrs.size(); i++) {
1180 printer << attrs[i].getName().str() <<
" = " << attrs[i].getValue();
1181 if (i < attrs.size() - 1)
1190template <
typename ArithOp>
1195 return ArithOp::create(builder, loc, aVal, bVal).getResult();
1200 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
1204 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
1208 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
1211#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
1220 assert(offsets.size() == blockShape.size() &&
1221 "offsets and blockShape must have the same size");
1225 for (
auto [offset, block] : llvm::zip(offsets, blockShape)) {
1226 divs.push_back(
div(offset, block));
1227 rems.push_back(
rem(offset, block));
1229 blockedOffsets.append(divs.begin(), divs.end());
1230 blockedOffsets.append(rems.begin(), rems.end());
1232 return blockedOffsets;
1240 ArrayAttr strideAttr = getStrideAttr();
1242 for (
Attribute attr : strideAttr.getValue()) {
1243 strides.push_back(cast<IntegerAttr>(attr).getInt());
1251 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
1252 llvm::sort(perm, [&](
int a,
int b) {
return strides[a] < strides[
b]; });
1254 assert(strides[perm[0]] == 1 &&
"inner most dim must have stride 1");
1256 SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
1257 innerBlkStride[perm[0]] = 1;
1258 for (
size_t i = 1; i < perm.size(); ++i)
1259 innerBlkStride[perm[i]] =
1260 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
1266 SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
1267 SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
1268 for (
size_t i = 0; i < perm.size() - 1; ++i) {
1269 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
1270 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
1273 int64_t innerBlkSize = 1;
1274 for (
auto s : innerBlkShape)
1277 SmallVector<int64_t> outerBlkStride(matrixShape.size());
1278 outerBlkStride[perm[0]] = innerBlkSize;
1279 for (
size_t i = 0; i < perm.size() - 1; ++i) {
1280 outerBlkStride[perm[i + 1]] =
1281 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
1285 SmallVector<int64_t> blockedStrides;
1286 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
1287 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
1289 return blockedStrides;
1293Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
1294 ArrayRef<OpFoldResult> offsets) {
1297 SmallVector<int64_t> blockShape = getBlockShape();
1298 SmallVector<int64_t> strides = getStrideShape();
1299 SmallVector<OpFoldResult> blockedOffsets;
1302 if (llvm::equal(blockShape, matrixShape)) {
1304 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
1306 assert(offsets.size() == blockShape.size() &&
1307 "offsets and blockShape must have the same size");
1311 SmallVector<OpFoldResult> divs, rems;
1313 for (
auto [offset, block] : llvm::zip(offsets, blockShape)) {
1314 divs.push_back(
div(offset, block));
1315 rems.push_back(
rem(offset, block));
1317 blockedOffsets.append(divs.begin(), divs.end());
1318 blockedOffsets.append(rems.begin(), rems.end());
1319 offsets = blockedOffsets;
1324 for (
size_t i = 0; i < offsets.size(); ++i) {
1325 OpFoldResult mulResult =
mul(offsets[i], strides[i]);
1327 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
1330 return linearOffset;
1336#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
1337#define GET_ATTRDEF_CLASSES
1338#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
1339#define GET_TYPEDEF_CLASSES
1340#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static Type getElementType(Type type)
Determine the element type of type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
constexpr unsigned generalPackedFormatBitSize
static SmallVector< int64_t > mapSlicedDimsToParentSpace(const SmallVector< int64_t > &dimsToMap, ArrayRef< int64_t > sliceDims)
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
static SmallVector< SmallVector< Value > > genCoordinates(OpBuilder &builder, Location loc, SmallVector< Value > delinearizedId, ArrayRef< int64_t > subShapesLayout, ArrayRef< int64_t > subShape, ArrayRef< int64_t > srcShape)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.