18#include "llvm/ADT/SmallVectorExtras.h"
19#include "llvm/ADT/TypeSwitch.h"
20#include "llvm/Support/Debug.h"
27void XeGPUDialect::initialize() {
29#define GET_TYPEDEF_LIST
30#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
34#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
37#define GET_ATTRDEF_LIST
38#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
41#define GET_OP_INTERFACE_CLASSES
42#include "mlir/Dialect/XeGPU/IR/XeGPUOpInterface.cpp.inc"
61 llvm::zip_equal(srcShape,
63 [](
const auto &t) {
return std::min(std::get<0>(t), std::get<1>(t)); });
67 llvm::zip(delinearizedId, subShape), [&](
const auto &t) ->
Value {
83 llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
84 [&](
const auto &t) ->
Value {
86 loc, std::get<0>(t), std::get<1>(t));
90 llvm::zip_equal(adds, srcShape), [&](
const auto &t) ->
Value {
96 coordinates.push_back(mods);
106 for (
size_t i = 0; i <
shape.size(); ++i)
107 distUnitShape[i] = std::min(
shape[i], layout[i] * subShape[i]);
111 for (
size_t i = 0; i <
shape.size(); ++i)
112 localOffset[i] = canonicalIds[i] * subShape[i];
119 for (
size_t i = 0; i <
shape.size(); ++i)
120 coord[i] = (unitOffs[i] + localOffset[i]) %
shape[i];
121 coordinates.push_back(coord);
139 for (
size_t i = 0; i < start.size(); ++i)
140 coord[i] = start[i] + off[i];
141 expanded.push_back(std::move(coord));
159 const xegpu::DistributeLayoutAttr &other,
164 self.getEffectiveLaneDataAsInt() != other.getEffectiveLaneDataAsInt();
167 selfSubShape = self.getEffectiveLaneDataAsInt();
168 otherSubShape = other.getEffectiveLaneDataAsInt();
170 for (
int64_t id : llvm::seq<int64_t>(0, size)) {
171 auto coords = self.computeStaticDistributedCoords(
id,
shape);
172 auto otherCoords = other.computeStaticDistributedCoords(
id,
shape);
177 if (coords != otherCoords)
184bool XeGPUDialect::isSharedMemory(
const MemRefType &memrefTy) {
185 Attribute attr = memrefTy.getMemorySpace();
188 if (
auto intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr))
189 return intAttr.getInt() == 3;
190 if (
auto memrefSpace = llvm::dyn_cast_if_present<MemorySpaceAttr>(attr))
191 return memrefSpace.getValue() == MemorySpace::SLM;
192 if (
auto xevmSpace = llvm::dyn_cast_if_present<xevm::AddrSpaceAttr>(attr))
193 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
194 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
201 xegpu::MemorySpace memory_space,
203 bool boundary_check) {
204 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
206 IntegerAttr::get(IntegerType::get(context, 64), array_length);
208 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
211bool BlockTensorDescAttr::hasDefaultsOnly() {
212 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
213 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
220LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()>
emitError,
226 if (!sg_layout && !inst_data && !lane_layout)
232 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
234 <<
"expected sg_layout and inst_data to have the same rank";
237 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
239 <<
"expected sg_layout and lane_layout to have the same rank";
242 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
243 return emitError() <<
"expected inst_data and lane_layout to have the same "
244 "rank, got inst_data "
245 << inst_data.size() <<
", lane_layout "
246 << lane_layout.size();
249 if ((sg_layout && !sg_data) || (!sg_layout && sg_data))
250 return emitError() <<
"sg_layout and sg_data must be used together";
251 if (sg_layout && sg_data && sg_layout.size() != sg_data.size())
253 <<
"expected sg_data and sg_layout to have the same rank";
255 if ((lane_layout && !lane_data) || (!lane_layout && lane_data))
256 return emitError() <<
"lane_layout and lane_data must be used together";
257 if (lane_layout && lane_data && lane_layout.size() != lane_data.size())
259 <<
"expected lane_data and lane_layout to have the same rank";
262 if (!sg_layout && !lane_layout)
264 <<
"expected sg_layout/lane_layout being used with order";
266 if (sg_layout && order.size() != sg_layout.size())
268 <<
"expected order and sg_layout to have the same rank";
270 if (lane_layout && order.size() != lane_layout.size())
272 <<
"expected order and lane_layout to have the same rank";
278FailureOr<SmallVector<Value>>
279LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
281 SmallVector<int64_t> sgLayoutInt;
282 if (isForWorkgroup()) {
283 sgLayoutInt = getEffectiveSgLayoutAsInt();
284 }
else if (isForSubgroup()) {
285 sgLayoutInt = getEffectiveLaneLayoutAsInt();
293 SmallVector<int64_t> order;
294 if (orderAttr && !orderAttr.empty()) {
295 order = llvm::map_to_vector(orderAttr.
asArrayRef(), [](int32_t idx) {
296 return static_cast<int64_t>(idx);
300 order = llvm::to_vector(
301 llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
304 if (order.size() != sgLayoutInt.size()) {
308 SmallVector<Value>
result(sgLayoutInt.size());
309 Value remaining = linearId;
332 for (
size_t i = 0; i < order.size(); ++i) {
333 int64_t dimIdx = order[i];
334 int64_t dimSize = sgLayoutInt[dimIdx];
337 builder.
createOrFold<arith::ConstantIndexOp>(loc, dimSize);
344 builder.
createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
351 if (i < order.size() - 1) {
353 builder.
createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
362FailureOr<SmallVector<SmallVector<Value>>>
363LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
364 Value linearId, ArrayRef<int64_t> shape) {
365 SmallVector<int64_t> layout;
366 SmallVector<int64_t> subShape;
367 if (isForWorkgroup()) {
368 layout = getEffectiveSgLayoutAsInt();
369 subShape = getEffectiveSgDataAsInt();
370 }
else if (isForSubgroup()) {
371 layout = getEffectiveLaneLayoutAsInt();
372 subShape = getEffectiveLaneDataAsInt();
376 assert(!subShape.empty() &&
"sgdata or lanedata cannot be empty for "
377 "distributed coordinates computation");
380 auto maybeIds = delinearizeId(builder, loc, linearId);
383 SmallVector<Value> ids = *maybeIds;
385 return genCoordinates(builder, loc, ids, layout, subShape, shape);
388bool LayoutAttr::isEqualTo(
const xegpu::DistributeLayoutAttr &other) {
389 if (dyn_cast<xegpu::SliceAttr>(other))
392 return *
this == dyn_cast<xegpu::LayoutAttr>(other);
398SmallVector<SmallVector<int64_t>>
399LayoutAttr::computeStaticDistributedCoords(int64_t linearId,
400 ArrayRef<int64_t> shape) {
401 SmallVector<int64_t> layoutVec;
402 SmallVector<int64_t> subShape;
403 SmallVector<int64_t> instData;
404 if (isForWorkgroup()) {
405 layoutVec = getEffectiveSgLayoutAsInt();
406 subShape = getEffectiveSgDataAsInt();
407 }
else if (isForSubgroup()) {
408 instData = getEffectiveInstDataAsInt();
409 layoutVec = getEffectiveLaneLayoutAsInt();
410 subShape = getEffectiveLaneDataAsInt();
412 if (!instData.empty()) {
416 assert(!subShape.empty() &&
"sgdata or lanedata cannot be empty");
419 SmallVector<int64_t> order = getEffectiveOrderAsInt();
420 SmallVector<int64_t> delinearizedId(layoutVec.size());
421 int64_t remaining = linearId;
422 for (
size_t i = 0; i < order.size(); ++i) {
423 int64_t dimIdx = order[i];
424 delinearizedId[dimIdx] = remaining % layoutVec[dimIdx];
425 remaining = remaining / layoutVec[dimIdx];
433LayoutAttr::setUnitDimData(SmallVector<int64_t> unitDims)
const {
434 auto sgDataOpt = getSgData();
435 auto instDataOpt = getInstData();
436 auto laneDataOpt = getLaneData();
438 SmallVector<int32_t> sgData;
439 SmallVector<int32_t> instData;
440 SmallVector<int32_t> laneData;
443 sgData = llvm::to_vector(sgDataOpt.asArrayRef());
446 instData = llvm::to_vector(instDataOpt.asArrayRef());
449 laneData = llvm::to_vector(laneDataOpt.asArrayRef());
451 for (
auto dim : unitDims) {
452 if (dim <
static_cast<int64_t
>(sgData.size()))
454 if (dim <
static_cast<int64_t
>(instData.size()))
456 if (dim <
static_cast<int64_t
>(laneData.size()))
460 return LayoutAttr::get(
474LayoutAttr::setUnitDimLayout(SmallVector<int64_t> unitDims)
const {
475 auto sgLayoutOpt = getSgLayout();
476 auto laneLayoutOpt = getLaneLayout();
478 SmallVector<int32_t> sgLayout;
479 SmallVector<int32_t> laneLayout;
482 sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
484 laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
486 for (
auto dim : unitDims) {
487 if (dim <
static_cast<int64_t
>(sgLayout.size()))
489 if (dim <
static_cast<int64_t
>(laneLayout.size()))
493 return LayoutAttr::get(
497 getSgData(), getInstData(),
500 getLaneData(), getOrder());
505DistributeLayoutAttr LayoutAttr::setDimData(int64_t dim, int64_t sgData,
509 SmallVector<int64_t> sgDataVec = getEffectiveSgDataAsInt();
510 SmallVector<int64_t> instDataVec = getEffectiveInstDataAsInt();
511 SmallVector<int64_t> laneDataVec = getEffectiveLaneDataAsInt();
513 if (dim <
static_cast<int64_t
>(sgDataVec.size()) && sgData != -1)
514 sgDataVec[dim] = sgData;
515 if (dim <
static_cast<int64_t
>(instDataVec.size()) && instData != -1)
516 instDataVec[dim] = instData;
517 if (dim <
static_cast<int64_t
>(laneDataVec.size()) && laneData != -1)
518 laneDataVec[dim] = laneData;
520 SmallVector<int32_t> sgDataVec32(sgDataVec.begin(), sgDataVec.end());
521 SmallVector<int32_t> instDataVec32(instDataVec.begin(), instDataVec.end());
522 SmallVector<int32_t> laneDataVec32(laneDataVec.begin(), laneDataVec.end());
524 return LayoutAttr::get(
539DistributeLayoutAttr LayoutAttr::dropDims(SmallVector<int64_t> dimGroup) {
541 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
542 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
543 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
544 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
545 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
548 SmallVector<int64_t> sortedDimGroup = dimGroup;
549 llvm::sort(sortedDimGroup);
551 for (
auto dimIdx : llvm::reverse(sortedDimGroup)) {
552 if (!sgLayout.empty()) {
553 sgLayout.erase(sgLayout.begin() + dimIdx);
554 sgData.erase(sgData.begin() + dimIdx);
556 if (!instData.empty())
557 instData.erase(instData.begin() + dimIdx);
558 if (!laneLayout.empty()) {
559 laneLayout.erase(laneLayout.begin() + dimIdx);
560 laneData.erase(laneData.begin() + dimIdx);
567 SmallVector<int64_t> newOrder;
568 if (origOrderAttr && !origOrderAttr.empty()) {
569 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
570 for (int64_t d : origOrder) {
571 if (llvm::is_contained(dimGroup, d))
574 llvm::count_if(dimGroup, [&](int64_t s) {
return s < d; });
575 newOrder.push_back(d - offset);
577 if ((sgLayout.empty() && laneLayout.empty()) || newOrder.size() == 1)
584 SmallVector<int32_t> v32(v.begin(), v.end());
587 auto droppedLayout = xegpu::LayoutAttr::get(
588 getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
589 toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
590 return droppedLayout;
596DistributeLayoutAttr LayoutAttr::collapseDims(SmallVector<int64_t> dimGroup) {
598 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
599 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
600 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
601 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
602 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
603 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
605 SmallVector<int64_t> sortedDimGroup = dimGroup;
606 llvm::sort(sortedDimGroup);
607 int64_t dimBeforeCurrent = -1;
608 for (
auto dimIdx : sortedDimGroup) {
612 if (dimBeforeCurrent >= 0) {
613 if (getOrder() && !getOrder().empty()) {
614 int64_t orderBefore = origOrder[dimBeforeCurrent];
615 int64_t orderCurrent = origOrder[dimIdx];
616 if (orderBefore != (orderCurrent - 1))
617 llvm::report_fatal_error(
618 "dimensions being collapsed must be adjacent in order");
620 if (dimIdx != (dimBeforeCurrent + 1))
621 llvm::report_fatal_error(
622 "dimensions being collapsed must be adjacent");
625 dimBeforeCurrent = dimIdx;
628 int firstDim = sortedDimGroup.front();
633 if (!sgLayout.empty()) {
634 int64_t collapsedSglayout = 1, collapsedSgData = 1;
635 for (
auto dimIdx : dimGroup) {
636 collapsedSglayout *= sgLayout[dimIdx];
637 collapsedSgData *= sgData[dimIdx];
639 for (
auto dimIdx : llvm::reverse(sortedDimGroup)) {
640 sgLayout.erase(sgLayout.begin() + dimIdx, sgLayout.begin() + dimIdx + 1);
641 sgData.erase(sgData.begin() + dimIdx, sgData.begin() + dimIdx + 1);
643 sgLayout.insert(sgLayout.begin() + firstDim, collapsedSglayout);
644 sgData.insert(sgData.begin() + firstDim, collapsedSgData);
647 if (!instData.empty()) {
648 int64_t collapsedInstData = 1;
649 for (
auto dimIdx : dimGroup)
650 collapsedInstData *= instData[dimIdx];
651 for (
auto dimIdx : llvm::reverse(sortedDimGroup))
652 instData.erase(instData.begin() + dimIdx, instData.begin() + dimIdx + 1);
653 instData.insert(instData.begin() + firstDim, collapsedInstData);
656 if (!laneLayout.empty()) {
657 int64_t collapsedLaneLayout = 1, collapsedLaneData = 1;
658 for (
auto dimIdx : dimGroup) {
659 collapsedLaneLayout *= laneLayout[dimIdx];
660 collapsedLaneData *= laneData[dimIdx];
662 for (
auto dimIdx : llvm::reverse(sortedDimGroup)) {
663 laneLayout.erase(laneLayout.begin() + dimIdx,
664 laneLayout.begin() + dimIdx + 1);
665 laneData.erase(laneData.begin() + dimIdx, laneData.begin() + dimIdx + 1);
667 laneLayout.insert(laneLayout.begin() + firstDim, collapsedLaneLayout);
668 laneData.insert(laneData.begin() + firstDim, collapsedLaneData);
671 SmallVector<int64_t> newOrder;
673 if (orderAttr && !orderAttr.empty()) {
675 for (
auto dimIdx : llvm::reverse(sortedDimGroup)) {
676 if (dimIdx != firstDim)
677 origOrder.erase(origOrder.begin() + dimIdx);
682 llvm::to_vector(llvm::seq<size_t>(0, origOrder.size()));
686 [&](
size_t a,
size_t b) {
return origOrder[a] < origOrder[
b]; });
688 newOrder = llvm::to_vector(llvm::map_range(
689 indices, [&](
size_t i) {
return static_cast<int64_t
>(i); }));
695 SmallVector<int32_t> v32(v.begin(), v.end());
698 auto collapsedLayout = xegpu::LayoutAttr::get(
699 getContext(), toAttr(sgLayout), toAttr(sgData), toAttr(instData),
700 toAttr(laneLayout), toAttr(laneData), toAttr(newOrder));
701 return collapsedLayout;
741DistributeLayoutAttr LayoutAttr::expandDim(int64_t dim,
742 ArrayRef<int64_t> targetShape) {
743 SmallVector<int64_t> sgLayout = getEffectiveSgLayoutAsInt();
744 SmallVector<int64_t> sgData = getEffectiveSgDataAsInt();
745 SmallVector<int64_t> instData = getEffectiveInstDataAsInt();
746 SmallVector<int64_t> laneLayout = getEffectiveLaneLayoutAsInt();
747 SmallVector<int64_t> laneData = getEffectiveLaneDataAsInt();
749 int64_t origRank = getRank();
750 int64_t expCount =
static_cast<int64_t
>(targetShape.size());
751 assert(dim >= 0 && dim < origRank &&
"dim out of range");
752 assert(expCount >= 1 &&
"targetShape must have at least one dim");
753 int64_t newRank = origRank + expCount - 1;
758 int64_t origSgLayoutDim = sgLayout.empty() ? 1 : sgLayout[dim];
759 int64_t origSgDataDim = sgData.empty() ? 1 : sgData[dim];
760 int64_t origLaneLayoutDim = laneLayout.empty() ? 1 : laneLayout[dim];
761 int64_t origLaneDataDim = laneData.empty() ? 1 : laneData[dim];
762 int64_t origInstDataDim = instData.empty() ? 1 : instData[dim];
767 auto spread = [&](int64_t total, ArrayRef<int64_t> dimSizeCap,
768 bool outerToInner) -> SmallVector<int64_t> {
769 SmallVector<int64_t> out(expCount, 1);
770 int64_t remaining = total;
771 auto step = [&](int64_t i) {
774 int64_t take = std::min(remaining, dimSizeCap[i]);
775 assert(take > 0 &&
"expandDim distribution must not be zero");
776 assert(remaining % take == 0 &&
777 "expandDims must divide evenly across dims");
782 for (int64_t i = 0; i < expCount; ++i)
785 for (int64_t i = expCount - 1; i >= 0; --i)
787 assert(remaining == 1 &&
"expandDims total must fit within target shape");
793 auto splice = [&](SmallVector<int64_t> &vec, ArrayRef<int64_t> expanded) {
796 vec.erase(vec.begin() + dim);
797 vec.insert(vec.begin() + dim, expanded.begin(), expanded.end());
800 bool hasSgLayout = !sgLayout.empty();
801 bool hasSgData = !sgData.empty();
802 bool hasLaneLayout = !laneLayout.empty();
803 bool hasLaneData = !laneData.empty();
804 bool hasInstData = !instData.empty();
807 SmallVector<int64_t> expSgLayout(expCount, 1);
809 expSgLayout = spread(origSgLayoutDim, targetShape,
true);
810 splice(sgLayout, expSgLayout);
812 bool sgDataReplicated =
815 SmallVector<int64_t> dimSizeCap(targetShape.begin(), targetShape.end());
816 if (hasSgLayout && !sgDataReplicated)
817 for (int64_t i = 0; i < expCount; ++i)
818 dimSizeCap[i] /= expSgLayout[i];
819 SmallVector<int64_t> expSgData =
820 spread(origSgDataDim, dimSizeCap,
false);
821 splice(sgData, expSgData);
827 SmallVector<int64_t> perSgShape(targetShape.begin(), targetShape.end());
828 if (hasSgLayout && !sgDataReplicated)
829 for (int64_t i = 0; i < expCount; ++i)
830 perSgShape[i] /= expSgLayout[i];
833 SmallVector<int64_t> expLaneLayout(expCount, 1);
834 SmallVector<int64_t> expLaneData(expCount, 1);
836 expLaneLayout = spread(origLaneLayoutDim, perSgShape,
838 splice(laneLayout, expLaneLayout);
841 SmallVector<int64_t> dimSizeCap(perSgShape.begin(), perSgShape.end());
843 for (int64_t i = 0; i < expCount; ++i)
844 dimSizeCap[i] /= expLaneLayout[i];
845 expLaneData = spread(origLaneDataDim, dimSizeCap,
false);
846 splice(laneData, expLaneData);
857 SmallVector<int64_t> expInstData;
858 if (!hasLaneLayout || !hasLaneData) {
859 expInstData = spread(origInstDataDim, perSgShape,
false);
861 int64_t laneAtom = origLaneLayoutDim * origLaneDataDim;
862 SmallVector<int64_t> atom(expCount, 1);
863 SmallVector<int64_t> dimSizeCap(expCount, 1);
864 for (int64_t i = 0; i < expCount; ++i) {
865 atom[i] = expLaneLayout[i] * expLaneData[i];
866 dimSizeCap[i] = perSgShape[i] / atom[i];
868 expInstData = spread(origInstDataDim / laneAtom, dimSizeCap,
870 for (int64_t i = 0; i < expCount; ++i)
871 expInstData[i] *= atom[i];
873 splice(instData, expInstData);
879 SmallVector<int64_t> newOrder;
881 if (orderAttr && !orderAttr.empty()) {
882 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
883 newOrder.reserve(newRank);
884 for (int64_t o : origOrder) {
887 for (int64_t i = expCount - 1; i >= 0; --i)
888 newOrder.push_back(dim + i);
889 }
else if (o > dim) {
890 newOrder.push_back(o + expCount - 1);
892 newOrder.push_back(o);
900 SmallVector<int32_t> v32(v.begin(), v.end());
903 return xegpu::LayoutAttr::get(
getContext(), toAttr(sgLayout), toAttr(sgData),
904 toAttr(instData), toAttr(laneLayout),
905 toAttr(laneData), toAttr(newOrder));
909DistributeLayoutAttr LayoutAttr::transposeDims(ArrayRef<int64_t> permutation) {
911 SmallVector<int64_t> origSgLayout = getEffectiveSgLayoutAsInt();
912 SmallVector<int64_t> origSgData = getEffectiveSgDataAsInt();
913 SmallVector<int64_t> origInstData = getEffectiveInstDataAsInt();
914 SmallVector<int64_t> origLaneLayout = getEffectiveLaneLayoutAsInt();
915 SmallVector<int64_t> origLaneData = getEffectiveLaneDataAsInt();
916 SmallVector<int64_t> origOrder = getEffectiveOrderAsInt();
918 SmallVector<int32_t> sgLayout;
919 SmallVector<int32_t> sgData;
920 SmallVector<int32_t> instData;
921 SmallVector<int32_t> laneLayout;
922 SmallVector<int32_t> laneData;
923 SmallVector<int32_t> order;
925 for (int64_t idx : permutation) {
926 if (!origLaneLayout.empty()) {
927 laneLayout.push_back(
static_cast<int32_t
>(origLaneLayout[idx]));
928 laneData.push_back(
static_cast<int32_t
>(origLaneData[idx]));
930 if (!origInstData.empty())
931 instData.push_back(
static_cast<int32_t
>(origInstData[idx]));
932 if (!origSgLayout.empty()) {
933 sgLayout.push_back(
static_cast<int32_t
>(origSgLayout[idx]));
934 sgData.push_back(
static_cast<int32_t
>(origSgData[idx]));
936 order.push_back(
static_cast<int32_t
>(origOrder[idx]));
938 if (origLaneLayout.empty() && origSgLayout.empty())
944 return xegpu::LayoutAttr::get(
getContext(), toAttr(sgLayout), toAttr(sgData),
945 toAttr(instData), toAttr(laneLayout),
946 toAttr(laneData), toAttr(order));
950bool LayoutAttr::isTransposeOf(
const xegpu::DistributeLayoutAttr &other,
951 ArrayRef<int64_t> perm,
955 if (getRank() != other.getRank() ||
956 perm.size() !=
static_cast<size_t>(getRank()))
963 auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src,
964 ArrayRef<int64_t> perm) {
965 for (
const auto &ta : llvm::enumerate(perm)) {
966 if (dst[ta.index()] != src[ta.value()])
972 return checkTranspose(getEffectiveSgLayoutAsInt(),
973 other.getEffectiveSgLayoutAsInt(), perm) &&
974 checkTranspose(getEffectiveSgDataAsInt(),
975 other.getEffectiveSgDataAsInt(), perm) &&
976 checkTranspose(getEffectiveOrderAsInt(),
977 other.getEffectiveOrderAsInt(), perm);
979 return checkTranspose(getEffectiveInstDataAsInt(),
980 other.getEffectiveInstDataAsInt(), perm);
982 return checkTranspose(getEffectiveLaneLayoutAsInt(),
983 other.getEffectiveLaneLayoutAsInt(), perm) &&
984 checkTranspose(getEffectiveLaneDataAsInt(),
985 other.getEffectiveLaneDataAsInt(), perm) &&
986 checkTranspose(getEffectiveOrderAsInt(),
987 other.getEffectiveOrderAsInt(), perm);
992bool LayoutAttr::isCompatibleWith(
const xegpu::DistributeLayoutAttr &other,
993 SmallVector<int64_t> shape,
997 if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
1000 if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
1001 getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
1004 if (getEffectiveLaneLayoutAsInt() ==
1005 other.getEffectiveLaneLayoutAsInt() &&
1006 getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
1010 auto compareCoordsForAllIds = [&](int64_t size) {
1016 return compareCoordsForAllIds(wgSize);
1019 return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
1022 int64_t subgroupSize =
computeProduct(getEffectiveLaneLayoutAsInt());
1023 return compareCoordsForAllIds(subgroupSize);
1032SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()>
emitError,
1036 return emitError() <<
"expected dims attribute";
1039 llvm::SmallDenseSet<int64_t> seen;
1042 return emitError() <<
"invalid dim (" << dim <<
") in slice attribute.";
1043 if (!seen.insert(dim).second)
1044 return emitError() <<
"repeated dim (" << dim <<
") in slice attribute.";
1049SliceAttr SliceAttr::flatten()
const {
1050 xegpu::DistributeLayoutAttr parent = getParent();
1051 SmallVector<DenseI64ArrayAttr> slicedDims({
getDims()});
1053 while (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
1054 parent = sliceAttr.getParent();
1055 slicedDims.push_back(sliceAttr.getDims());
1058 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
1059 SmallVector<int64_t>
indices =
1060 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
1063 SmallVector<int64_t> remainingDims(
indices);
1064 for (
auto dim : llvm::reverse(slicedDims))
1065 remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
1069 SmallVector<int64_t> flattenedDims = XeGPUDialect::slice(
1070 llvm::ArrayRef<int64_t>(
indices), llvm::ArrayRef<int64_t>(remainingDims));
1072 return xegpu::SliceAttr::get(
1077FailureOr<SmallVector<Value>>
1078SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
1079 SliceAttr attr = flatten();
1080 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
1081 return parent.delinearizeId(builder, loc, linearId);
1087FailureOr<SmallVector<SmallVector<Value>>>
1088SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
1089 Value linearId, ArrayRef<int64_t> shape) {
1090 assert(getRank() ==
static_cast<int64_t
>(shape.size()) &&
"invalid shape.");
1092 SmallVector<int64_t> layout;
1093 SmallVector<int64_t> subShape;
1094 if (isForWorkgroup()) {
1095 layout = getEffectiveSgLayoutAsInt();
1096 subShape = getEffectiveSgDataAsInt();
1097 }
else if (isForSubgroup()) {
1098 layout = getEffectiveLaneLayoutAsInt();
1099 subShape = getEffectiveLaneDataAsInt();
1104 if (subShape.empty())
1108 auto maybeIds = delinearizeId(builder, loc, linearId);
1114 ArrayRef<int64_t> dims = flatten().getDims().
asArrayRef();
1115 SmallVector<Value> canonicalIds =
1116 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
1118 return genCoordinates(builder, loc, canonicalIds, layout, subShape, shape);
1125SmallVector<SmallVector<int64_t>>
1126SliceAttr::computeStaticDistributedCoords(int64_t linearId,
1127 ArrayRef<int64_t> shape) {
1128 assert(getRank() ==
static_cast<int64_t
>(shape.size()) &&
"invalid shape.");
1130 SmallVector<int64_t> layout;
1131 SmallVector<int64_t> subShape;
1132 SmallVector<int64_t> instData;
1133 if (isForWorkgroup()) {
1134 layout = getEffectiveSgLayoutAsInt();
1135 subShape = getEffectiveSgDataAsInt();
1136 }
else if (isForSubgroup()) {
1137 instData = getEffectiveInstDataAsInt();
1138 layout = getEffectiveLaneLayoutAsInt();
1139 subShape = getEffectiveLaneDataAsInt();
1141 if (!instData.empty()) {
1143 subShape = instData;
1146 assert(!subShape.empty() &&
"sgdata or lanedata cannot be empty");
1149 SliceAttr flattened = flatten();
1150 auto parent = dyn_cast<LayoutAttr>(flattened.getParent());
1151 SmallVector<int64_t> parentLayoutVec;
1152 if (parent.isForWorkgroup())
1153 parentLayoutVec = parent.getEffectiveSgLayoutAsInt();
1155 parentLayoutVec = parent.getEffectiveLaneLayoutAsInt();
1157 SmallVector<int64_t> order = parent.getEffectiveOrderAsInt();
1158 SmallVector<int64_t> allIds(parentLayoutVec.size());
1159 int64_t remaining = linearId;
1160 for (
size_t i = 0; i < order.size(); ++i) {
1161 int64_t dimIdx = order[i];
1162 allIds[dimIdx] = remaining % parentLayoutVec[dimIdx];
1163 if (i < order.size() - 1)
1164 remaining = remaining / parentLayoutVec[dimIdx];
1169 ArrayRef<int64_t> dims = flattened.getDims().asArrayRef();
1170 SmallVector<int64_t> canonicalIds =
1171 XeGPUDialect::slice(ArrayRef<int64_t>(allIds), dims);
1176bool SliceAttr::isSliceOf(
const xegpu::DistributeLayoutAttr &other) {
1177 auto flattenedThis = flatten();
1180 if (
auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
1181 return flattenedThis.getParent() == otherLayout;
1183 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
1185 if (flattenedThis.getParent() != flattenedOther.getParent())
1189 llvm::SmallDenseSet<int64_t> thisDims(
1190 flattenedThis.getDims().asArrayRef().begin(),
1191 flattenedThis.getDims().asArrayRef().end());
1192 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
1193 [&](int64_t dim) { return thisDims.contains(dim); });
1196bool SliceAttr::isEqualTo(
const xegpu::DistributeLayoutAttr &other) {
1197 if (dyn_cast<xegpu::LayoutAttr>(other))
1200 auto flattenedThis = flatten();
1201 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
1203 return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
1204 (flattenedThis.getDims() == flattenedOther.getDims()));
1207bool SliceAttr::isCompatibleWith(
const xegpu::DistributeLayoutAttr &other,
1208 SmallVector<int64_t> shape,
1212 if (getEffectiveOrderAsInt() == other.getEffectiveOrderAsInt()) {
1215 if (getEffectiveSgLayoutAsInt() == other.getEffectiveSgLayoutAsInt() &&
1216 getEffectiveSgDataAsInt() == other.getEffectiveSgDataAsInt())
1219 if (getEffectiveLaneLayoutAsInt() ==
1220 other.getEffectiveLaneLayoutAsInt() &&
1221 getEffectiveLaneDataAsInt() == other.getEffectiveLaneDataAsInt())
1225 auto compareCoordsForAllIds = [&](int64_t size) {
1229 auto flattenedThis = flatten();
1230 auto parent = dyn_cast<LayoutAttr>(flattenedThis.getParent());
1232 int64_t wgSize =
computeProduct(parent.getEffectiveSgLayoutAsInt());
1233 return compareCoordsForAllIds(wgSize);
1236 return (getEffectiveInstDataAsInt() == other.getEffectiveInstDataAsInt());
1239 int64_t subgroupSize =
computeProduct(parent.getEffectiveLaneLayoutAsInt());
1240 return compareCoordsForAllIds(subgroupSize);
1245xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
1246 if (sliceDimsToDrop.empty())
1248 SmallVector<int64_t> sliceDims{
getDims().asArrayRef()};
1249 for (
auto dim : sliceDimsToDrop) {
1250 auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), dim);
1251 assert(foundIt != sliceDims.end() &&
1252 "Expected to find the specified reduction dim in slice dims");
1253 sliceDims.erase(foundIt);
1256 auto sliceWithoutDims = xegpu::SliceAttr::get(
1260 return sliceWithoutDims;
1268static SmallVector<int64_t>
1276 std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
1278 std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
1279 int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
1283 llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
1286 for (
int64_t i = 0; i < parentSpaceRank; ++i) {
1287 if (!slicedDimsSet.contains(i))
1288 remainingDims.push_back(i);
1293 for (
auto dim : dimsToMap) {
1294 int64_t mappedDim = remainingDims[dim];
1295 adjustUnitDims.push_back(mappedDim);
1298 return adjustUnitDims;
1304 DistributeLayoutAttr parentLayout = getParent();
1312 parentLayout.setUnitDimData(adjustUnitDims), getDims());
1318 DistributeLayoutAttr parentLayout = getParent();
1325 return SliceAttr::get(
1326 getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
1331DistributeLayoutAttr SliceAttr::setDimData(int64_t dim, int64_t sgData,
1332 int64_t instData, int64_t laneData) {
1333 ArrayRef<int64_t> sliceDims =
getDims().asArrayRef();
1334 auto parent = getParent();
1336 SmallVector<int64_t> dimSet;
1337 dimSet.push_back(dim);
1338 SmallVector<int64_t> adjustDims =
1340 return SliceAttr::get(
1342 parent.setDimData(adjustDims[0], sgData, instData, laneData),
getDims());
1363DistributeLayoutAttr SliceAttr::dropDims(SmallVector<int64_t> dimGroup) {
1365 SmallVector<int64_t> sliceDims = llvm::to_vector(
getDims().asArrayRef());
1366 SmallVector<int64_t> dimsInParentSpace =
1369 auto droppedParent = getParent().dropDims(dimsInParentSpace);
1374 SmallVector<int64_t> newSliceDims;
1375 for (int64_t d : sliceDims) {
1377 llvm::count_if(dimsInParentSpace, [&](int64_t s) {
return s < d; });
1378 newSliceDims.push_back(d - offset);
1381 return SliceAttr::get(
getContext(), droppedParent,
1388DistributeLayoutAttr SliceAttr::collapseDims(SmallVector<int64_t> dimGroup) {
1391 SmallVector<int64_t> sliceDims = llvm::to_vector(
getDims().asArrayRef());
1392 assert(
"expect sliceDims not being collapsed" &&
1393 llvm::none_of(dimGroup, [&](int64_t dim) {
1394 return llvm::is_contained(sliceDims, dim);
1396 SmallVector<int64_t> dimsInParentSpace =
1399 auto collapsedParent = getParent().collapseDims(dimsInParentSpace);
1400 return SliceAttr::get(
getContext(), collapsedParent,
1408DistributeLayoutAttr SliceAttr::expandDim(int64_t dim,
1409 ArrayRef<int64_t> targetShape) {
1413 ArrayRef<int64_t> sliceDims =
getDims().asArrayRef();
1414 SmallVector<int64_t> dimSet = {dim};
1415 SmallVector<int64_t> dimsInParentSpace =
1417 int64_t parentDim = dimsInParentSpace[0];
1419 auto expandedParent = getParent().expandDim(parentDim, targetShape);
1421 int64_t shift =
static_cast<int64_t
>(targetShape.size()) - 1;
1422 SmallVector<int64_t> newSliceDims;
1423 newSliceDims.reserve(sliceDims.size());
1424 for (int64_t s : sliceDims)
1425 newSliceDims.push_back(s > parentDim ? s + shift : s);
1427 return SliceAttr::get(
getContext(), expandedParent,
1434 llvm::sort(sortedSliceDims);
1436 for (
size_t i = 1; i < sortedSliceDims.size(); ++i) {
1437 assert((sortedSliceDims[i] == sortedSliceDims[i - 1] + 1) &&
1438 "slice dims non consecutive, cannot be transposed");
1442 if (sortedSliceDims.front() == 0) {
1445 for (
int64_t dim : permutation)
1446 permForParent.push_back(dim + sortedSliceDims.size());
1447 for (
int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
1448 permForParent.push_back(i);
1452 for (
int64_t i = sortedSliceDims.size() - 1; i >= 0; --i)
1453 permForParent.push_back(i + permutation.size());
1454 for (
int64_t dim : permutation)
1455 permForParent.push_back(dim);
1457 return permForParent;
1463 DistributeLayoutAttr parent = getParent();
1466 auto transposedParent = parent.transposeDims(permForParent);
1467 return SliceAttr::get(
getContext(), transposedParent,
1472bool SliceAttr::isTransposeOf(
const xegpu::DistributeLayoutAttr &other,
1476 auto otherSlice = dyn_cast<xegpu::SliceAttr>(other);
1477 if (!otherSlice || getDims() != otherSlice.getDims())
1481 DistributeLayoutAttr parent = getParent();
1483 auto otherParent = otherSlice.getParent();
1484 return parent.isTransposeOf(otherParent, permForParent, kind);
1492RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()>
emitError,
1493 IntegerAttr startOfRange, IntegerAttr endOfRange) {
1494 if (startOfRange.getInt() >= endOfRange.getInt())
1495 return emitError() <<
"'end' : " << endOfRange.getInt()
1496 <<
" must be greater than 'start' : "
1497 << startOfRange.getInt();
1506mlir::Type TensorDescType::parse(AsmParser &parser) {
1507 llvm::SmallVector<int64_t> shape;
1508 mlir::Type elementType;
1509 mlir::FailureOr<mlir::Attribute> encoding;
1510 mlir::FailureOr<mlir::Attribute> layout;
1518 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
1523 if (mlir::failed(parser.
parseType(elementType))) {
1524 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
1530 mlir::Attribute attr;
1532 if (mlir::succeeded(res)) {
1533 if (mlir::isa<DistributeLayoutAttr>(attr)) {
1537 if (mlir::isa<BlockTensorDescAttr>(attr)) {
1550 return TensorDescType::getChecked(
1552 elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
1553 layout.value_or(mlir::Attribute()));
1556void TensorDescType::print(AsmPrinter &printer)
const {
1560 for (int64_t dim : shape) {
1561 if (mlir::ShapedType::isDynamic(dim))
1570 auto encoding = getEncoding();
1571 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1572 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
1573 printer <<
", " << encoding;
1575 if (
auto layout = getLayout())
1576 printer <<
", " << layout;
1581TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
1582 mlir::Type elementType,
int array_length,
1583 bool boundary_check,
1584 MemorySpace memory_space,
1585 mlir::Attribute layout) {
1587 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
1589 return Base::get(context, shape, elementType, attr, layout);
1593TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()>
emitError,
1594 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
1595 mlir::Attribute encoding, mlir::Attribute layout) {
1596 size_t rank = shape.size();
1599 return emitError() <<
"expected non-zero rank tensor";
1601 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
1603 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
1604 if (rank > 1 && memorySpaceAttr &&
1605 memorySpaceAttr.getValue() == MemorySpace::SLM)
1606 return emitError() <<
"SLM is only supported for 1D block tensor";
1610 return emitError() <<
"unsupported element type " << elementType
1611 <<
": expected integer or float";
1613 if (
auto layoutAttr =
1614 mlir::dyn_cast_if_present<DistributeLayoutAttr>(layout)) {
1615 if (rank != (
size_t)layoutAttr.getRank())
1616 return emitError() <<
"expected layout rank to match tensor rank";
1618 if (!layoutAttr.isDistributable(SmallVector<int64_t>(shape))) {
1619 std::string shapeStr;
1620 llvm::raw_string_ostream stream(shapeStr);
1621 llvm::interleaveComma(shape, stream);
1622 return emitError() <<
"cannot distribute [" << shapeStr <<
"] using "
1633mlir::Type MemDescType::parse(AsmParser &parser) {
1634 llvm::SmallVector<int64_t> shape;
1635 mlir::Type elementType;
1636 mlir::FailureOr<MemLayoutAttr> layout;
1644 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
1649 if (mlir::failed(parser.
parseType(elementType))) {
1650 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
1658 if (mlir::failed(res))
1668 return MemDescType::getChecked(
1670 elementType, layout.value_or(MemLayoutAttr()));
1673void MemDescType::print(AsmPrinter &printer)
const {
1680 if (
auto layout = getMemLayout())
1681 printer <<
", " << layout;
1690Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
1695 llvm::SmallDenseSet<StringRef> seenKeys;
1696 SmallVector<NamedAttribute> attributes;
1698 auto parseElt = [&]() -> ParseResult {
1701 return parser.
emitError(loc,
"expected valid attribute name");
1703 if (!seenKeys.insert(nameId).second)
1704 return parser.
emitError(loc,
"duplicate key '")
1705 << nameId <<
" in mem layout attribute";
1713 attributes.emplace_back(nameId, attr);
1729 loc, context, DictionaryAttr::get(context, attributes));
1732void MemLayoutAttr::print(AsmPrinter &printer)
const {
1734 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
1735 for (
size_t i = 0; i < attrs.size(); i++) {
1736 printer << attrs[i].getName().str() <<
" = " << attrs[i].getValue();
1737 if (i < attrs.size() - 1)
1746template <
typename ArithOp>
1751 return ArithOp::create(builder, loc, aVal, bVal).getResult();
1756 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
1760 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
1764 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
1767#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
1776 assert(offsets.size() == blockShape.size() &&
1777 "offsets and blockShape must have the same size");
1781 for (
auto [offset, block] : llvm::zip(offsets, blockShape)) {
1782 divs.push_back(
div(offset, block));
1783 rems.push_back(
rem(offset, block));
1785 blockedOffsets.append(divs.begin(), divs.end());
1786 blockedOffsets.append(rems.begin(), rems.end());
1788 return blockedOffsets;
1796 ArrayAttr strideAttr = getStrideAttr();
1798 for (
Attribute attr : strideAttr.getValue()) {
1799 strides.push_back(cast<IntegerAttr>(attr).getInt());
1807 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
1808 llvm::sort(perm, [&](
int a,
int b) {
return strides[a] < strides[
b]; });
1810 assert(strides[perm[0]] == 1 &&
"inner most dim must have stride 1");
1812 SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
1813 innerBlkStride[perm[0]] = 1;
1814 for (
size_t i = 1; i < perm.size(); ++i)
1815 innerBlkStride[perm[i]] =
1816 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
1822 SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
1823 SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
1824 for (
size_t i = 0; i < perm.size() - 1; ++i) {
1825 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
1826 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
1829 int64_t innerBlkSize = 1;
1830 for (
auto s : innerBlkShape)
1833 SmallVector<int64_t> outerBlkStride(matrixShape.size());
1834 outerBlkStride[perm[0]] = innerBlkSize;
1835 for (
size_t i = 0; i < perm.size() - 1; ++i) {
1836 outerBlkStride[perm[i + 1]] =
1837 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
1841 SmallVector<int64_t> blockedStrides;
1842 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
1843 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
1845 return blockedStrides;
1849Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
1850 ArrayRef<OpFoldResult> offsets) {
1853 SmallVector<int64_t> blockShape = getBlockShape();
1854 SmallVector<int64_t> strides = getStrideShape();
1855 SmallVector<OpFoldResult> blockedOffsets;
1858 if (llvm::equal(blockShape, matrixShape)) {
1860 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
1862 assert(offsets.size() == blockShape.size() &&
1863 "offsets and blockShape must have the same size");
1867 SmallVector<OpFoldResult> divs, rems;
1869 for (
auto [offset, block] : llvm::zip(offsets, blockShape)) {
1870 divs.push_back(
div(offset, block));
1871 rems.push_back(
rem(offset, block));
1873 blockedOffsets.append(divs.begin(), divs.end());
1874 blockedOffsets.append(rems.begin(), rems.end());
1875 offsets = blockedOffsets;
1880 for (
size_t i = 0; i < offsets.size(); ++i) {
1881 OpFoldResult mulResult =
mul(offsets[i], strides[i]);
1883 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
1886 return linearOffset;
1892#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
1893#define GET_ATTRDEF_CLASSES
1894#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
1895#define GET_TYPEDEF_CLASSES
1896#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static Type getElementType(Type type)
Determine the element type of type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
static SmallVector< SmallVector< int64_t > > genStaticCoordinates(llvm::ArrayRef< int64_t > canonicalIds, llvm::ArrayRef< int64_t > layout, llvm::ArrayRef< int64_t > subShape, llvm::ArrayRef< int64_t > shape)
LayoutKind
Specifies the level of a layout hierarchy for comparison or propagation.
static SmallVector< int64_t > mapSlicedDimsToParentSpace(const SmallVector< int64_t > &dimsToMap, ArrayRef< int64_t > sliceDims)
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
static bool compareDistributedCoords(xegpu::DistributeLayoutAttr self, const xegpu::DistributeLayoutAttr &other, ArrayRef< int64_t > shape, xegpu::LayoutKind level, int64_t size)
Returns true if self and other distribute shape identically at level: every id in [0,...
static SmallVector< SmallVector< Value > > genCoordinates(OpBuilder &builder, Location loc, SmallVector< Value > delinearizedId, ArrayRef< int64_t > subShapesLayout, ArrayRef< int64_t > subShape, ArrayRef< int64_t > srcShape)
SmallVector< int64_t > getPermForParentLayout(ArrayRef< int64_t > sliceDims, ArrayRef< int64_t > permutation)
static SmallVector< SmallVector< int64_t > > expandBlockCoords(ArrayRef< SmallVector< int64_t > > blockStarts, ArrayRef< int64_t > subShape)
Expands per-distribution-unit block-start coordinates into the full list of element coordinates each ...
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
bool isPermutationVector(ArrayRef< int64_t > interchange)
Method to check if an interchange vector is a permutation.