25#include "llvm/Support/Casting.h"
26#include "llvm/Support/FormatVariadic.h"
35 for (
const auto &vals : values)
36 llvm::append_range(
result, vals);
42 auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout());
45 if (!layout || !layout.isForSubgroup())
50 auto tdescShape = tdescTy.getShape();
51 auto elementType = tdescTy.getElementType();
56 int64_t sgSize = llvm::product_of(laneLayout);
60 for (
auto [tdescDim, laneDim, laneDataDim] :
61 llvm::zip_equal(tdescShape, laneLayout, laneData)) {
62 assert((tdescDim % (laneDim * laneDataDim) == 0) &&
63 "tensor descriptor shape is not distributable");
64 tensorSize *= tdescDim;
67 tensorSize *= tdescTy.getArrayLength();
69 return VectorType::get({tensorSize / sgSize}, elementType);
74 xegpu::LayoutAttr layout) {
75 int64_t rank = originalType.getRank();
77 if (rank < 1 || rank > 3)
84 arrayLength =
shape[0];
87 auto helperTdescTy = xegpu::TensorDescType::get(
88 shape, originalType.getElementType(), arrayLength,
90 xegpu::MemorySpace::Global, layout);
96 VectorType originalType) {
99 assert((isa<xegpu::LayoutAttr>(layout) || isa<xegpu::SliceAttr>(layout)) &&
100 "Expecting a valid layout.");
102 int64_t vectorRank = originalType.getRank();
103 int64_t layoutRank = layout.getRank();
104 assert(vectorRank >= layoutRank &&
"Vector rank must be >= layout rank.");
108 int64_t offset = vectorRank - layoutRank;
112 auto distributedShapeOrFailure =
113 layout.computeDistributedShape(trailingShape);
114 if (
failed(distributedShapeOrFailure))
118 fullShape.begin() + offset);
119 resultShape.append(distributedShapeOrFailure->begin(),
120 distributedShapeOrFailure->end());
121 return VectorType::get(resultShape, originalType.getElementType());
125 const StringRef prefix(
"layout_operand_");
126 unsigned idx =
const_cast<OpOperand &
>(operand).getOperandNumber();
127 return llvm::formatv(
"{0}{1}", prefix, idx).str();
131 const StringRef prefix =
"layout_result_";
132 return llvm::formatv(
"{0}{1}", prefix,
result.getResultNumber()).str();
140 dyn_cast_if_present<xegpu::TensorDescType>(value.
getType()))
141 return tdescTy.getLayoutAttr();
143 if (
auto result = dyn_cast<OpResult>(value)) {
145 assert(defOp &&
"result must have a defining op");
147 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
148 auto layout = anchorOp.getAnchorLayout();
153 if (defOp->
hasAttr(layoutName)) {
155 defOp->
getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
160 if (
auto arg = dyn_cast<BlockArgument>(value)) {
161 auto *parentOp = arg.getOwner()->getParentOp();
162 if (
auto loop = dyn_cast_if_present<LoopLikeOpInterface>(parentOp)) {
163 OpOperand *tiedInit = loop.getTiedLoopInit(arg);
171xegpu::DistributeLayoutAttr
174 unsigned idx =
const_cast<OpOperand &
>(opr).getOperandNumber();
176 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
177 if (
auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
179 return dpasOp.getLayoutAAttr();
180 }
else if (idx == 1) {
181 return dpasOp.getLayoutBAttr();
182 }
else if (idx == 2) {
183 return dpasOp.getLayoutCdAttr();
186 if (
auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
187 return convertOp.getInputLayoutAttr();
189 auto layout = anchorOp.getAnchorLayout();
196 if (isa<xegpu::StoreNdOp, xegpu::StoreMatrixOp>(op) && (idx < 2))
199 if (isa<xegpu::StoreScatterOp>(op)) {
200 xegpu::StoreScatterOp store(op);
201 int chunkSize = store.getChunkSize().value_or(1);
202 if (layout && idx >= 2 && chunkSize > 1)
203 return layout.dropDims(llvm::to_vector(
204 llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
207 if (isa<xegpu::LoadGatherOp>(op)) {
208 xegpu::LoadGatherOp
load(op);
209 int chunkSize =
load.getChunkSize().value_or(1);
210 if (layout && idx >= 1 && chunkSize > 1)
211 return layout.dropDims(llvm::to_vector(
212 llvm::seq<int64_t>(layout.getRank() - 1, layout.getRank())));
219 auto layout = op->
getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
228xegpu::DistributeLayoutAttr
231 const std::string &name) {
232 xegpu::DistributeLayoutAttr candidate = layout;
234 if (
auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
235 if (
auto perm = loadOp.getLayoutAttr())
244xegpu::DistributeLayoutAttr
247 const std::string &name) {
248 xegpu::DistributeLayoutAttr candidate = layout;
249 unsigned idx =
const_cast<OpOperand &
>(operand).getOperandNumber();
251 if (
auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
253 if (
auto perm = storeOp.getLayoutAttr())
265 const mlir::xegpu::DistributeLayoutAttr layout) {
268 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
269 if (anchorOp.getAnchorLayout() == layout)
271 anchorOp.setAnchorLayout(layout);
287 const DistributeLayoutAttr layout) {
289 unsigned idx =
const_cast<OpOperand &
>(operand).getOperandNumber();
294 if (
auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(owner)) {
295 if (
auto dpasOp = dyn_cast<xegpu::DpasOp>(owner)) {
297 return dpasOp.setLayoutAAttr(layout);
298 }
else if (idx == 1) {
299 return dpasOp.setLayoutBAttr(layout);
300 }
else if (idx == 2) {
301 return dpasOp.setLayoutCdAttr(layout);
304 if (
auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(owner)) {
305 return convertOp.setInputLayoutAttr(layout);
311 if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
314 anchorOp.setAnchorLayout(layout);
318 anchorOp.setAnchorLayout(layout);
332template <
typename T,
typename>
333xegpu::DistributeLayoutAttr
335 Operation *op = operandOrResult.getOwner();
339 auto layout = op->
getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
346template xegpu::DistributeLayoutAttr
348template xegpu::DistributeLayoutAttr
351template <
typename T,
typename>
353 const xegpu::DistributeLayoutAttr layout) {
354 Operation *owner = operandOrResult.getOwner();
356 if (owner->
hasAttrOfType<xegpu::DistributeLayoutAttr>(name)) {
366 const mlir::xegpu::DistributeLayoutAttr layout);
370 const mlir::xegpu::DistributeLayoutAttr layout);
375 auto vecTy = dyn_cast<VectorType>(value.
getType());
383 int64_t srcShapeRank = srcShape.size();
387 int64_t rankDiff = srcShapeRank - targetShapeRank;
388 std::fill(adjustedTargetShape.begin(), adjustedTargetShape.begin() + rankDiff,
390 llvm::copy(
shape, adjustedTargetShape.begin() + rankDiff);
396 Value slice = vector::ExtractStridedSliceOp::create(
397 builder, loc, value, offsets, adjustedTargetShape, staticStrides);
400 if (srcShapeRank > targetShapeRank) {
401 auto targetTy = VectorType::get(
shape, vecTy.getElementType());
402 slice = vector::ShapeCastOp::create(builder, loc, targetTy, slice);
413 VectorType inputTy = dyn_cast<VectorType>(values[0].
getType());
414 assert(llvm::all_of(values.
getTypes(),
415 [&](
Type type) { return type == inputTy; }) &&
416 "values must be of the same VectorType");
418 Type elemTy = inputTy.getElementType();
421 VectorType resultTy = VectorType::get(
shape, elemTy);
426 for (
auto [src, offsets] :
429 result = vector::InsertStridedSliceOp::create(builder, loc, src,
result,
430 offsets, staticStrides);
441 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
447 converter.addConversion([](
Type type) ->
Type {
return type; });
448 converter.addConversion([](VectorType type) ->
Type {
449 return RankedTensorType::get(type.getShape(), type.getElementType());
451 converter.addSourceMaterialization(materializeCast);
452 converter.addTargetMaterialization(materializeCast);
454 mlir::ConversionTarget
target(*context);
455 target.addLegalOp<UnrealizedConversionCastOp>();
460 (
void)mlir::applyPartialConversion(op,
target, std::move(patterns));
466 op->
walk([](UnrealizedConversionCastOp castOp) {
467 if (castOp.getNumOperands() != 1 || castOp.getNumResults() != 1)
470 Value input = castOp.getInputs()[0];
472 auto inputTy = dyn_cast<VectorType>(input.
getType());
473 auto resultTy = dyn_cast<RankedTensorType>(
result.getType());
476 if (!inputTy || !resultTy)
479 xegpu::DistributeLayoutAttr layout =
484 RankedTensorType newTy = resultTy.cloneWithEncoding(layout);
489 if (
auto loop = dyn_cast<LoopLikeOpInterface>(use.getOwner())) {
495 if (
auto whileOp = dyn_cast<scf::WhileOp>(use.getOwner())) {
496 unsigned idx = use.getOperandNumber();
505 op->
walk([](scf::YieldOp yieldOp) {
508 unsigned idx = r.getResultNumber();
509 Type resultTy = r.getType();
510 Type yieldTy = yieldOp.getResults()[idx].getType();
511 if (isa<RankedTensorType>(resultTy) && yieldTy != resultTy)
524 class UnrealizedConversionCastOpPattern
525 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
526 using OpConversionPattern<
527 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
530 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
532 ConversionPatternRewriter &rewriter)
const override {
533 auto inputs = op.getOperands();
534 auto outputs = op.getOutputs();
536 if (inputs.size() != 1 || outputs.size() != 1)
539 auto inputTy = inputs[0].getType();
540 auto outputTy = outputs[0].getType();
542 if (isa<VectorType>(inputTy) && isa<RankedTensorType>(outputTy)) {
543 rewriter.replaceOpWithMultiple(op, adaptor.getInputs());
547 if (isa<RankedTensorType>(inputTy) && isa<VectorType>(outputTy)) {
549 auto newOp = UnrealizedConversionCastOp::create(rewriter, op.getLoc(),
551 rewriter.replaceOp(op, newOp);
558 converter.addSourceMaterialization(materializeCast);
561 return UnrealizedConversionCastOp::create(builder, loc, type, inputs)
565 mlir::ConversionTarget
target(*context);
566 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
567 [](UnrealizedConversionCastOp op) {
568 auto isTensorTy = [](
Type type) {
569 return isa<RankedTensorType>(type);
575 patterns.insert<UnrealizedConversionCastOpPattern>(context);
578 (
void)mlir::applyPartialConversion(op,
target, std::move(patterns));
588 auto targetAttrs = gpuModuleOp.getTargets();
590 for (
auto &attr : *targetAttrs) {
591 auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
593 return xevmAttr.getChip().str();
605 assert(
lhs.size() ==
rhs.size() &&
"lhs and rhs must have the same size");
607 for (
auto [l, r] : llvm::zip_equal(
lhs,
rhs)) {
610 results.push_back(builder.
createOrFold<arith::AddIOp>(loc, lval, rval));
633 a = a.slice(a.size() -
b.size());
641 static_assert(std::is_integral<T>::value,
"T must be an integer type");
644 if (!candidateMultiples.empty())
646 SmallVector<T>(candidateMultiples.begin(), candidateMultiples.end());
647 for (T candidate : candidates) {
648 for (T multiple : multiples) {
649 int value =
static_cast<int>(candidate * multiple);
650 if (value != 0 && dim % value == 0 && value > largest)
658 vector::CombiningKind kind, uint32_t size) {
660 Value laneVal = vector::ReductionOp::create(builder, loc, kind, input);
662 for (uint64_t i = 1; i < size; i <<= 1) {
664 gpu::ShuffleOp::create(builder, loc, laneVal, i, size,
665 gpu::ShuffleMode::XOR)
667 laneVal = makeArithReduction(builder, loc, kind, laneVal, shuffled);
674 vector::CombiningKind kind,
677 VectorType sourceType = src.
getType();
678 int64_t sourceRank = sourceType.getRank();
681 assert(sourceRank >= 2 &&
"expected at least a 2D source vector");
682 for (
int64_t i = 0; i < sourceRank - 2; ++i)
683 assert(sourceType.getShape()[i] == 1 &&
684 "expected leading dimensions to be unit");
685 int64_t rowIdx = sourceRank - 2;
686 int64_t columnIdx = sourceRank - 1;
687 int64_t sourceH = sourceType.getShape()[rowIdx];
688 int64_t sourceW = sourceType.getShape()[columnIdx];
689 int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
691 TypedAttr zeroAttr = rewriter.
getZeroAttr(sourceType.getElementType());
692 Value reductionResult = arith::ConstantOp::create(
693 rewriter, loc,
acc.getType(),
704 for (
int i = 0; i < nSlices; ++i) {
710 if (reductionDim == columnIdx) {
711 sliceOffsets[rowIdx] = i;
712 sliceSizes[columnIdx] = sourceW;
714 sliceOffsets[columnIdx] = i;
715 sliceSizes[rowIdx] = sourceH;
718 vector::ExtractStridedSliceOp extractOp =
719 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
720 sliceSizes, strides);
724 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
726 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
728 VectorType::get({nSliceElements}, sourceType.getElementType()),
729 extractOp.getResult());
739 accIdx[accRank - 1] = i;
740 Value accExtract = vector::ExtractOp::create(rewriter, loc,
acc, accIdx);
741 Value reduction = vector::ReductionOp::create(
742 rewriter, loc, kind, slice.getResult(), accExtract);
743 reductionResult = vector::InsertOp::create(rewriter, loc, reduction,
744 reductionResult, accIdx);
748 return reductionResult;
753 vector::CombiningKind kind,
int64_t reductionDim,
int64_t reductionSize,
755 VectorType sourceType = src.
getType();
756 int64_t sourceRank = sourceType.getRank();
759 assert(sourceRank >= 2 &&
"expected at least a 2D source vector");
760 for (
int64_t i = 0; i < sourceRank - 2; ++i)
761 assert(sourceType.getShape()[i] == 1 &&
762 "expected leading dimensions to be unit");
763 int64_t rowIdx = sourceRank - 2;
764 int64_t columnIdx = sourceRank - 1;
765 int64_t sourceH = sourceType.getShape()[rowIdx];
766 int64_t sourceW = sourceType.getShape()[columnIdx];
769 TypedAttr zeroAttr = rewriter.
getZeroAttr(sourceType.getElementType());
770 Value reductionResult = arith::ConstantOp::create(
771 rewriter, loc,
acc.getType(),
778 int nSlices = (reductionDim == rowIdx) ? sourceW : sourceH;
783 for (
int i = 0; i < nSlices; ++i) {
789 if (reductionDim == columnIdx) {
790 sliceOffsets[rowIdx] = i;
791 sliceSizes[columnIdx] = sourceW;
793 sliceOffsets[columnIdx] = i;
794 sliceSizes[rowIdx] = sourceH;
797 vector::ExtractStridedSliceOp extractOp =
798 vector::ExtractStridedSliceOp::create(rewriter, loc, src, sliceOffsets,
799 sliceSizes, strides);
800 int64_t nSliceElements = extractOp.getResult().getType().getNumElements();
801 vector::ShapeCastOp slice = vector::ShapeCastOp::create(
803 VectorType::get({nSliceElements}, sourceType.getElementType()),
804 extractOp.getResult());
807 accIdx[accRank - 1] = i;
808 Value accExtract = vector::ExtractOp::create(rewriter, loc,
acc, accIdx);
813 reductionResult = vector::InsertOp::create(rewriter, loc, fullReduce,
814 reductionResult, accIdx);
816 return reductionResult;
821 vector::CombiningKind kind) {
822 auto vecTy = dyn_cast<VectorType>(type);
823 Type elemTy = vecTy ? vecTy.getElementType() : type;
828 return arith::ConstantOp::create(
830 return arith::ConstantOp::create(builder, loc, cast<TypedAttr>(scalarAttr));
834 case vector::CombiningKind::ADD:
835 case vector::CombiningKind::XOR:
836 case vector::CombiningKind::OR:
837 case vector::CombiningKind::MAXUI:
840 case vector::CombiningKind::MUL:
841 case vector::CombiningKind::AND:
844 case vector::CombiningKind::MINSI:
845 if (
auto intTy = dyn_cast<IntegerType>(elemTy))
847 elemTy, APInt::getSignedMaxValue(intTy.getWidth())));
850 case vector::CombiningKind::MINUI:
851 if (
auto intTy = dyn_cast<IntegerType>(elemTy))
853 builder.
getIntegerAttr(elemTy, APInt::getMaxValue(intTy.getWidth())));
856 case vector::CombiningKind::MAXSI:
857 if (
auto intTy = dyn_cast<IntegerType>(elemTy))
859 elemTy, APInt::getSignedMinValue(intTy.getWidth())));
862 case vector::CombiningKind::MINNUMF:
863 case vector::CombiningKind::MINIMUMF:
864 if (
auto floatTy = dyn_cast<FloatType>(elemTy))
866 elemTy, APFloat::getInf(floatTy.getFloatSemantics())));
869 case vector::CombiningKind::MAXNUMF:
870 case vector::CombiningKind::MAXIMUMF:
871 if (
auto floatTy = dyn_cast<FloatType>(elemTy))
873 elemTy, APFloat::getInf(floatTy.getFloatSemantics(),
true)));
889 auto laneData = layout.getEffectiveLaneDataAsInt();
890 if (laneData.size() != 2)
892 return laneData[0] != 1;
904 auto laneLayout = layout.getEffectiveLaneLayoutAsInt();
905 if (laneLayout.size() != 2)
920 for (
size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx)
921 if (srcIdx < src.size() && src[srcIdx] == dst[dstIdx])
923 else if (dst[dstIdx] == 1)
924 expandedUnitDims.push_back(dstIdx);
927 return srcIdx == src.size();
944 splitDimGroups.clear();
945 for (
size_t dstIdx = 0; dstIdx < dst.size(); ++dstIdx) {
946 if (srcIdx >= src.size())
948 accumulatedSize *= dst[dstIdx];
949 currentDstDims.push_back(dstIdx);
951 if (accumulatedSize == src[srcIdx]) {
953 splitDimGroups.push_back(currentDstDims);
957 currentDstDims.clear();
958 }
else if (accumulatedSize > src[srcIdx]) {
962 return srcIdx == src.size();
xegpu::DistributeLayoutAttr maybePickPermanentLayout(xegpu::DistributeLayoutAttr layout, const OpResult &result, mlir::Operation *owner, const std::string &name)
Attributes are known-constant values of operations.
This class represents an argument of a Block.
IntegerAttr getIntegerAttr(Type type, int64_t value)
FloatAttr getFloatAttr(Type type, double value)
TypedAttr getZeroAttr(Type type)
TypedAttr getOneAttr(Type type)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
bool hasAttrOfType(NameT &&name)
bool hasAttr(StringAttr name)
Return true if the operation has an attribute with the provided name, false otherwise.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
result_range getOpResults()
MLIRContext * getContext()
Return the context this operation is associated with.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void setType(Type newType)
Mutate the type of this Value to be of the specified type.
Type getType() const
Return the type of this value.
static WalkResult advance()
Operation * getOwner() const
Return the owner of this operand.
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, Value v1, Value acc, arith::FastMathFlagsAttr fastmath=nullptr, Value mask=nullptr)
Returns the result value of reducing two scalar/vector values with the corresponding arith operation.
Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc, ValueRange values, ArrayRef< int64_t > shape)
Create a vector of shape from a set of values using vector.insert_stride_slice.
bool requirePacked(const DistributeLayoutAttr layout)
Helper function to check if the layout is packed.
void setTemporaryLayout(const T &operandOrResult, const DistributeLayoutAttr layout)
Value createReductionNeutralValue(OpBuilder &builder, Location loc, Type type, vector::CombiningKind kind)
Creates a constant filled with the neutral (identity) value for the given reduction kind.
void setDistributeLayoutAttr(const OpResult &Result, const DistributeLayoutAttr layout)
[to-be-deprecated] Sets the DistributeLayoutAttr for a given OpResult user should use setAnchorLayout...
Value subgroupReduction(Location loc, OpBuilder &builder, Value input, vector::CombiningKind kind, uint32_t size)
Given an input value representing per-lane data, this function returns the result after performing a ...
bool matchUnitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< int64_t > &expandedUnitDims)
int getLargestDivisor(T dim, ArrayRef< T > candidates, ArrayRef< T > candidateMultiples={})
Helper Function to find a proper instruction multiple for the user-supplied sg-level data shape (dive...
FailureOr< VectorType > getDistVecTypeBasedOnLaneLayout(DistributeLayoutAttr layout, VectorType originalType)
Helper function to get distributed vector type for a source vector type according to the lane_layout.
Value lowerToVectorReductions(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, Location loc, PatternRewriter &rewriter)
Given a src and an acc argumments from a vector::MultiDimReductionOp, lower to a set of vector::Reduc...
bool requireTranspose(const DistributeLayoutAttr layout, const uArch::uArch *uArch)
Helper function to check if the layout requires a transpose effect.
bool matchSplitDimExpansion(ArrayRef< int64_t > src, ArrayRef< int64_t > dst, SmallVector< SmallVector< int64_t > > &splitDimGroups)
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
std::string getTemporaryLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
SmallVector< Value > extractVectorsWithShapeFromValue(OpBuilder &builder, Location loc, Value value, ArrayRef< int64_t > shape)
Extract a set of small vectors from a value with a given shape using vector.extract_stride_slice.
DistributeLayoutAttr getTemporaryLayout(const T &operandOrResult)
get and set distribute layout attribute for non-anchor operations (and offsets/masks of load/store op...
Value lowerCrossLaneReductionToShuffles(TypedValue< VectorType > src, TypedValue< VectorType > acc, vector::CombiningKind kind, int64_t reductionDim, int64_t reductionSize, Location loc, PatternRewriter &rewriter)
Lowers cross-lane reductions to shuffle operations on a 2D vector.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
SmallVector< OpFoldResult > addElementwise(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with same length.
FailureOr< VectorType > getDistributedVectorType(xegpu::TensorDescType tdescTy)
If tensor descriptor has a layout attribute it is used in SIMT mode.
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
virtual int getSubgroupSize() const =0
StringRef getName() const