28#define GEN_PASS_DEF_XEGPUWGTOSGDISTRIBUTE
29#include "mlir/Dialect/XeGPU/Transforms/Passes.h.inc"
38static xegpu::RangeAttr getRangeSpecAttr(
Operation *op) {
41 if (
auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>(
42 parent->
getAttr(
"sg_id_range")))
49static std::pair<SmallVector<int64_t>,
int>
51 xegpu::DistributeLayoutAttr layout) {
54 if (layout && layout.isForWorkgroup()) {
56 if (!layout.getEffectiveSgDataAsInt().empty())
57 sgShape = layout.getEffectiveSgDataAsInt();
59 sgShape = *maybeDerivedSgData;
64 for (
size_t i = 0; i < distUnit.size(); ++i)
65 distUnit[i] = std::min(
shape[i], distUnit[i]);
68 return std::make_pair(sgShape, count);
77 typename = std::enable_if_t<llvm::is_one_of<
78 OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp,
79 xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
81genOffsetsList(ConversionPatternRewriter &rewriter, OpType op,
86 if (origOffsets.empty())
90 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
91 if (!layout || !layout.isForWorkgroup())
95 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
98 xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op);
100 int64_t startOfRange = sgIdRange.getStart().getInt();
101 int64_t endOfRange = sgIdRange.getEnd().getInt();
103 if (layout.getNumSubgroups() != endOfRange - startOfRange)
104 return rewriter.notifyMatchFailure(
105 op,
"sg_layout size must match the sg_id_range");
107 if (startOfRange > 0) {
108 Value startOfRangeVal =
110 sgId = index::SubOp::create(rewriter, loc, sgId, startOfRangeVal);
117 auto maybeDescOffsets =
118 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
119 if (
failed(maybeDescOffsets))
124 for (
const auto &sgOffsets : *maybeDescOffsets) {
127 offsetsList.push_back(std::move(newOffsets));
179struct WgToSgCreateNdOp :
public OpConversionPattern<xegpu::CreateNdDescOp> {
180 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
183 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
184 ConversionPatternRewriter &rewriter)
const override {
185 SmallVector<SmallVector<OpFoldResult>> offsetsList;
186 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
189 MLIRContext *ctx = op.getContext();
190 xegpu::TensorDescType tdescTy = op.getType();
191 ArrayRef<int64_t> wgShape = tdescTy.getShape();
192 Type elemTy = tdescTy.getElementType();
193 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
194 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
196 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
197 layout.dropSgLayoutAndData());
199 SmallVector<Value> newOps;
200 for (
auto offsets : offsetsList) {
201 auto newOp = xegpu::CreateNdDescOp::create(
202 rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets,
203 op.getMixedSizes(), op.getMixedStrides());
205 newOps.push_back(newOp);
207 rewriter.replaceOpWithMultiple(op, {newOps});
215struct WgToSgCreateNdOpNoOffset
216 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
217 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
220 matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
221 ConversionPatternRewriter &rewriter)
const override {
224 if (!op.getMixedOffsets().empty())
227 Location loc = op.getLoc();
228 MLIRContext *ctx = op.getContext();
229 xegpu::TensorDescType tdescTy = op.getType();
230 auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
231 if (!layout || !layout.isForWorkgroup())
234 Type elemTy = tdescTy.getElementType();
235 ArrayRef<int64_t> wgShape = tdescTy.getShape();
237 SmallVector<int64_t> sgShape;
239 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
240 xegpu::TensorDescType newTdescTy =
241 xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
242 layout.dropSgLayoutAndData());
244 SmallVector<Value> newCreateNdOps(count);
245 std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
246 return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
247 op.getSource(), op.getMixedSizes(),
248 op.getMixedStrides());
251 rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
257struct WgToSgLoadNdOp :
public OpConversionPattern<xegpu::LoadNdOp> {
258 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
260 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
261 ConversionPatternRewriter &rewriter)
const override {
262 if (!op.getMixedOffsets().empty())
265 SmallVector<Value> newLoadOps;
266 for (
auto src : adaptor.getTensorDesc()) {
267 xegpu::TensorDescType tdescTy =
268 dyn_cast<xegpu::TensorDescType>(src.getType());
269 ArrayRef<int64_t> srcShape = tdescTy.getShape();
270 VectorType newResTy = VectorType::get(srcShape, tdescTy.getElementType());
271 auto newLoadOp = xegpu::LoadNdOp::create(rewriter, op.getLoc(), newResTy,
272 src, op->getAttrs());
273 newLoadOps.push_back(newLoadOp);
275 rewriter.replaceOpWithMultiple(op, {newLoadOps});
276 return mlir::success();
283struct WgToSgStoreNdOp :
public OpConversionPattern<xegpu::StoreNdOp> {
284 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
286 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter)
const override {
288 if (!op.getMixedOffsets().empty())
291 for (
auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc()))
292 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, t, op.getL1HintAttr(),
293 op.getL2HintAttr(), op.getL3HintAttr());
295 rewriter.eraseOp(op);
302struct WgToSgLoadNdOpWithOffset :
public OpConversionPattern<xegpu::LoadNdOp> {
303 using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
305 matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter)
const override {
308 SmallVector<SmallVector<OpFoldResult>> offsetsList;
309 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
312 SmallVector<Value> newOps;
313 for (
auto [tdesc, offsets] :
314 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
315 auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
316 VectorType newResTy =
317 VectorType::get(tdescTy.getShape(), tdescTy.getElementType());
318 auto newOp = xegpu::LoadNdOp::create(
319 rewriter, op.getLoc(), newResTy, tdesc, offsets,
320 nullptr,
nullptr, op.getL1HintAttr(),
321 op.getL2HintAttr(), op.getL3HintAttr());
322 newOps.push_back(newOp);
324 rewriter.replaceOpWithMultiple(op, {newOps});
332struct WgToSgStoreNdOpWithOffset
333 :
public OpConversionPattern<xegpu::StoreNdOp> {
334 using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
336 matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
337 ConversionPatternRewriter &rewriter)
const override {
338 SmallVector<SmallVector<OpFoldResult>> offsetsList;
339 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
342 for (
auto [v, tdesc, offsets] :
343 llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) {
344 xegpu::StoreNdOp::create(rewriter, op.getLoc(), v, tdesc, offsets,
345 op.getL1HintAttr(), op.getL2HintAttr(),
348 rewriter.eraseOp(op);
356struct WgToSgPrefetchNdOpWithOffset
357 :
public OpConversionPattern<xegpu::PrefetchNdOp> {
358 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
360 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
361 ConversionPatternRewriter &rewriter)
const override {
362 SmallVector<SmallVector<OpFoldResult>> offsetsList;
363 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
366 for (
auto [tdesc, offsets] :
367 llvm::zip(adaptor.getTensorDesc(), offsetsList)) {
368 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(), tdesc, offsets,
369 op.getL1HintAttr(), op.getL2HintAttr(),
372 rewriter.eraseOp(op);
381struct WgToSgUpdateNdOffsetOp
382 :
public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
383 using OpConversionPattern<xegpu::UpdateNdOffsetOp>::OpConversionPattern;
385 matchAndRewrite(xegpu::UpdateNdOffsetOp op, OneToNOpAdaptor adaptor,
386 ConversionPatternRewriter &rewriter)
const override {
387 llvm::SmallVector<Value> newUpdateTileOffsetOps;
388 for (
auto tDesc : adaptor.getTensorDesc()) {
389 auto newUpdateTileOffsetOp = xegpu::UpdateNdOffsetOp::create(
390 rewriter, op.getLoc(), tDesc.getType(), tDesc, op.getOffsets(),
391 op.getConstOffsets());
392 newUpdateTileOffsetOps.push_back(newUpdateTileOffsetOp);
395 rewriter.replaceOpWithMultiple(op, {newUpdateTileOffsetOps});
401struct WgToSgDpasOp :
public OpConversionPattern<xegpu::DpasOp> {
402 using OpConversionPattern<xegpu::DpasOp>::OpConversionPattern;
404 matchAndRewrite(xegpu::DpasOp op, OneToNOpAdaptor adaptor,
405 ConversionPatternRewriter &rewriter)
const override {
406 Location loc = op.getLoc();
407 VectorType resultTy = op.getResult().getType();
408 if (resultTy.getRank() != 2)
416 SmallVector<Value> newDpasOps;
417 for (
auto aVec : adaptor.getLhs()) {
418 for (
auto bVec : adaptor.getRhs()) {
420 llvm::SmallVector<Value> operands({aVec, bVec});
423 tmpC = adaptor.getAcc()[i++];
424 operands.push_back(tmpC);
427 ArrayRef<int64_t> aVecShape =
428 llvm::cast<VectorType>(aVec.getType()).getShape();
429 ArrayRef<int64_t> bVecShape =
430 llvm::cast<VectorType>(bVec.getType()).getShape();
431 VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]},
432 resultTy.getElementType());
433 tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands);
435 originalLayout.dropSgLayoutAndData());
437 newDpasOps.push_back(tmpC);
440 rewriter.replaceOpWithMultiple(op, {newDpasOps});
446struct WgToSgPrefetchNdOp :
public OpConversionPattern<xegpu::PrefetchNdOp> {
447 using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
449 matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
450 ConversionPatternRewriter &rewriter)
const override {
452 int64_t offsetSize =
static_cast<int64_t
>(op.getOffsets().size());
453 if ((offsetSize != 0) || op.getConstOffsetsAttr())
456 for (
auto src : adaptor.getTensorDesc())
457 xegpu::PrefetchNdOp::create(rewriter, op.getLoc(),
TypeRange(), src,
459 rewriter.eraseOp(op);
465struct WgToSgVectorBroadcastOp
466 :
public OpConversionPattern<vector::BroadcastOp> {
467 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
470 matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
471 ConversionPatternRewriter &rewriter)
const override {
473 VectorType resultType = op.getResult().getType();
474 ArrayRef<int64_t> wgShape = resultType.getShape();
476 xegpu::DistributeLayoutAttr layout =
478 if (!layout || !layout.isForWorkgroup())
481 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
482 VectorType newResultType =
483 VectorType::get(sgShape, resultType.getElementType());
485 if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
488 SmallVector<Value> newBroadcastOps;
489 for (
auto operand : adaptor.getOperands().front()) {
490 auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
491 newResultType, operand);
493 layout.dropSgLayoutAndData());
495 newBroadcastOps.push_back(newBroadcast.getResult());
497 rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
504 WgToSgElementwiseOp(MLIRContext *ctx)
505 : ConversionPattern(MatchAnyOpTypeTag(), 1, ctx) {}
508 matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
509 ConversionPatternRewriter &rewriter)
const override {
515 assert(resultType &&
"Expected result to be a VectorType");
517 ArrayRef<int64_t> wgShape = resultType.getShape();
519 xegpu::DistributeLayoutAttr layout =
521 if (!layout || !layout.isForWorkgroup())
526 size_t numVariants = operands.empty() ? 0 : operands.front().size();
528 if (llvm::any_of(operands, [&](
const ValueRange &operandVec) {
529 return operandVec.size() != numVariants;
534 VectorType newResultType =
535 VectorType::get(sgShape, resultType.getElementType());
537 for (
size_t i = 0; i < numVariants; ++i) {
539 for (
auto &operandVec : operands)
540 opOperands.push_back(operandVec[i]);
543 state.addOperands(opOperands);
544 state.addTypes(newResultType);
549 dyn_cast<xegpu::DistributeLayoutAttr>(attr.getValue())) {
550 if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
551 !layout.getEffectiveInstDataAsInt().empty())
552 state.addAttribute(attr.getName(), layout.dropSgLayoutAndData());
554 state.addAttribute(attr.getName(), attr.getValue());
557 Operation *newOp = rewriter.create(state);
558 newResults.push_back(newOp->
getResult(0));
561 rewriter.replaceOpWithMultiple(op, {newResults});
592struct WgToSgConvertLayoutOp
593 :
public OpConversionPattern<xegpu::ConvertLayoutOp> {
594 using OpConversionPattern<xegpu::ConvertLayoutOp>::OpConversionPattern;
597 ConversionPatternRewriter &rewriter)
const override {
599 auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout());
600 auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout());
602 if (!input || !
target || !input.isForWorkgroup() ||
604 return rewriter.notifyMatchFailure(
605 op,
"Input and target layouts must have subgroup layout");
616 if (inputSgLayout != targetSgLayout || inputSgData != targetSgData ||
617 inputOrder != targetOrder)
620 input = input.dropSgLayoutAndData();
626 for (
auto [i, src] : llvm::enumerate(adaptor.getSource())) {
627 auto newOp = xegpu::ConvertLayoutOp::create(
628 rewriter, op.getLoc(), src.getType(), src, input,
target);
632 rewriter.replaceOpWithMultiple(op, {newOps});
668struct UnrealizedConversionCastOpPattern
669 :
public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
670 using OpConversionPattern<
671 mlir::UnrealizedConversionCastOp>::OpConversionPattern;
674 matchAndRewrite(mlir::UnrealizedConversionCastOp op,
OneToNOpAdaptor adaptor,
675 ConversionPatternRewriter &rewriter)
const override {
678 auto inputTy = dyn_cast<VectorType>(inputs[0].
getType());
679 auto outputTy = dyn_cast<VectorType>(op->getOpResult(0).getType());
681 if (!inputTy || !outputTy || !llvm::all_equal(op->getResultTypes()) ||
682 !llvm::all_equal(
ValueRange(inputs).getTypes()))
690 if (op.getNumOperands() == 1 &&
691 llvm::equal(
ValueRange(inputs).getTypes(), op->getResultTypes())) {
692 rewriter.replaceOp(op, inputs);
703 if (op.getNumResults() == 1 &&
705 rewriter.replaceOpWithMultiple(op, {inputs});
709 return mlir::failure();
714struct WgToSgArithConstantOp :
public OpConversionPattern<arith::ConstantOp> {
715 using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
718 matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
719 ConversionPatternRewriter &rewriter)
const override {
720 auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
721 auto vecType = dyn_cast<VectorType>(op.getType());
722 if (!vecAttr || !vecType)
725 xegpu::DistributeLayoutAttr layout =
727 if (!layout || !layout.isForWorkgroup())
730 ArrayRef<int64_t> wgShape = vecType.getShape();
731 SmallVector<int64_t> sgShape;
733 std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
735 auto newType = VectorType::get(sgShape, vecType.getElementType());
736 Location loc = op.getLoc();
737 auto eltType = vecType.getElementType();
739 auto setLayout = [&](Value val) {
741 layout.dropSgLayoutAndData());
744 if (vecAttr.isSplat()) {
746 Attribute singleVal = vecAttr.getSplatValue<Attribute>();
748 auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
749 setLayout(cstOp->getResult(0));
750 rewriter.replaceOp(op, cstOp);
752 }
else if (sgShape == wgShape) {
755 arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
756 setLayout(newConstOp->getResult(0));
757 rewriter.replaceOp(op, newConstOp);
763 if (!eltType.isIndex())
764 return rewriter.notifyMatchFailure(
765 op,
"Unsupported element type for non-splat constant op.");
767 if (wgShape.size() > 2)
768 return rewriter.notifyMatchFailure(
769 op,
"Only 1D & 2D vector constant supported");
771 SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
772 int64_t rowStride = 0, colStride = 0;
773 int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
774 int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
778 colStride = cast<IntegerAttr>(values[1]).getInt() -
779 cast<IntegerAttr>(values[0]).getInt();
782 rowStride = cast<IntegerAttr>(values[cols]).getInt() -
783 cast<IntegerAttr>(values[0]).getInt();
786 for (int64_t r = 0; r < rows; ++r) {
787 for (int64_t c = 0; c < cols; ++c) {
788 int64_t idx = r * cols + c;
790 if (c > 0 && cols > 1) {
791 int64_t prevIdx = r * cols + (c - 1);
792 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
793 cast<IntegerAttr>(values[prevIdx]).getInt();
794 if (diff != colStride)
795 return rewriter.notifyMatchFailure(
796 op,
"Non-constant column stride in constant op.");
799 if (r > 0 && rows > 1) {
800 int64_t prevIdx = (r - 1) * cols + c;
801 int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
802 cast<IntegerAttr>(values[prevIdx]).getInt();
803 if (diff != rowStride)
804 return rewriter.notifyMatchFailure(
805 op,
"Non-constant row stride in constant op.");
813 SmallVector<Attribute> baseTileValues;
814 int baseTileCols = sgShape[sgShape.size() - 1];
815 int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
816 for (int64_t r = 0; r < baseTileRows; ++r) {
817 for (int64_t c = 0; c < baseTileCols; ++c) {
818 baseTileValues.push_back(values[r * cols + c]);
824 auto baseConstVec = arith::ConstantOp::create(rewriter, loc, tileAttr);
828 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
830 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
834 SmallVector<Value, 2> strideConsts;
835 strideConsts.push_back(
839 strideConsts.begin(),
842 SmallVector<Value> newConstOps;
843 for (
auto offsets : *sgOffsets) {
846 for (
size_t i = 0; i < strideConsts.size(); ++i) {
848 arith::MulIOp::create(rewriter, loc, rewriter.getIndexType(),
849 offsets[i], strideConsts[i]);
850 mulOffset = arith::AddIOp::create(
851 rewriter, loc, rewriter.getIndexType(), mulOffset,
mul);
854 auto bcastOffset = vector::BroadcastOp::create(
855 rewriter, loc, baseConstVec.getType(), mulOffset);
857 arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
858 setLayout(baseConstVec);
859 setLayout(bcastOffset);
860 setLayout(finalConst);
861 newConstOps.push_back(finalConst);
863 rewriter.replaceOpWithMultiple(op, {newConstOps});
871struct WgToSgLoadGatherOpWithOffset
872 :
public OpConversionPattern<xegpu::LoadGatherOp> {
873 using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern;
875 matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor,
876 ConversionPatternRewriter &rewriter)
const override {
878 if (!op.getOffsets())
881 Location loc = op.getLoc();
882 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
885 ArrayRef<int64_t> wgShape = resultType.getShape();
887 xegpu::DistributeLayoutAttr layout =
889 if (!layout || !layout.isForWorkgroup())
892 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
895 auto offsetsVecType =
896 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
898 dyn_cast<VectorType>(adaptor.getMask().front().getType());
899 if (!offsetsVecType || !maskVecType ||
900 offsetsVecType.getShape() != maskVecType.getShape()) {
901 return rewriter.notifyMatchFailure(op,
902 "offsets have not been distributed");
905 SmallVector<Value> newLoadOps;
907 rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1));
908 VectorType newTy = VectorType::get(sgShape, resultType.getElementType());
909 for (
auto [offsets, mask] :
910 llvm::zip(adaptor.getOffsets(), adaptor.getMask())) {
911 auto newLayout = layout.dropSgLayoutAndData();
912 auto newLoadOp = xegpu::LoadGatherOp::create(
913 rewriter, loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr,
914 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
917 newLoadOps.push_back(newLoadOp);
919 rewriter.replaceOpWithMultiple(op, {newLoadOps});
926struct WgToSgStoreScatterOpWithOffset
927 :
public OpConversionPattern<xegpu::StoreScatterOp> {
928 using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern;
930 matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor,
931 ConversionPatternRewriter &rewriter)
const override {
933 if (!op.getOffsets())
936 Location loc = op.getLoc();
937 VectorType valueType = dyn_cast<VectorType>(op.getValue().getType());
941 xegpu::DistributeLayoutAttr layout =
943 if (!layout || !layout.isForWorkgroup())
947 auto offsetsVecType =
948 dyn_cast<VectorType>(adaptor.getOffsets().front().getType());
950 dyn_cast<VectorType>(adaptor.getMask().front().getType());
951 if (!offsetsVecType || !maskVecType ||
952 offsetsVecType.getShape() != maskVecType.getShape()) {
953 return rewriter.notifyMatchFailure(op,
954 "offsets have not been distributed");
957 auto chunkSizeOpt = op.getChunkSize();
958 int64_t chunkSize = chunkSizeOpt ?
static_cast<int64_t
>(*chunkSizeOpt) : 1;
959 auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
960 for (
auto [val, offs, mask] : llvm::zip(
961 adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
962 auto store = xegpu::StoreScatterOp::create(
963 rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
964 op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
965 layout.dropSgLayoutAndData());
967 for (OpOperand &operand : store->getOpOperands()) {
969 if (operand.getOperandNumber() == 1)
974 rewriter.eraseOp(op);
979struct WgToSgLoadMatrixOp :
public OpConversionPattern<xegpu::LoadMatrixOp> {
980 using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern;
982 matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor,
983 ConversionPatternRewriter &rewriter)
const override {
985 SmallVector<SmallVector<OpFoldResult>> offsetsList;
986 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
989 ArrayRef<int64_t> wgShape = op.getDataShape();
990 VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
991 assert(valueTy &&
"the value type must be vector type!");
992 Type elemTy = valueTy.getElementType();
994 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
995 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
996 VectorType newResTy = VectorType::get(sgShape, elemTy);
997 SmallVector<Value> newOps;
998 for (
auto offsets : offsetsList) {
999 auto newOp = xegpu::LoadMatrixOp::create(rewriter, op.getLoc(), newResTy,
1000 op.getMemDesc(), offsets,
1001 layout.dropSgLayoutAndData());
1002 newOps.push_back(newOp);
1004 rewriter.replaceOpWithMultiple(op, {newOps});
1010struct WgToSgStoreMatrixOp :
public OpConversionPattern<xegpu::StoreMatrixOp> {
1011 using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern;
1013 matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor,
1014 ConversionPatternRewriter &rewriter)
const override {
1016 SmallVector<SmallVector<OpFoldResult>> offsetsList;
1017 if (
failed(genOffsetsList(rewriter, op, offsetsList)))
1020 xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
1021 for (
auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList))
1022 xegpu::StoreMatrixOp::create(rewriter, op.getLoc(), v, op.getMemDesc(),
1023 offsets, layout.dropSgLayoutAndData());
1024 rewriter.eraseOp(op);
1030struct WgToSgVectorStepOp :
public OpConversionPattern<vector::StepOp> {
1031 using OpConversionPattern<vector::StepOp>::OpConversionPattern;
1033 matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
1034 ConversionPatternRewriter &rewriter)
const override {
1035 xegpu::DistributeLayoutAttr layout =
1037 if (!layout || !layout.isForWorkgroup())
1040 Location loc = op.getLoc();
1041 VectorType type = op.getResult().getType();
1042 auto wgShape = type.getShape();
1043 std::optional<SmallVector<int64_t>> sgShape =
1044 getSgShapeAndCount(wgShape, layout).first;
1049 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1051 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1055 VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
1056 auto steps = vector::StepOp::create(rewriter, loc, newTy);
1057 SmallVector<Value> newOps;
1058 for (
auto offsets : *sgOffsets) {
1061 vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
1063 arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1065 layout.dropSgLayoutAndData());
1067 layout.dropSgLayoutAndData());
1069 layout.dropSgLayoutAndData());
1070 newOps.push_back(finalSteps);
1073 rewriter.replaceOpWithMultiple(op, {newOps});
1079struct WgToSgVectorShapeCastOp
1080 :
public OpConversionPattern<vector::ShapeCastOp> {
1081 using OpConversionPattern<vector::ShapeCastOp>::OpConversionPattern;
1084 matchAndRewrite(vector::ShapeCastOp op, OneToNOpAdaptor adaptor,
1085 ConversionPatternRewriter &rewriter)
const override {
1087 VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1091 ArrayRef<int64_t> wgShape = resultType.getShape();
1092 xegpu::DistributeLayoutAttr layout =
1094 if (!layout || !layout.isForWorkgroup())
1097 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1098 VectorType newResultType =
1099 VectorType::get(sgShape, resultType.getElementType());
1102 auto srcType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
1107 auto onlyUnitDims = [](ArrayRef<int64_t> src, ArrayRef<int64_t> dst) {
1109 SmallVector<int64_t> srcNonUnit, dstNonUnit;
1110 for (int64_t d : src)
1112 srcNonUnit.push_back(d);
1113 for (int64_t d : dst)
1115 dstNonUnit.push_back(d);
1116 return srcNonUnit == dstNonUnit;
1119 if (!onlyUnitDims(srcType.getShape(), sgShape))
1124 int64_t sourceRank = srcType.getRank();
1125 int64_t resultRank = sgShape.size();
1126 xegpu::DistributeLayoutAttr sourceLayout =
1128 if (sourceRank < resultRank && !sourceLayout.isSliceOf(layout))
1130 if (sourceRank > resultRank && !layout.isSliceOf(sourceLayout))
1133 SmallVector<Value> newShapeCastOps;
1134 for (
auto src : adaptor.getSource()) {
1135 auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
1136 newResultType, src);
1138 layout.dropSgLayoutAndData());
1139 newShapeCastOps.push_back(newShapeCast.getResult());
1142 rewriter.replaceOpWithMultiple(op, {newShapeCastOps});
1152struct WgToSgMultiDimReductionOp
1153 :
public OpConversionPattern<vector::MultiDimReductionOp> {
1154 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
1157 matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
1158 ConversionPatternRewriter &rewriter)
const override {
1159 VectorType srcType = op.getSourceVectorType();
1160 VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
1164 auto srcShape = srcType.getShape();
1165 xegpu::DistributeLayoutAttr layout =
1167 if (!layout || !layout.isForWorkgroup())
1170 auto reductionDims = llvm::to_vector(op.getReductionDims());
1172 SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1174 .getEffectiveSgLayoutAsInt();
1175 SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
1177 .getEffectiveSgDataAsInt();
1181 for (int64_t dim : reductionDims) {
1182 if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1183 return rewriter.notifyMatchFailure(
1185 "sgLayout in each reduced dimension must be 1 and sgData in the "
1186 "reduced dim must match srcShape in that dim");
1189 SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
1191 VectorType newDstType =
1192 VectorType::get({sgShape}, dstType.getElementType());
1194 SmallVector<Value> newReductions;
1195 for (
auto sgSrc : adaptor.getSource()) {
1196 auto newOp = vector::MultiDimReductionOp::create(
1197 rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
1198 adaptor.getAcc()[0], op.getReductionDims());
1200 layout.dropSgLayoutAndData());
1201 newReductions.push_back(newOp.
getResult());
1204 rewriter.replaceOpWithMultiple(op, {newReductions});
1210struct WgToSgVectorTransposeOp
1211 :
public OpConversionPattern<vector::TransposeOp> {
1212 using OpConversionPattern<vector::TransposeOp>::OpConversionPattern;
1215 matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
1216 ConversionPatternRewriter &rewriter)
const override {
1217 VectorType resultType = op.getResultVectorType();
1219 ArrayRef<int64_t> wgShape = resultType.getShape();
1220 xegpu::DistributeLayoutAttr layout =
1222 if (!layout || !layout.isForWorkgroup())
1225 xegpu::DistributeLayoutAttr sourceLayout =
1227 if (!sourceLayout || !sourceLayout.isForWorkgroup())
1230 SmallVector<int64_t> sourceSgLayout =
1231 sourceLayout.getEffectiveSgLayoutAsInt();
1232 SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1236 if (!sourceOrder || !resultOrder) {
1237 return rewriter.notifyMatchFailure(
1238 op,
"Both source and result must have order attributes");
1241 ArrayRef<int64_t> permutation = op.getPermutation();
1242 size_t permutationSize = permutation.size();
1243 if (sourceSgLayout.size() != permutationSize ||
1244 resultSgLayout.size() != permutationSize) {
1245 return rewriter.notifyMatchFailure(
1246 op,
"Layouts and permutation must have the same rank");
1251 if (!layout.isTransposeOf(sourceLayout, permutation))
1252 return rewriter.notifyMatchFailure(
1253 op,
"Result layout is not a valid transpose of source layout "
1254 "according to permutation");
1256 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1257 VectorType newResultType =
1258 VectorType::get(sgShape, resultType.getElementType());
1259 SmallVector<Value> newTransposeOps;
1260 for (
auto src : adaptor.getVector()) {
1261 auto newTranspose = vector::TransposeOp::create(
1262 rewriter, op.getLoc(), newResultType, src, permutation);
1264 layout.dropSgLayoutAndData());
1265 newTransposeOps.push_back(newTranspose.getResult());
1268 rewriter.replaceOpWithMultiple(op, {newTransposeOps});
1275struct WgToSgVectorConstantMaskOp
1276 :
public OpConversionPattern<vector::ConstantMaskOp> {
1277 using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern;
1280 matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor,
1281 ConversionPatternRewriter &rewriter)
const override {
1282 xegpu::DistributeLayoutAttr layout =
1284 if (!layout || !layout.isForWorkgroup())
1287 Location loc = op.getLoc();
1288 VectorType type = op.getResult().getType();
1289 auto wgShape = type.getShape();
1291 ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes();
1295 gpu::SubgroupIdOp::create(rewriter, loc,
nullptr);
1297 layout.computeDistributedCoords(rewriter, loc, sgId, wgShape);
1301 SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
1302 VectorType resultType = VectorType::get(sgShape, type.getElementType());
1306 SmallVector<Value> newCreateMaskOps;
1307 for (
auto offsetSet : *sgOffsets) {
1308 SmallVector<Value> maskOperands;
1310 for (
auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) {
1311 Value wgMaskSizeVal =
1315 Value offset = offsetSet[i];
1316 Value adjustedMaskSize =
1317 arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset);
1320 arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
1322 arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal);
1323 maskOperands.push_back(sgMaskSize);
1326 auto newCreateMaskOp =
1327 vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands);
1329 layout.dropSgLayoutAndData());
1330 newCreateMaskOps.push_back(newCreateMaskOp.getResult());
1333 rewriter.replaceOpWithMultiple(op, {newCreateMaskOps});
1344 .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
1345 WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
1346 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
1347 WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
1348 WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
1349 WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
1350 WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
1351 WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp,
1352 WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp,
1353 WgToSgVectorConstantMaskOp>(
patterns.getContext());
1359struct XeGPUWgToSgDistributePass
1361 void runOnOperation()
override;
1365void XeGPUWgToSgDistributePass::runOnOperation() {
1367 SmallVector<Operation *> existingCastOps;
1368 getOperation()->walk([&](UnrealizedConversionCastOp castOp) {
1369 existingCastOps.push_back(castOp.getOperation());
1379 TypeConverter converter;
1380 converter.addConversion([&](Type type) -> Type {
return type; });
1381 converter.addConversion(
1382 [&](RankedTensorType type,
1383 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1384 Type elemTy = type.getElementType();
1385 ArrayRef<int64_t> shape = type.getShape();
1388 SmallVector<int64_t> subShape;
1389 std::tie(subShape, count) = getSgShapeAndCount(
1391 dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()));
1393 auto newTy = VectorType::get(subShape, elemTy);
1394 result.append(count, newTy);
1406 ConversionTarget
target(*ctx);
1407 TypeConverter converter;
1408 converter.addConversion([&](Type type) -> Type {
return type; });
1409 converter.addConversion(
1410 [&](xegpu::TensorDescType type,
1411 SmallVectorImpl<Type> &
result) -> std::optional<LogicalResult> {
1412 Type elemTy = type.getElementType();
1413 ArrayRef<int64_t> shape = type.getShape();
1416 SmallVector<int64_t> subShape;
1417 xegpu::LayoutAttr layout = type.getLayoutAttr();
1418 std::tie(subShape, count) = getSgShapeAndCount(shape, layout);
1421 layout = layout.dropSgLayoutAndData();
1423 auto newTy = xegpu::TensorDescType::get(
1424 type.getContext(), subShape, elemTy, type.getEncoding(), layout);
1425 result.append(count, newTy);
1429 auto getTensorDescType = [](Operation *op) -> xegpu::TensorDescType {
1430 if (
auto createOp = dyn_cast<xegpu::CreateNdDescOp>(op))
1431 return createOp.getType();
1432 if (
auto loadOp = dyn_cast<xegpu::LoadNdOp>(op))
1433 return loadOp.getTensorDescType();
1434 if (
auto storeOp = dyn_cast<xegpu::StoreNdOp>(op))
1435 return storeOp.getTensorDescType();
1436 if (
auto updateOp = dyn_cast<xegpu::UpdateNdOffsetOp>(op))
1438 if (
auto prefetchOp = dyn_cast<xegpu::PrefetchNdOp>(op))
1439 return prefetchOp.getTensorDescType();
1440 return xegpu::TensorDescType();
1443 auto isLegal = [&](xegpu::DistributeLayoutAttr layout) ->
bool {
1444 return !layout || !layout.isForWorkgroup();
1447 target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp,
1448 xegpu::StoreNdOp, xegpu::UpdateNdOffsetOp,
1449 xegpu::PrefetchNdOp>([=](Operation *op) ->
bool {
1450 auto tdescTy = getTensorDescType(op);
1451 auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(tdescTy.getLayout());
1452 return isLegal(layout);
1455 target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) ->
bool {
1457 return isLegal(layout);
1460 target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>(
1461 [=](xegpu::LoadMatrixOp op) ->
bool {
1462 return isLegal(op.getLayoutAttr());
1465 target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>(
1466 [=](xegpu::StoreMatrixOp op) ->
bool {
1467 return isLegal(op.getLayoutAttr());
1470 target.addDynamicallyLegalOp<arith::ConstantOp>(
1471 [=](arith::ConstantOp op) ->
bool {
1472 auto vecType = dyn_cast<VectorType>(op.getType());
1477 return isLegal(layout);
1480 target.addDynamicallyLegalOp<
1481 vector::ShapeCastOp, vector::StepOp, vector::TransposeOp,
1482 vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>(
1483 [=](Operation *op) ->
bool {
1486 return isLegal(layout);
1489 target.addDynamicallyLegalOp<xegpu::LoadGatherOp>(
1490 [=](xegpu::LoadGatherOp op) ->
bool {
1492 return isLegal(layout);
1495 target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
1496 [=](xegpu::StoreScatterOp op) ->
bool {
1498 return isLegal(layout);
1501 target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
1502 [=](xegpu::ConvertLayoutOp op) ->
bool {
1503 return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
1506 target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
1507 [=](Operation *op) -> std::optional<bool> {
1512 VectorType resultType =
1513 dyn_cast<VectorType>(op->getResult(0).getType());
1519 for (Value operand : op->getOperands()) {
1520 VectorType operandType = dyn_cast<VectorType>(operand.getType());
1521 if (!operandType || operandType.getShape() != resultType.getShape()) {
1526 xegpu::DistributeLayoutAttr layout =
1528 return isLegal(layout);
1531 target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
1532 [=](UnrealizedConversionCastOp op) {
1533 return llvm::is_contained(existingCastOps, op.getOperation());
1536 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
1542 applyPartialConversion(getOperation(),
target, std::move(
patterns))))
1543 return signalPassFailure();
1550 getOperation()->walk([](Operation *op) {
1553 if (
auto layout = op->
getAttrOfType<xegpu::LayoutAttr>(name)) {
1555 if (!isa<scf::IfOp, scf::ForOp, scf::WhileOp, scf::ConditionOp>(op)) {
1556 if (
auto newLayout = layout.dropSgLayoutAndData())
static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, GetLayoutFnTy getLayoutOfValue)
Update an operation with the layout of its results.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Operation is the basic unit of execution within MLIR.
AttrClass getAttrOfType(StringAttr name)
Attribute getAttr(StringAttr name)
Return the specified attribute if present, null otherwise.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Location getLoc()
The source location the operation was defined or derived from.
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
OperationName getName()
The name of an operation is the key identifier for it.
result_range getOpResults()
Attribute removeAttr(StringAttr name)
Remove the attribute with the specified name if it exists.
unsigned getNumResults()
Return the number of results held by this operation.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
bool hasElementwiseMappableTraits(Operation *op)
Together, Elementwise, Scalarizable, Vectorizable, and Tensorizable provide an easy way for scalar op...
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
std::string getLayoutName(const OpOperand &operand)
Return the attribute name for the OpOperand to attach DistributeLayoutAttr.
void doSCFStructuralTypeConversionWithTensorType(Operation *op, TypeConverter converter)
Do type conversion for SCF structural ops, e.g., scf.for using SCF structure type convertion patterns...
DistributeLayoutAttr getDistributeLayoutAttr(const Value value)
Retrieves the DistributeLayoutAttr associated with a given Value.
void setDistributeLayoutAttr(const T &operandOrResult, const DistributeLayoutAttr layout, bool respectPermLayout=false)
Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching it to the owner's dictio...
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns)
Appends patterns for XeGPU workgroup to subgroup distribution into patterns.
SmallVector< Value > flattenValues(ArrayRef< ValueRange > values)
Flatten a set of ValueRange into a single SmallVector<Value>
SmallVector< OpFoldResult > addWithRightAligned(OpBuilder &builder, Location loc, ArrayRef< OpFoldResult > lhs, ArrayRef< OpFoldResult > rhs)
Generates element-wise addition ops of two arrays with automatic alignment.
Include the generated interface declarations.
SmallVector< int64_t > computeElementwiseMul(ArrayRef< int64_t > v1, ArrayRef< int64_t > v2)
Return a vector containing llvm::zip_equal(v1, v2) multiplied elementwise.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
const FrozenRewritePatternSet & patterns
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
This represents an operation in an abstracted form, suitable for use with the builder APIs.