MLIR 23.0.0git
XeVMToLLVM.cpp
Go to the documentation of this file.
1//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/Pass/Pass.h"
16#include "mlir/Support/LLVM.h"
17#include "llvm/Support/FormatVariadic.h"
18
20#include "mlir/IR/Types.h"
22
23#include "llvm/ADT/TypeSwitch.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
27#include "mlir/Conversion/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31using namespace xevm;
32
33namespace {
34
35struct LLVMFuncAttributeOptions {
36 bool isConvergent = false;
37 bool isNoUnwind = false;
38 bool isWillReturn = false;
39 LLVM::MemoryEffectsAttr memEffectsAttr{};
40};
41static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
42 false, true, false, {}};
43static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
44 false, true, true, {}};
45static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
46 true, true, true, {}};
47
48std::string getTypeMangling(Type ty, bool isUnsigned = false) {
50 .Case([isUnsigned](VectorType ty) -> std::string {
51 return "Dv" + std::to_string(ty.getNumElements()) + "_" +
52 getTypeMangling(ty.getElementType(), isUnsigned);
53 })
54 .Case([](Float16Type) -> std::string { return "Dh"; })
55 .Case([](Float32Type) -> std::string { return "f"; })
56 .Case([](Float64Type) -> std::string { return "d"; })
57 .Case([isUnsigned](IntegerType ty) -> std::string {
58 switch (ty.getWidth()) {
59 case 8:
60 return isUnsigned ? "h" : "c";
61 case 16:
62 return isUnsigned ? "t" : "s";
63 case 32:
64 return isUnsigned ? "j" : "i";
65 case 64:
66 return isUnsigned ? "m" : "l";
67 default:
68 llvm_unreachable("unhandled integer type");
69 }
70 })
71 .DefaultUnreachable("unhandled type for mangling");
72}
73
74std::string mangle(StringRef baseName, ArrayRef<Type> types,
75 ArrayRef<bool> isUnsigned = {}) {
76 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
77 "Signedness info doesn't match");
78 std::string s;
79 llvm::raw_string_ostream os(s);
80 llvm::SmallDenseMap<Type, unsigned> substitutions;
81 os << "_Z" << baseName.size() << baseName;
82 for (auto [idx, type] : llvm::enumerate(types)) {
83 auto it = substitutions.find(type);
84 if (it != substitutions.end()) {
85 os << "S";
86 // First substitution is `S_`, second is `S0_`, and so on.
87 if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
88 os << firstIdx - 1;
89 os << "_";
90 } else {
91 if (!type.isIntOrFloat())
92 substitutions[type] = substitutions.size();
93 os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
94 }
95 }
96 return os.str();
97}
98
99static int32_t getL1CacheControl(LoadCacheControl cc) {
100 int32_t control = 0;
101 switch (cc) {
102 case LoadCacheControl::L1C_L2UC_L3UC:
103 case LoadCacheControl::L1C_L2UC_L3C:
104 case LoadCacheControl::L1C_L2C_L3UC:
105 case LoadCacheControl::L1C_L2C_L3C:
106 control = 1;
107 break;
108 case LoadCacheControl::L1S_L2UC_L3UC:
109 case LoadCacheControl::L1S_L2UC_L3C:
110 case LoadCacheControl::L1S_L2C_L3UC:
111 case LoadCacheControl::L1S_L2C_L3C:
112 control = 2;
113 break;
114 case LoadCacheControl::INVALIDATE_READ:
115 control = 3;
116 break;
117 default:
118 break;
119 }
120 return control;
121}
122
123static int32_t getL1CacheControl(StoreCacheControl cc) {
124 int32_t control = 0;
125 switch (cc) {
126 case StoreCacheControl::L1WT_L2UC_L3UC:
127 case StoreCacheControl::L1WT_L2UC_L3WB:
128 case StoreCacheControl::L1WT_L2WB_L3UC:
129 case StoreCacheControl::L1WT_L2WB_L3WB:
130 control = 1;
131 break;
132 case StoreCacheControl::L1WB_L2UC_L3UC:
133 case StoreCacheControl::L1WB_L2WB_L3UC:
134 case StoreCacheControl::L1WB_L2UC_L3WB:
135 control = 2;
136 break;
137 case StoreCacheControl::L1S_L2UC_L3UC:
138 case StoreCacheControl::L1S_L2UC_L3WB:
139 case StoreCacheControl::L1S_L2WB_L3UC:
140 case StoreCacheControl::L1S_L2WB_L3WB:
141 control = 3;
142 break;
143 default:
144 break;
145 }
146 return control;
147}
148
149static int32_t getL3CacheControl(LoadCacheControl cc) {
150 int32_t control = 0;
151 switch (cc) {
152 case LoadCacheControl::L1UC_L2UC_L3C:
153 case LoadCacheControl::L1UC_L2C_L3C:
154 case LoadCacheControl::L1C_L2UC_L3C:
155 case LoadCacheControl::L1C_L2C_L3C:
156 case LoadCacheControl::L1S_L2UC_L3C:
157 case LoadCacheControl::L1S_L2C_L3C:
158 control = 1;
159 break;
160 case LoadCacheControl::INVALIDATE_READ:
161 control = 3;
162 break;
163 default:
164 break;
165 }
166 return control;
167}
168
169static int32_t getL3CacheControl(StoreCacheControl cc) {
170 int32_t control = 0;
171 switch (cc) {
172 case StoreCacheControl::L1UC_L2UC_L3WB:
173 case StoreCacheControl::L1UC_L2WB_L3WB:
174 case StoreCacheControl::L1WT_L2UC_L3WB:
175 case StoreCacheControl::L1WT_L2WB_L3WB:
176 case StoreCacheControl::L1S_L2UC_L3WB:
177 case StoreCacheControl::L1S_L2WB_L3WB:
178 case StoreCacheControl::L1WB_L2UC_L3WB:
179 control = 2;
180 break;
181 default:
182 break;
183 }
184 return control;
185}
186
187static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
188 return op.getCacheControl();
189}
190
191static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
192 return op.getCacheControl();
193}
194
195static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
196 return op.getCacheControl();
197}
198
199static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
200 return op.getCacheControl();
201}
202
203static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
204 return op.getCacheControl();
205}
206
207static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
208 return op.getCacheControl();
209}
210
211static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
212 if (op->hasAttr("cache_control")) {
213 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
214 if (!attr)
215 return std::nullopt;
216 return std::optional<LoadCacheControl>(attr.getValue());
217 }
218 return std::nullopt;
219}
220
221static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
222 if (op->hasAttr("cache_control")) {
223 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
224 if (!attr)
225 return std::nullopt;
226 return std::optional<StoreCacheControl>(attr.getValue());
227 }
228 return std::nullopt;
229}
230
231template <typename OpType>
232int32_t getL1CacheControl(OpType op) {
233 return getL1CacheControl(*getCacheControl(op));
234}
235
236template <typename OpType>
237int32_t getL3CacheControl(OpType op) {
238 return getL3CacheControl(*getCacheControl(op));
239}
240
241template <typename OpType>
242static std::optional<ArrayAttr>
243getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
244 if (!getCacheControl(op))
245 return {};
246 constexpr int32_t decorationCacheControlArity{3};
247 constexpr int32_t loadCacheControlKey{6442};
248 constexpr int32_t storeCacheControlKey{6443};
249 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
250 std::is_same_v<OpType, BlockPrefetch2dOp> ||
251 std::is_same_v<OpType, LLVM::LoadOp> ||
252 std::is_same_v<OpType, BlockLoadOp> ||
253 std::is_same_v<OpType, PrefetchOp>;
254 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
256 controlKey, 0, getL1CacheControl<OpType>(op)};
258 controlKey, 1, getL3CacheControl<OpType>(op)};
259 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
260 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
261
262 SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
263 return rewriter.getArrayAttr(combinedAttrs);
264}
265
266static LLVM::CallOp createDeviceFunctionCall(
267 ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
268 ArrayRef<Type> argTypes, ArrayRef<Value> args,
269 mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
270 LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
271 auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
272 assert(moduleOp && "Expecting module");
273 Location loc = op->getLoc();
274
275 auto funcOpRes =
276 LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
277 assert(!failed(funcOpRes));
278 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
279 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
280 funcOp.setConvergent(funcAttributeOptions.isConvergent);
281 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
282 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
283
284 if (funcAttributeOptions.memEffectsAttr)
285 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
286
287 for (auto [idx, attrName] : paramAttrs)
288 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
289
290 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
291 callOp->setAttrs(funcOp->getAttrs());
292
293 return callOp;
294}
295
296class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
297 using OpConversionPattern::OpConversionPattern;
298 LogicalResult
299 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
300 ConversionPatternRewriter &rewriter) const override {
301 if (!op.getC()) {
302 return rewriter.notifyMatchFailure(op, "OCL requires C operand");
303 }
304 auto precisionA = op.getTypes().getA();
305 auto precisionB = op.getTypes().getB();
306 auto precisionC = op.getTypes().getC();
307 auto precisionD = op.getTypes().getD();
308 if (precisionC != precisionD) {
309 return rewriter.notifyMatchFailure(op, "type of C and D need to match");
310 }
311 if (precisionC != xevm::ElemType::S32 &&
312 precisionC != xevm::ElemType::F32 &&
313 precisionC != xevm::ElemType::F16 &&
314 precisionC != xevm::ElemType::BF16) {
315 return rewriter.notifyMatchFailure(
316 op, "type of C and D must be S32, F32, F16 or BF16");
317 }
318 if (precisionA == xevm::ElemType::S32 ||
319 precisionA == xevm::ElemType::F32) {
320 return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
321 }
322 if (precisionB == xevm::ElemType::S32 ||
323 precisionB == xevm::ElemType::F32) {
324 return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
325 }
326 constexpr uint32_t bitWidthPackedA{16};
327 constexpr uint32_t bitWidthPackedB{32};
328 auto loc = op.getLoc();
329
330 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
331 VectorType origTy = cast<VectorType>(val.getType());
332 const uint32_t vecBitSize =
333 origTy.getNumElements() *
334 origTy.getElementType().getIntOrFloatBitWidth();
335 VectorType newTy = VectorType::get(
336 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
337 if (origTy != newTy)
338 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
339 return val;
340 };
341
342 Value a = op.getA();
343 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
344 ? cast<Type>(rewriter.getF32Type())
345 : rewriter.getIntegerType(bitWidthPackedA);
346 a = castIfNeeded(a, packedAType);
347
348 Value b = op.getB();
349 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
350 ? cast<Type>(rewriter.getF32Type())
351 : rewriter.getIntegerType(bitWidthPackedB);
352 b = castIfNeeded(b, packedBType);
353
354 Value c = op.getC();
355 VectorType cOrigTy = cast<VectorType>(c.getType());
356 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
357 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
358 // OCL builtins encode bfloat16 as int16
359 VectorType cTy =
360 cOrigTy.getElementType().isBF16()
361 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
362 : cOrigTy;
363 VectorType resTy = cTy;
364 if (cOrigTy != cTy)
365 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
366
367 constexpr int32_t systolicDepth{8};
368 std::string fnName =
369 llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
370 stringifyElemType(op.getTypes().getA()).str(),
371 stringifyElemType(op.getTypes().getB()).str(),
372 systolicDepth *
373 getNumOperandsPerDword(op.getTypes().getA()))
374 .str();
375 SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
376 fnName = mangle(fnName, argTypes);
377 SmallVector<Value> args{a, b, c};
378
379 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
380 /*other=*/LLVM::ModRefInfo::NoModRef,
381 /*argMem=*/LLVM::ModRefInfo::NoModRef,
382 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
383 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
384 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
385 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
386 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
387 funcAttrs.memEffectsAttr = memAttr;
388 Value result =
389 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
390 funcAttrs, op.getOperation())
391 ->getResult(0);
392
393 if (resOrigTy != resTy)
394 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
395
396 rewriter.replaceOp(op, result);
397 return success();
398 }
399
400private:
401 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
402 switch (pTy) {
403 case xevm::ElemType::TF32:
404 return 1;
405 case xevm::ElemType::BF16:
406 case xevm::ElemType::F16:
407 return 2;
408 case xevm::ElemType::U8:
409 case xevm::ElemType::S8:
410 return 4;
411 default:
412 llvm_unreachable("unsupported xevm::ElemType");
413 }
414 }
415};
416
417class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
418 using OpConversionPattern::OpConversionPattern;
419 LogicalResult
420 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
421 ConversionPatternRewriter &rewriter) const override {
422 auto loc = op.getLoc();
423 const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
424 Value one =
425 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
426 SmallVector<Value> args{op.getPtr(), one};
427 SmallVector<Type> argTypes;
428 for (auto arg : args)
429 argTypes.push_back(arg.getType());
430 auto funcAttr = noUnwindAttrs;
431 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
432 /*other=*/LLVM::ModRefInfo::NoModRef,
433 /*argMem=*/LLVM::ModRefInfo::Ref,
434 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
435 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
436 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
437 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
438 funcAttr.memEffectsAttr = memAttr;
439
440 LLVM::CallOp call = createDeviceFunctionCall(
441 rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
442 argTypes, args, {}, funcAttr, op.getOperation());
443 if (std::optional<ArrayAttr> optCacheControls =
444 getCacheControlMetadata(rewriter, op))
445 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
446 rewriter.eraseOp(op);
447 return success();
448 }
449};
450
451class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
452 using OpConversionPattern::OpConversionPattern;
453 LogicalResult
454 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
455 ConversionPatternRewriter &rewriter) const override {
456 auto loc = op.getLoc();
457 const std::string fnName{"atomic_work_item_fence"};
458 int memScope, addrSpace;
459 switch (op.getAddrspace()) {
460 case xevm::AddrSpace::SHARED:
461 addrSpace = 1; // CLK_LOCAL_MEM_FENCE
462 break;
463 case xevm::AddrSpace::GLOBAL:
464 addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
465 break;
466 default:
467 // GENERIC is not supported in OpenCL
468 return rewriter.notifyMatchFailure(
469 op, "Fence only supports global and shared address spaces.");
470 }
471 switch (op.getScope()) {
472 case xevm::MemScope::WORKGROUP:
473 memScope = 1;
474 break;
475 case xevm::MemScope::DEVICE:
476 memScope = 2;
477 break;
478 default:
479 // CLUSTER and SYSTEM are not supported in OpenCL
480 return rewriter.notifyMatchFailure(
481 op, "Fence only supports workgroup and device memory scopes.");
482 }
483 Type i32Type = rewriter.getI32Type();
484 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
485 Value memScopeConst =
486 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
487 Value addrSpaceConst =
488 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
489 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
490 SmallVector<Type> argTypes{3, i32Type};
491 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
492 LLVM::LLVMVoidType::get(rewriter.getContext()),
493 argTypes, args, {}, noUnwindAttrs,
494 op.getOperation());
495 rewriter.eraseOp(op);
496 return success();
497 }
498};
499template <typename OpType>
500class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
501 using OpConversionPattern<OpType>::OpConversionPattern;
502 LogicalResult
503 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
504 ConversionPatternRewriter &rewriter) const override {
505 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
506 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
507
508 auto loc = op.getLoc();
509 VectorType vecType;
510 bool packReg = false;
511 bool transpose = false;
512 if constexpr (isLoad) {
513 vecType = op.getRes().getType();
514 packReg = op.getPackRegister();
515 transpose = op.getTranspose();
516 } else if constexpr (!isPrefetch) {
517 vecType = op.getStoredVal().getType();
518 }
519
520 auto i32Type = rewriter.getI32Type();
521 Value byteCoord =
522 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
523 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
524 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
525 byteCoord = LLVM::InsertElementOp::create(
526 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
527 byteCoord = LLVM::InsertElementOp::create(
528 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
529 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
530 op.getBasePitch(), byteCoord};
531 SmallVector<Type> retTypes;
532 Value spvLoadDstPtr;
533 std::string funcName{"intel_sub_group_2d_block_"};
534 std::string bitWidthId;
535 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
536 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
537 if constexpr (isPrefetch) { // Prefetch
538 funcName += "prefetch";
539 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
540 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
541 /*other=*/LLVM::ModRefInfo::NoModRef,
542 /*argMem=*/LLVM::ModRefInfo::Ref,
543 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
544 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
545 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
546 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
547 funcAttr = noUnwindAttrs;
548 funcAttr.memEffectsAttr = memAttr;
549 } else {
550 auto vecElemType = vecType.getElementType();
551 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
552 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
553 vecType.getNumElements());
554 auto dstOrSrcPtr = LLVM::AllocaOp::create(
555 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
556 vecElemType, numElems);
557 args.push_back(dstOrSrcPtr);
558 if constexpr (isLoad) { // Load
559 funcName += "read";
560 bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
561 if (packReg)
562 funcName += "_transform";
563 else if (transpose)
564 funcName += "_transpose";
565 spvLoadDstPtr = dstOrSrcPtr;
566 retTypes.push_back(vecType);
567 paramAttrs = {
568 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
569 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
570 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
571 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
572 };
573 } else { // Store
574 funcName += "write";
575 bitWidthId = (vecElemBitWidth == 32)
576 ? "j"
577 : ((vecElemBitWidth == 16) ? "t" : "h");
578 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
579 paramAttrs = {
580 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
581 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
582 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
583 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
584 };
585 }
586 }
587
588 funcName =
589 llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
590 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
591 .str();
592 std::string prefetchCode("");
593 if (!isPrefetch)
594 prefetchCode += "P";
595 funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
596 funcName, prefetchCode, bitWidthId)
597 .str();
598 SmallVector<Type> argTypes;
599 for (auto arg : args) {
600 argTypes.push_back(arg.getType());
601 }
602 LLVM::CallOp call = createDeviceFunctionCall(
603 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
604 argTypes, args, paramAttrs, funcAttr, op.getOperation());
605 if (std::optional<ArrayAttr> optCacheControls =
606 getCacheControlMetadata(rewriter, op)) {
607 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
608 }
609 if constexpr (isLoad)
610 rewriter.replaceOp(
611 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
612 else
613 rewriter.eraseOp(op);
614 return success();
615 }
616};
617
618template <typename OpType>
619class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
620 using OpConversionPattern<OpType>::OpConversionPattern;
621 LogicalResult
622 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
623 ConversionPatternRewriter &rewriter) const override {
624 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
625 // Get OpenCL function name
626 // https://registry.khronos.org/OpenCL/extensions/
627 // intel/cl_intel_subgroup_local_block_io.html
628 std::string funcName{"intel_sub_group_block_"};
629 // Value or Result type can be vector or scalar
630 Type valOrResTy;
631 if constexpr (isStore) {
632 funcName += "write_u";
633 valOrResTy = op.getVal().getType();
634 } else {
635 funcName += "read_u";
636 valOrResTy = op.getType();
637 }
638 // Get element type of the vector/scalar
639 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
640 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
641 funcName += getTypeMangling(elemType);
642 if (vecTy)
643 funcName += std::to_string(vecTy.getNumElements());
644 SmallVector<Type, 2> argTypes{};
645 // XeVM BlockLoad/StoreOp always use signless integer types
646 // but OpenCL builtins expect unsigned types
647 // use unsigned types for mangling
648 SmallVector<bool, 2> isUnsigned{};
649 // arg0: pointer to the src/dst address
650 // arg1 - only if store : vector to store
651 // Prepare arguments
652 SmallVector<Value, 2> args{};
653 args.push_back(op.getPtr());
654 argTypes.push_back(op.getPtr().getType());
655 isUnsigned.push_back(true);
656 Type retType;
657 if constexpr (isStore) {
658 args.push_back(op.getVal());
659 argTypes.push_back(op.getVal().getType());
660 isUnsigned.push_back(true);
661 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
662 } else {
663 retType = valOrResTy;
664 }
665 funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
666 "PU3AS" +
667 std::to_string(op.getPtr().getType().getAddressSpace());
668 funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
669 if constexpr (isStore)
670 funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
671 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
672
673 LLVM::CallOp call =
674 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
675 {}, funcAttr, op.getOperation());
676 if (std::optional<ArrayAttr> optCacheControls =
677 getCacheControlMetadata(rewriter, op)) {
678 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
679 }
680 if constexpr (isStore)
681 rewriter.eraseOp(op);
682 else
683 rewriter.replaceOp(op, call->getResult(0));
684 return success();
685 }
686};
687
688template <typename OpType>
689class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
690 using OpConversionPattern<OpType>::OpConversionPattern;
691 LogicalResult
692 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
693 ConversionPatternRewriter &rewriter) const override {
694 if (!op->hasAttr("cache_control"))
695 return failure();
696 std::optional<ArrayAttr> optCacheControls =
697 getCacheControlMetadata(rewriter, op);
698 op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
699 op->removeAttr("cache_control");
700 return success();
701 }
702};
703
704//===----------------------------------------------------------------------===//
705// GPU index id operations
706//===----------------------------------------------------------------------===//
707/*
708// Launch Config ops
709// dimidx - x, y, z - is fixed to i32
710// return type is set by XeVM type converter
711// get_local_id
712xevm::WorkitemIdXOp;
713xevm::WorkitemIdYOp;
714xevm::WorkitemIdZOp;
715// get_local_size
716xevm::WorkgroupDimXOp;
717xevm::WorkgroupDimYOp;
718xevm::WorkgroupDimZOp;
719// get_group_id
720xevm::WorkgroupIdXOp;
721xevm::WorkgroupIdYOp;
722xevm::WorkgroupIdZOp;
723// get_num_groups
724xevm::GridDimXOp;
725xevm::GridDimYOp;
726xevm::GridDimZOp;
727// get_global_id : to be added if needed
728*/
729
730// Helpers to get the OpenCL function name and dimension argument for each op.
731static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
732 return {"get_local_id", 0};
733}
734static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
735 return {"get_local_id", 1};
736}
737static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
738 return {"get_local_id", 2};
739}
740static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
741 return {"get_local_size", 0};
742}
743static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
744 return {"get_local_size", 1};
745}
746static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
747 return {"get_local_size", 2};
748}
749static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
750 return {"get_group_id", 0};
751}
752static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
753 return {"get_group_id", 1};
754}
755static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
756 return {"get_group_id", 2};
757}
758static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
759 return {"get_num_groups", 0};
760}
761static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
762 return {"get_num_groups", 1};
763}
764static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
765 return {"get_num_groups", 2};
766}
767/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
768/// a constant argument for the dimension - x, y or z.
769template <typename OpType>
770class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
771 using OpConversionPattern<OpType>::OpConversionPattern;
772 LogicalResult
773 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
774 ConversionPatternRewriter &rewriter) const override {
775 Location loc = op->getLoc();
776 auto [baseName, dim] = getConfig(op);
777 Type dimTy = rewriter.getI32Type();
778 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
779 static_cast<int64_t>(dim));
780 std::string func = mangle(baseName, {dimTy}, {true});
781 Type resTy = op.getType();
782 auto call =
783 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
784 noUnwindWillReturnAttrs, op.getOperation());
785 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
786 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
787 /*other=*/noModRef,
788 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
789 /*errnoMem=*/noModRef,
790 /*targetMem0=*/noModRef,
791 /*targetMem1=*/noModRef);
792 call.setMemoryEffectsAttr(memAttr);
793 rewriter.replaceOp(op, call);
794 return success();
795 }
796};
797
798/*
799// Subgroup ops
800// get_sub_group_local_id
801xevm::LaneIdOp;
802// get_sub_group_id
803xevm::SubgroupIdOp;
804// get_sub_group_size
805xevm::SubgroupSizeOp;
806// get_num_sub_groups : to be added if needed
807*/
808
809// Helpers to get the OpenCL function name for each op.
810static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
811static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
812static StringRef getConfig(xevm::SubgroupSizeOp) {
813 return "get_sub_group_size";
814}
815template <typename OpType>
816class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
817 using OpConversionPattern<OpType>::OpConversionPattern;
818 LogicalResult
819 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
820 ConversionPatternRewriter &rewriter) const override {
821 std::string func = mangle(getConfig(op).str(), {});
822 Type resTy = op.getType();
823 auto call =
824 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
825 noUnwindWillReturnAttrs, op.getOperation());
826 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
827 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
828 /*other=*/noModRef,
829 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
830 /*errnoMem=*/noModRef,
831 /*targetMem0=*/noModRef,
832 /*targetMem1=*/noModRef);
833 call.setMemoryEffectsAttr(memAttr);
834 rewriter.replaceOp(op, call);
835 return success();
836 }
837};
838
839static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
840 if (op.getV1() != op.getV2())
841 return false;
842 auto maskAttr = op.getMask();
843 int64_t firstIndex = maskAttr[0];
844 for (int64_t i = 1; i < static_cast<int64_t>(maskAttr.size()); ++i) {
845 int64_t index = maskAttr[i];
846 if (index != firstIndex + i)
847 return false;
848 }
849 return true;
850}
851
852// Input vector of a shuffle vector op extracting a contiguous slice is an
853// illegal vector in SPIRV kernel if the vector size is > 16 elements.
854// To legalize this case, keep applying the following transformations until no
855// more match:
856// 1. keep hoisting the shuffle vector op past unary element-wise operations
857// start with fpext, fptrunc and bitcast for now.
858// 2. merge with another shuffle vector op
859// 3. merge with load as a smaller load
860class HandleVectorExtractPattern
861 : public OpRewritePattern<LLVM::ShuffleVectorOp> {
862 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
863
864 void initialize() { setHasBoundedRewriteRecursion(); }
865
866 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
867 PatternRewriter &rewriter) const override {
868
869 if (!isExtractingContiguousSlice(op))
870 return failure();
871
872 auto mask = op.getMask();
873 auto loc = op.getLoc();
874 auto ty = op.getType();
875 // Check source operand to determine rewrite pattern.
876 auto src = op.getV1();
877 // 1. Hoist past unary element-wise operations
878 if (auto srcOp = src.getDefiningOp()) {
879 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
880 Value srcInput = srcOp->getOperand(0);
881 // Create new shuffle vector op with unary input as source.
882 auto srcVecTy = dyn_cast<VectorType>(srcInput.getType());
883 auto newShuffleVecTy =
884 VectorType::get(mask.size(), srcVecTy.getElementType());
885 auto newShuffle = LLVM::ShuffleVectorOp::create(
886 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
887 // Create new unary op with new shuffle as input.
888 Value newUnaryOp;
889 if (isa<LLVM::FPExtOp>(srcOp)) {
890 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
891 } else {
892 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
893 }
894 rewriter.replaceOp(op, newUnaryOp);
895 } else if (isa<LLVM::BitcastOp>(srcOp)) {
896 Value srcInput = srcOp->getOperand(0);
897 // Create new shuffle vector op with unary input as source.
898 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.getType());
899 auto srcInputSize = srcInputVecTy.getNumElements();
900 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
901 auto srcResSize = srcResVecTy.getNumElements();
902 auto maskSize = static_cast<int32_t>(mask.size());
903 if (srcInputSize > srcResSize) {
904 return failure();
905 }
906 if (srcResSize % srcInputSize != 0) {
907 return failure();
908 }
909 auto maskScale = srcResSize / srcInputSize;
910 if (maskScale != 1) {
911 if (mask[0] % maskScale != 0) {
912 return failure();
913 }
914 // Create a new mask that maps to the source vector
915 SmallVector<int32_t> newMask;
916 int32_t newMaskSize = maskSize / maskScale;
917 int32_t maskStart = mask[0] / maskScale;
918 for (int32_t i = 0; i < newMaskSize; ++i) {
919 newMask.push_back(maskStart + i);
920 }
921 mask = newMask;
922 }
923 auto newShuffleVecTy =
924 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
925 auto newShuffle = LLVM::ShuffleVectorOp::create(
926 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
927 // Create new unary op with new shuffle as input.
928 auto newBitcast =
929 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
930 rewriter.replaceOp(op, newBitcast);
931 } else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
932 // 2. Merge with another shuffle vector op
933 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
934 auto srcMask = srcShuffle.getMask();
935 SmallVector<int32_t> combinedMask;
936 for (auto index : mask) {
937 combinedMask.push_back(srcMask[index]);
938 }
939 auto newShuffle = LLVM::ShuffleVectorOp::create(
940 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
941 DenseI32ArrayAttr::get(rewriter.getContext(), combinedMask));
942 rewriter.replaceOp(op, newShuffle);
943 } else if (isa<LLVM::LoadOp>(srcOp)) {
944 // 3. Merge with load as a smaller load
945 auto loadOp = cast<LLVM::LoadOp>(srcOp);
946 auto loadPtr = loadOp.getAddr();
947 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
948 auto elemTy = loadTy.getElementType();
949 auto firstIndex = mask[0];
950 auto newVecTy = VectorType::get(mask.size(), elemTy);
951 // GEPOp is needed if first index is not zero
952 if (firstIndex) {
953 auto newPtr = LLVM::GEPOp::create(
954 rewriter, loc,
955 LLVM::LLVMPointerType::get(rewriter.getContext(),
956 loadPtr.getType().getAddressSpace()),
957 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
958 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
959 rewriter.replaceOp(op, newLoad);
960 } else {
961 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
962 rewriter.replaceOp(op, newLoad);
963 }
964 } else {
965 return failure();
966 }
967 }
968 return success();
969 }
970};
971
972//===----------------------------------------------------------------------===//
973// Pass Definition
974//===----------------------------------------------------------------------===//
975
976struct ConvertXeVMToLLVMPass
977 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
978 using Base::Base;
979
980 void getDependentDialects(DialectRegistry &registry) const override {
981 registry.insert<LLVM::LLVMDialect, XeVMDialect>();
982 }
983
984 void runOnOperation() override {
985 ConversionTarget target(getContext());
986 RewritePatternSet patterns(&getContext());
988 if (failed(applyPartialConversion(getOperation(), target,
989 std::move(patterns))))
990 signalPassFailure();
991
992 // Apply in-dialect lowerings to handle illegal vectors
993 {
994 RewritePatternSet vectorPatterns(&getContext());
995 vectorPatterns.add<HandleVectorExtractPattern>(&getContext());
996 GreedyRewriteConfig config{};
997 // folding can remove ops with temporary attributes used to
998 // represent LLVM metadata, so disable it here.
999 // Effectively just this single pattern is applied without any
1000 // op folding patterns from dialects.
1001 config.enableFolding(false);
1002 // config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
1003 // config.setMaxNumRewrites(GreedyRewriteConfig::kNoLimit);
1004 (void)applyPatternsGreedily(getOperation(), std::move(vectorPatterns),
1005 config);
1006 }
1007 }
1008};
1009} // namespace
1010
1011//===----------------------------------------------------------------------===//
1012// Pattern Population
1013//===----------------------------------------------------------------------===//
1014
1015void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
1017 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
1018 [](Operation *op) { return !op->hasAttr("cache_control"); });
1019 target.addIllegalDialect<XeVMDialect>();
1020 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1021 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1022 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1023 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1024 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1025 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1026 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1027 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1028 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1029 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1030 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1031 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1032 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1033 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1034 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1035 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1036 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1037 LaunchConfigOpToOCLPattern<GridDimXOp>,
1038 LaunchConfigOpToOCLPattern<GridDimYOp>,
1039 LaunchConfigOpToOCLPattern<GridDimZOp>,
1040 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1041 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1042 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
1043 patterns.getContext());
1044}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Type getType() const
Return the type of this value.
Definition Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...