16#include "llvm/ADT/TypeSwitch.h"
17#include "llvm/Support/Debug.h"
24void XeGPUDialect::initialize() {
26#define GET_TYPEDEF_LIST
27#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
31#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
34#define GET_ATTRDEF_LIST
35#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
38#define GET_OP_INTERFACE_CLASSES
39#include "mlir/Dialect/XeGPU/IR/XeGPUOpInterface.cpp.inc"
58 llvm::zip_equal(srcShape,
60 [](
const auto &t) {
return std::min(std::get<0>(t), std::get<1>(t)); });
64 llvm::zip(delinearizedId, subShape), [&](
const auto &t) ->
Value {
80 llvm::map_to_vector(llvm::zip_equal(base, distUnitLocalOffset),
81 [&](
const auto &t) ->
Value {
83 loc, std::get<0>(t), std::get<1>(t));
87 llvm::zip_equal(adds, srcShape), [&](
const auto &t) ->
Value {
93 coordinates.push_back(mods);
101 xegpu::DistributeLayoutAttr attr) {
102 assert(attr &&
"Layout attribute is missing.");
119 if (layout.size() !=
shape.size())
122 if (ratio.has_value()) {
123 newShape = ratio.value();
131 if (data.size() != shape.size())
134 if (!ratio.has_value() && rr)
136 if (!ratio.has_value())
146 auto maybeSgShape = tryDistribute(shape, attr.getEffectiveSgLayoutAsInt(),
147 attr.getEffectiveSgDataAsInt());
150 auto sgShape = maybeSgShape.value();
153 auto maybeInstShape =
154 tryDistribute(sgShape, {}, attr.getEffectiveInstDataAsInt(),
false);
157 auto instShape = maybeInstShape.value();
160 auto maybeLaneShape =
161 tryDistribute(instShape, attr.getEffectiveLaneLayoutAsInt(),
162 attr.getEffectiveLaneDataAsInt(),
false);
163 return maybeLaneShape.has_value();
169BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
170 xegpu::MemorySpace memory_space,
172 bool boundary_check) {
173 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
175 IntegerAttr::get(IntegerType::get(context, 64), array_length);
177 return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
180bool BlockTensorDescAttr::hasDefaultsOnly() {
181 return getMemorySpace().getValue() == xegpu::MemorySpace::Global &&
182 getArrayLength().getInt() == 1 && getBoundaryCheck().getValue();
189ScatterTensorDescAttr::get(mlir::MLIRContext *context,
190 xegpu::MemorySpace memory_space,
int chunk_size) {
191 auto scopeAttr = MemorySpaceAttr::get(context, memory_space);
193 IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
194 return Base::get(context, scopeAttr, chunkSizeAttr);
197LogicalResult ScatterTensorDescAttr::verify(
198 llvm::function_ref<mlir::InFlightDiagnostic()>
emitError,
199 MemorySpaceAttr memory_space, IntegerAttr chunk_size) {
200 int64_t chunkSize = chunk_size.getInt();
202 return emitError() <<
"invalid chunk size";
211LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()>
emitError,
217 if (!sg_layout && !inst_data && !lane_layout)
223 if (sg_layout && inst_data && sg_layout.size() != inst_data.size()) {
225 <<
"expected sg_layout and inst_data to have the same rank";
228 if (sg_layout && lane_layout && sg_layout.size() != lane_layout.size()) {
230 <<
"expected sg_layout and lane_layout to have the same rank";
233 if (inst_data && lane_layout && inst_data.size() != lane_layout.size()) {
234 return emitError() <<
"expected inst_data and lane_layout to have the same "
235 "rank, got inst_data "
236 << inst_data.size() <<
", lane_layout "
237 << lane_layout.size();
244 return emitError() <<
"expected sg_layout being used with sg_data";
245 if (sg_data.size() != sg_layout.size())
247 <<
"expected sg_data and sg_layout to have the same rank";
254 return emitError() <<
"expected lane_layout being used with lane_data";
255 if (lane_data.size() != lane_layout.size())
257 <<
"expected lane_data and lane_layout to have the same rank";
261 if (!sg_layout && !lane_layout)
263 <<
"expected sg_layout/lane_layout being used with order";
265 if (sg_layout && order.size() != sg_layout.size())
267 <<
"expected order and sg_layout to have the same rank";
269 if (lane_layout && order.size() != lane_layout.size())
271 <<
"expected order and lane_layout to have the same rank";
277FailureOr<SmallVector<Value>>
278LayoutAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
280 SmallVector<int64_t> sgLayoutInt;
281 if (isForWorkgroup()) {
282 sgLayoutInt = getEffectiveSgLayoutAsInt();
283 }
else if (isForSubgroup()) {
284 sgLayoutInt = getEffectiveLaneLayoutAsInt();
292 SmallVector<int64_t> order;
293 if (orderAttr && !orderAttr.empty()) {
294 order = llvm::to_vector(
296 [](int32_t idx) { return static_cast<int64_t>(idx); }));
299 order = llvm::to_vector(
300 llvm::reverse(llvm::seq<int64_t>(0, sgLayoutInt.size())));
303 if (order.size() != sgLayoutInt.size()) {
307 SmallVector<Value>
result(sgLayoutInt.size());
308 Value remaining = linearId;
331 for (
size_t i = 0; i < order.size(); ++i) {
332 int64_t dimIdx = order[i];
333 int64_t dimSize = sgLayoutInt[dimIdx];
336 builder.
createOrFold<arith::ConstantIndexOp>(loc, dimSize);
343 builder.
createOrFold<arith::RemUIOp>(loc, remaining, dimSizeVal);
350 if (i < order.size() - 1) {
352 builder.
createOrFold<arith::DivUIOp>(loc, remaining, dimSizeVal);
361FailureOr<SmallVector<SmallVector<Value>>>
362LayoutAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
363 Value linearId, ArrayRef<int64_t> shape) {
364 SmallVector<int64_t> layout;
365 SmallVector<int64_t> subShape;
366 if (isForWorkgroup()) {
367 layout = getEffectiveSgLayoutAsInt();
368 subShape = getEffectiveSgDataAsInt();
369 }
else if (isForSubgroup()) {
370 layout = getEffectiveLaneLayoutAsInt();
371 subShape = getEffectiveLaneDataAsInt();
375 if (subShape.empty()) {
377 subShape = derivedShape.value();
383 auto maybeIds = delinearizeId(builder, loc, linearId);
386 SmallVector<Value> ids = *maybeIds;
388 return genCoordinates(builder, loc, ids, layout, subShape, shape);
391bool LayoutAttr::isEqualTo(
const xegpu::DistributeLayoutAttr &other) {
392 if (dyn_cast<xegpu::SliceAttr>(other))
395 return *
this == dyn_cast<xegpu::LayoutAttr>(other);
401 auto sgDataOpt = getSgData();
402 auto instDataOpt = getInstData();
403 auto laneDataOpt = getLaneData();
405 SmallVector<int32_t> sgData;
406 SmallVector<int32_t> instData;
407 SmallVector<int32_t> laneData;
410 sgData = llvm::to_vector(sgDataOpt.asArrayRef());
413 instData = llvm::to_vector(instDataOpt.asArrayRef());
416 laneData = llvm::to_vector(laneDataOpt.asArrayRef());
419 for (
auto dim : unitDims) {
420 if (dim <
static_cast<int64_t
>(sgData.size()))
422 if (dim <
static_cast<int64_t
>(instData.size()))
424 if (dim <
static_cast<int64_t
>(laneData.size()))
428 return LayoutAttr::get(
443 auto sgLayoutOpt = getSgLayout();
444 auto laneLayoutOpt = getLaneLayout();
446 SmallVector<int32_t> sgLayout;
447 SmallVector<int32_t> laneLayout;
450 sgLayout = llvm::to_vector(sgLayoutOpt.asArrayRef());
453 laneLayout = llvm::to_vector(laneLayoutOpt.asArrayRef());
456 for (
auto dim : unitDims) {
457 if (dim <
static_cast<int64_t
>(sgLayout.size()))
459 if (dim <
static_cast<int64_t
>(laneLayout.size()))
463 return LayoutAttr::get(
467 getSgData(), getInstData(),
470 getLaneData(), getOrder());
477SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()>
emitError,
481 return emitError() <<
"expected dims attribute";
484 llvm::SmallDenseSet<int64_t> seen;
487 return emitError() <<
"invalid dim (" << dim <<
") in slice attribute.";
488 if (!seen.insert(dim).second)
489 return emitError() <<
"repeated dim (" << dim <<
") in slice attribute.";
494SliceAttr SliceAttr::flatten()
const {
495 xegpu::DistributeLayoutAttr parent = getParent();
496 SmallVector<DenseI64ArrayAttr> slicedDims({
getDims()});
498 while (
auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
499 parent = sliceAttr.getParent();
500 slicedDims.push_back(sliceAttr.getDims());
503 auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
505 llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
508 SmallVector<int64_t> remainingDims(
indices);
509 for (
auto dim : llvm::reverse(slicedDims))
510 remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
514 SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
515 llvm::ArrayRef<int64_t>(
indices), llvm::ArrayRef<int64_t>(remainingDims));
517 return xegpu::SliceAttr::get(
522FailureOr<SmallVector<Value>>
523SliceAttr::delinearizeId(OpBuilder &builder, Location loc, Value linearId) {
524 SliceAttr attr = flatten();
525 auto parent = dyn_cast<LayoutAttr>(attr.getParent());
526 return parent.delinearizeId(builder, loc, linearId);
532FailureOr<SmallVector<SmallVector<Value>>>
533SliceAttr::computeDistributedCoords(OpBuilder &builder, Location loc,
534 Value linearId, ArrayRef<int64_t> shape) {
535 assert(getRank() ==
static_cast<int64_t
>(shape.size()) &&
"invalid shape.");
536 if (!isForWorkgroup())
539 SmallVector<int64_t> layout;
540 SmallVector<int64_t> subShape;
541 if (isForWorkgroup()) {
542 layout = getEffectiveSgLayoutAsInt();
543 subShape = getEffectiveSgDataAsInt();
544 }
else if (isForSubgroup()) {
545 layout = getEffectiveLaneLayoutAsInt();
546 subShape = getEffectiveLaneDataAsInt();
551 if (subShape.empty()) {
553 subShape = derivedShape.value();
559 auto maybeIds = delinearizeId(builder, loc, linearId);
565 ArrayRef<int64_t> dims = flatten().getDims().
asArrayRef();
566 SmallVector<Value> sgIds =
567 XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
569 return genCoordinates(builder, loc, sgIds, layout, subShape, shape);
572bool SliceAttr::isSliceOf(
const xegpu::DistributeLayoutAttr &other) {
573 auto flattenedThis = flatten();
576 if (
auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
577 return flattenedThis.getParent() == otherLayout;
579 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
581 if (flattenedThis.getParent() != flattenedOther.getParent())
585 llvm::SmallDenseSet<int64_t> thisDims(
586 flattenedThis.getDims().asArrayRef().begin(),
587 flattenedThis.getDims().asArrayRef().end());
588 return llvm::all_of(flattenedOther.getDims().asArrayRef(),
589 [&](int64_t dim) { return thisDims.contains(dim); });
592xegpu::SliceAttr SliceAttr::dropSliceDims(ArrayRef<int64_t> sliceDimsToDrop) {
593 if (sliceDimsToDrop.empty())
595 SmallVector<int64_t> sliceDims{
getDims().asArrayRef()};
596 for (
auto dim : sliceDimsToDrop) {
597 auto foundIt = std::find(sliceDims.begin(), sliceDims.end(), dim);
598 assert(foundIt != sliceDims.end() &&
599 "Expected to find the specified reduction dim in slice dims");
600 sliceDims.erase(foundIt);
603 auto sliceWithoutDims = xegpu::SliceAttr::get(
607 return sliceWithoutDims;
610bool SliceAttr::isEqualTo(
const xegpu::DistributeLayoutAttr &other) {
611 if (dyn_cast<xegpu::LayoutAttr>(other))
614 auto flattenedThis = flatten();
615 auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
617 return ((flattenedThis.getParent() == flattenedOther.getParent()) &&
618 (flattenedThis.getDims() == flattenedOther.getDims()));
634 std::max(maxDim, *std::max_element(sliceDims.begin(), sliceDims.end()));
636 std::max(maxDim, *std::max_element(dimsToMap.begin(), dimsToMap.end()));
637 int64_t parentSpaceRank = maxDim + sliceDims.size() + 1;
641 llvm::SmallDenseSet<int64_t> slicedDimsSet(sliceDims.begin(),
644 for (
int64_t i = 0; i < parentSpaceRank; ++i) {
645 if (!slicedDimsSet.contains(i))
646 remainingDims.push_back(i);
651 for (
auto dim : dimsToMap) {
652 int64_t mappedDim = remainingDims[dim];
653 adjustUnitDims.insert(mappedDim);
656 return adjustUnitDims;
662 DistributeLayoutAttr parentLayout = getParent();
670 parentLayout.setUnitDimData(adjustUnitDims), getDims());
676 DistributeLayoutAttr parentLayout = getParent();
683 return SliceAttr::get(
684 getContext(), parentLayout.setUnitDimLayout(adjustUnitDims), getDims());
692RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()>
emitError,
693 IntegerAttr startOfRange, IntegerAttr endOfRange) {
694 if (startOfRange.getInt() >= endOfRange.getInt())
695 return emitError() <<
"'end' : " << endOfRange.getInt()
696 <<
" must be greater than 'start' : "
697 << startOfRange.getInt();
706mlir::Type TensorDescType::parse(AsmParser &parser) {
707 llvm::SmallVector<int64_t> shape;
708 mlir::Type elementType;
709 mlir::FailureOr<mlir::Attribute> encoding;
710 mlir::FailureOr<mlir::Attribute> layout;
718 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
723 if (mlir::failed(parser.
parseType(elementType))) {
724 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
730 mlir::Attribute attr;
732 if (mlir::succeeded(res)) {
733 if (mlir::isa<LayoutAttr>(attr)) {
737 if (mlir::isa<BlockTensorDescAttr, ScatterTensorDescAttr>(attr)) {
750 return TensorDescType::getChecked(
752 elementType, encoding.value_or(BlockTensorDescAttr::get(ctxt)),
753 layout.value_or(mlir::Attribute()));
756void TensorDescType::print(AsmPrinter &printer)
const {
760 for (int64_t dim : shape) {
761 if (mlir::ShapedType::isDynamic(dim))
770 auto encoding = getEncoding();
771 auto blockAttr = llvm::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
772 if (encoding && (!blockAttr || !blockAttr.hasDefaultsOnly()))
773 printer <<
", " << encoding;
775 if (
auto layout = getLayout())
776 printer <<
", " << layout;
781TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
782 mlir::Type elementType,
int array_length,
784 MemorySpace memory_space,
785 mlir::Attribute layout) {
787 auto attr = BlockTensorDescAttr::get(context, memory_space, array_length,
789 return Base::get(context, shape, elementType, attr, layout);
792TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
793 mlir::Type elementType,
int chunk_size,
794 MemorySpace memory_space,
795 mlir::Attribute layout) {
797 auto attr = ScatterTensorDescAttr::get(context, memory_space, chunk_size);
798 return Base::get(context, shape, elementType, attr, layout);
802TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()>
emitError,
803 llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
804 mlir::Attribute encoding, mlir::Attribute layout) {
805 size_t rank = shape.size();
808 return emitError() <<
"expected non-zero rank tensor";
810 auto blockAttr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(encoding);
812 MemorySpaceAttr memorySpaceAttr = blockAttr.getMemorySpace();
813 if (rank > 1 && memorySpaceAttr &&
814 memorySpaceAttr.getValue() == MemorySpace::SLM)
815 return emitError() <<
"SLM is only supported for 1D block tensor";
819 return emitError() <<
"unsupported element type " << elementType
820 <<
": expected integer or float";
824 int chunkAlignmentFactor =
828 auto scatterAttr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(encoding);
830 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
831 if (rank == 1 && chunkSize != 1)
832 return emitError() <<
"expected non-contiguous elements for 1D tensor";
838 if (shape.back() != chunkSize)
839 return emitError() <<
"expected last dim of tensor to match chunk size";
840 if (shape.back() % chunkAlignmentFactor != 0)
841 return emitError() <<
"expected last dim of tensor to be a multiple of "
842 << chunkAlignmentFactor;
846 auto layoutAttr = llvm::dyn_cast_if_present<LayoutAttr>(layout);
848 if (rank != (
size_t)layoutAttr.getRank())
849 return emitError() <<
"expected layout rank to match tensor rank";
851 auto laneData = layoutAttr.getLaneData();
852 if (scatterAttr && laneData) {
856 int64_t chunkSize = scatterAttr.getChunkSizeAsInt();
857 if (chunkSize > 1 && laneData[rank - 1] % chunkAlignmentFactor)
859 <<
"expected last dim of lane_data to be a multiple of: "
860 << chunkAlignmentFactor;
863 if (!XeGPUDialect::isEvenlyDistributable(shape, layoutAttr)) {
864 std::string shapeStr;
865 llvm::raw_string_ostream stream(shapeStr);
866 llvm::interleaveComma(shape, stream);
867 return emitError() <<
"cannot distribute [" << shapeStr <<
"] using "
877mlir::Type MemDescType::parse(AsmParser &parser) {
878 llvm::SmallVector<int64_t> shape;
879 mlir::Type elementType;
880 mlir::FailureOr<MemLayoutAttr> layout;
888 parser.
emitError(shapeLoc,
"failed to parse parameter 'shape'");
893 if (mlir::failed(parser.
parseType(elementType))) {
894 parser.
emitError(elemTypeLoc,
"failed to parse parameter 'elementType'");
902 if (mlir::failed(res))
912 return MemDescType::getChecked(
914 elementType, layout.value_or(MemLayoutAttr()));
917void MemDescType::print(AsmPrinter &printer)
const {
924 if (
auto layout = getMemLayout())
925 printer <<
", " << layout;
934Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
939 llvm::SmallDenseSet<StringRef> seenKeys;
940 SmallVector<NamedAttribute> attributes;
942 auto parseElt = [&]() -> ParseResult {
945 return parser.
emitError(loc,
"expected valid attribute name");
947 if (!seenKeys.insert(nameId).second)
948 return parser.
emitError(loc,
"duplicate key '")
949 << nameId <<
" in mem layout attribute";
957 attributes.emplace_back(nameId, attr);
973 loc, context, DictionaryAttr::get(context, attributes));
976void MemLayoutAttr::print(AsmPrinter &printer)
const {
978 ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
979 for (
size_t i = 0; i < attrs.size(); i++) {
980 printer << attrs[i].getName().str() <<
" = " << attrs[i].getValue();
981 if (i < attrs.size() - 1)
990template <
typename ArithOp>
995 return ArithOp::create(builder, loc, aVal, bVal).getResult();
1000 genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
1004 genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
1008 genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
1011#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
1020 assert(offsets.size() == blockShape.size() &&
1021 "offsets and blockShape must have the same size");
1025 for (
auto [offset, block] : llvm::zip(offsets, blockShape)) {
1026 divs.push_back(
div(offset, block));
1027 rems.push_back(
rem(offset, block));
1029 blockedOffsets.append(divs.begin(), divs.end());
1030 blockedOffsets.append(rems.begin(), rems.end());
1032 return blockedOffsets;
1040 ArrayAttr strideAttr = getStrideAttr();
1042 for (
Attribute attr : strideAttr.getValue()) {
1043 strides.push_back(cast<IntegerAttr>(attr).getInt());
1051 llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
1052 llvm::sort(perm, [&](
int a,
int b) {
return strides[a] < strides[
b]; });
1054 assert(strides[perm[0]] == 1 &&
"inner most dim must have stride 1");
1056 SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
1057 innerBlkStride[perm[0]] = 1;
1058 for (
size_t i = 1; i < perm.size(); ++i)
1059 innerBlkStride[perm[i]] =
1060 innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
1066 SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
1067 SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
1068 for (
size_t i = 0; i < perm.size() - 1; ++i) {
1069 matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
1070 BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
1073 int64_t innerBlkSize = 1;
1074 for (
auto s : innerBlkShape)
1077 SmallVector<int64_t> outerBlkStride(matrixShape.size());
1078 outerBlkStride[perm[0]] = innerBlkSize;
1079 for (
size_t i = 0; i < perm.size() - 1; ++i) {
1080 outerBlkStride[perm[i + 1]] =
1081 outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
1085 SmallVector<int64_t> blockedStrides;
1086 blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
1087 blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
1089 return blockedStrides;
1093Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
1094 ArrayRef<OpFoldResult> offsets) {
1097 SmallVector<int64_t> blockShape = getBlockShape();
1098 SmallVector<int64_t> strides = getStrideShape();
1099 SmallVector<OpFoldResult> blockedOffsets;
1102 if (llvm::equal(blockShape, matrixShape)) {
1104 strides.erase(strides.begin(), strides.begin() + matrixShape.size());
1106 assert(offsets.size() == blockShape.size() &&
1107 "offsets and blockShape must have the same size");
1111 SmallVector<OpFoldResult> divs, rems;
1113 for (
auto [offset, block] : llvm::zip(offsets, blockShape)) {
1114 divs.push_back(
div(offset, block));
1115 rems.push_back(
rem(offset, block));
1117 blockedOffsets.append(divs.begin(), divs.end());
1118 blockedOffsets.append(rems.begin(), rems.end());
1119 offsets = blockedOffsets;
1124 for (
size_t i = 0; i < offsets.size(); ++i) {
1125 OpFoldResult mulResult =
mul(offsets[i], strides[i]);
1127 linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
1130 return linearOffset;
1136#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.cpp.inc>
1137#define GET_ATTRDEF_CLASSES
1138#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.cpp.inc>
1139#define GET_TYPEDEF_CLASSES
1140#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.cpp.inc>
static Type getElementType(Type type)
Determine the element type of type.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
void printDimensionList(ArrayRef< int64_t > shape)
Attributes are known-constant values of operations.
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents a single result from folding an operation.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Specialization of arith.constant op that returns an integer of index type.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
auto getDims(VectorType vType)
Returns a range over the dims (size and scalability) of a VectorType.
constexpr unsigned generalPackedFormatBitSize
static SetVector< int64_t > mapSlicedDimsToParentSpace(const SetVector< int64_t > &dimsToMap, ArrayRef< int64_t > sliceDims)
SmallVector< OpFoldResult > getBlockedOffsets(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > offsets, ArrayRef< int64_t > blockShape)
OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc, OpBuilder &builder)
static SmallVector< SmallVector< Value > > genCoordinates(OpBuilder &builder, Location loc, SmallVector< Value > delinearizedId, ArrayRef< int64_t > subShapesLayout, ArrayRef< int64_t > subShape, ArrayRef< int64_t > srcShape)
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.